[desktop] Enable ONNX Face ML (#1413)
Indexing + PUT embeddings works now, but most other things are still WIP.
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# Dependencies
|
||||
|
||||
* [Electron](#electron)
|
||||
* [Dev dependencies](#dev)
|
||||
* [Functionality](#functionality)
|
||||
- [Electron](#electron)
|
||||
- [Dev dependencies](#dev)
|
||||
- [Functionality](#functionality)
|
||||
|
||||
## Electron
|
||||
|
||||
@@ -114,8 +114,8 @@ available on the host machine, and is not bundled with our app.
|
||||
AI/ML runtime. It powers both natural language searches (using CLIP) and face
|
||||
detection (using YOLO).
|
||||
|
||||
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding
|
||||
JPEG data into raw RGB bytes before passing it to ONNX.
|
||||
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding JPEG
|
||||
data into raw RGB bytes before passing it to ONNX.
|
||||
|
||||
html-entities is used by the bundled clip-bpe-ts tokenizer for CLIP.
|
||||
|
||||
|
||||
@@ -36,13 +36,14 @@ import {
|
||||
updateAndRestart,
|
||||
updateOnNextRestart,
|
||||
} from "./services/app-update";
|
||||
import { clipImageEmbedding, clipTextEmbedding } from "./services/clip";
|
||||
import { runFFmpegCmd } from "./services/ffmpeg";
|
||||
import { getDirFiles } from "./services/fs";
|
||||
import {
|
||||
convertToJPEG,
|
||||
generateImageThumbnail,
|
||||
} from "./services/imageProcessor";
|
||||
import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip";
|
||||
import { detectFaces, faceEmbedding } from "./services/ml-face";
|
||||
import {
|
||||
clearStores,
|
||||
encryptionKey,
|
||||
@@ -146,6 +147,14 @@ export const attachIPCHandlers = () => {
|
||||
clipTextEmbedding(text),
|
||||
);
|
||||
|
||||
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
|
||||
detectFaces(input),
|
||||
);
|
||||
|
||||
ipcMain.handle("faceEmbedding", (_, input: Float32Array) =>
|
||||
faceEmbedding(input),
|
||||
);
|
||||
|
||||
// - File selection
|
||||
|
||||
ipcMain.handle("selectDirectory", () => selectDirectory());
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
/**
|
||||
* @file Compute CLIP embeddings
|
||||
* @file Compute CLIP embeddings for images and text.
|
||||
*
|
||||
* @see `web/apps/photos/src/services/clip-service.ts` for more details. This
|
||||
* file implements the Node.js implementation of the actual embedding
|
||||
* computation. By doing it in the Node.js layer, we can use the binary ONNX
|
||||
* runtimes which are 10-20x faster than the WASM based web ones.
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime. The model itself is not
|
||||
* shipped with the app but is downloaded on demand.
|
||||
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
|
||||
*/
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import jpeg from "jpeg-js";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
|
||||
import { CustomErrors } from "../../types/ipc";
|
||||
@@ -21,6 +15,12 @@ import { writeStream } from "../fs";
|
||||
import log from "../log";
|
||||
import { generateTempFilePath } from "../temp";
|
||||
import { deleteTempFile } from "./ffmpeg";
|
||||
import {
|
||||
createInferenceSession,
|
||||
downloadModel,
|
||||
modelPathDownloadingIfNeeded,
|
||||
modelSavePath,
|
||||
} from "./ml";
|
||||
|
||||
const textModelName = "clip-text-vit-32-uint8.onnx";
|
||||
const textModelByteSize = 64173509; // 61.2 MB
|
||||
@@ -28,55 +28,20 @@ const textModelByteSize = 64173509; // 61.2 MB
|
||||
const imageModelName = "clip-image-vit-32-float32.onnx";
|
||||
const imageModelByteSize = 351468764; // 335.2 MB
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download
|
||||
log.info(`Downloading CLIP model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
// Save
|
||||
await writeStream(saveLocation, res.body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
};
|
||||
|
||||
let activeImageModelDownload: Promise<void> | undefined;
|
||||
let activeImageModelDownload: Promise<string> | undefined;
|
||||
|
||||
const imageModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
const modelPath = modelSavePath(imageModelName);
|
||||
if (activeImageModelDownload) {
|
||||
log.info("Waiting for CLIP image model download to finish");
|
||||
await activeImageModelDownload;
|
||||
} else {
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
activeImageModelDownload = downloadModel(
|
||||
modelPath,
|
||||
imageModelName,
|
||||
);
|
||||
await activeImageModelDownload;
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelPath)).size;
|
||||
if (localFileSize !== imageModelByteSize) {
|
||||
log.error(
|
||||
`CLIP image model size ${localFileSize} does not match the expected size, downloading again`,
|
||||
);
|
||||
activeImageModelDownload = downloadModel(
|
||||
modelPath,
|
||||
imageModelName,
|
||||
);
|
||||
await activeImageModelDownload;
|
||||
}
|
||||
}
|
||||
activeImageModelDownload = modelPathDownloadingIfNeeded(
|
||||
imageModelName,
|
||||
imageModelByteSize,
|
||||
);
|
||||
return await activeImageModelDownload;
|
||||
}
|
||||
return modelPath;
|
||||
} finally {
|
||||
activeImageModelDownload = undefined;
|
||||
}
|
||||
@@ -84,6 +49,8 @@ const imageModelPathDownloadingIfNeeded = async () => {
|
||||
|
||||
let textModelDownloadInProgress = false;
|
||||
|
||||
/* TODO(MR): use the generic method. Then we can remove the exports for the
|
||||
internal details functions that we use here */
|
||||
const textModelPathDownloadingIfNeeded = async () => {
|
||||
if (textModelDownloadInProgress)
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
@@ -123,13 +90,6 @@ const textModelPathDownloadingIfNeeded = async () => {
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
intraOpNumThreads: 1,
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
};
|
||||
|
||||
let imageSessionPromise: Promise<any> | undefined;
|
||||
|
||||
const onnxImageSession = async () => {
|
||||
@@ -174,7 +134,7 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
||||
const results = await imageSession.run(feeds);
|
||||
log.debug(
|
||||
() =>
|
||||
`CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
`onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const imageEmbedding = results["output"].data; // Float32Array
|
||||
return normalizeEmbedding(imageEmbedding);
|
||||
@@ -281,7 +241,7 @@ export const clipTextEmbedding = async (text: string) => {
|
||||
const results = await imageSession.run(feeds);
|
||||
log.debug(
|
||||
() =>
|
||||
`CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const textEmbedding = results["output"].data;
|
||||
return normalizeEmbedding(textEmbedding);
|
||||
108
desktop/src/main/services/ml-face.ts
Normal file
108
desktop/src/main/services/ml-face.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
/**
|
||||
* @file Various face recognition related tasks.
|
||||
*
|
||||
* - Face detection with the YOLO model.
|
||||
* - Face embedding with the MobileFaceNet model.
|
||||
*
|
||||
* The runtime used is ONNX.
|
||||
*/
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml";
|
||||
|
||||
const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx";
|
||||
const faceDetectionModelByteSize = 30762872; // 29.3 MB
|
||||
|
||||
const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
|
||||
const faceEmbeddingModelByteSize = 5286998; // 5 MB
|
||||
|
||||
let activeFaceDetectionModelDownload: Promise<string> | undefined;
|
||||
|
||||
const faceDetectionModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
if (activeFaceDetectionModelDownload) {
|
||||
log.info("Waiting for face detection model download to finish");
|
||||
await activeFaceDetectionModelDownload;
|
||||
} else {
|
||||
activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded(
|
||||
faceDetectionModelName,
|
||||
faceDetectionModelByteSize,
|
||||
);
|
||||
return await activeFaceDetectionModelDownload;
|
||||
}
|
||||
} finally {
|
||||
activeFaceDetectionModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let _faceDetectionSession: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const faceDetectionSession = async () => {
|
||||
if (!_faceDetectionSession) {
|
||||
_faceDetectionSession =
|
||||
faceDetectionModelPathDownloadingIfNeeded().then((modelPath) =>
|
||||
createInferenceSession(modelPath),
|
||||
);
|
||||
}
|
||||
return _faceDetectionSession;
|
||||
};
|
||||
|
||||
let activeFaceEmbeddingModelDownload: Promise<string> | undefined;
|
||||
|
||||
const faceEmbeddingModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
if (activeFaceEmbeddingModelDownload) {
|
||||
log.info("Waiting for face embedding model download to finish");
|
||||
await activeFaceEmbeddingModelDownload;
|
||||
} else {
|
||||
activeFaceEmbeddingModelDownload = modelPathDownloadingIfNeeded(
|
||||
faceEmbeddingModelName,
|
||||
faceEmbeddingModelByteSize,
|
||||
);
|
||||
return await activeFaceEmbeddingModelDownload;
|
||||
}
|
||||
} finally {
|
||||
activeFaceEmbeddingModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let _faceEmbeddingSession: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const faceEmbeddingSession = async () => {
|
||||
if (!_faceEmbeddingSession) {
|
||||
_faceEmbeddingSession =
|
||||
faceEmbeddingModelPathDownloadingIfNeeded().then((modelPath) =>
|
||||
createInferenceSession(modelPath),
|
||||
);
|
||||
}
|
||||
return _faceEmbeddingSession;
|
||||
};
|
||||
|
||||
export const detectFaces = async (input: Float32Array) => {
|
||||
const session = await faceDetectionSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
|
||||
};
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `onnx/yolo face detection took ${Date.now() - t} ms`);
|
||||
return results["output"].data;
|
||||
};
|
||||
|
||||
export const faceEmbedding = async (input: Float32Array) => {
|
||||
// Dimension of each face (alias)
|
||||
const mobileFaceNetFaceSize = 112;
|
||||
// Smaller alias
|
||||
const z = mobileFaceNetFaceSize;
|
||||
// Size of each face's data in the batch
|
||||
const n = Math.round(input.length / (z * z * 3));
|
||||
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
|
||||
|
||||
const session = await faceEmbeddingSession();
|
||||
const t = Date.now();
|
||||
const feeds = { img_inputs: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`);
|
||||
// TODO: What's with this type? It works in practice, but double check.
|
||||
return (results.embeddings as unknown as any)["cpuData"]; // as Float32Array;
|
||||
};
|
||||
79
desktop/src/main/services/ml.ts
Normal file
79
desktop/src/main/services/ml.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
/**
|
||||
* @file AI/ML related functionality.
|
||||
*
|
||||
* @see also `ml-clip.ts`, `ml-face.ts`.
|
||||
*
|
||||
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
|
||||
* for various tasks are not shipped with the app but are downloaded on demand.
|
||||
*
|
||||
* The primary reason for doing these tasks in the Node.js layer is so that we
|
||||
* can use the binary ONNX runtime which is 10-20x faster than the WASM based
|
||||
* web one.
|
||||
*/
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import { writeStream } from "../fs";
|
||||
import log from "../log";
|
||||
|
||||
/**
|
||||
* Download the model named {@link modelName} if we don't already have it.
|
||||
*
|
||||
* Also verify that the size of the model we get matches {@expectedByteSize} (if
|
||||
* not, redownload it).
|
||||
*
|
||||
* @returns the path to the model on the local machine.
|
||||
*/
|
||||
export const modelPathDownloadingIfNeeded = async (
|
||||
modelName: string,
|
||||
expectedByteSize: number,
|
||||
) => {
|
||||
const modelPath = modelSavePath(modelName);
|
||||
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
await downloadModel(modelPath, modelName);
|
||||
} else {
|
||||
const size = (await fs.stat(modelPath)).size;
|
||||
if (size !== expectedByteSize) {
|
||||
log.error(
|
||||
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
|
||||
);
|
||||
await downloadModel(modelPath, modelName);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
export const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
export const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download
|
||||
log.info(`Downloading ML model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
// Save
|
||||
await writeStream(saveLocation, res.body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Crete an ONNX {@link InferenceSession} with some defaults.
|
||||
*/
|
||||
export const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
// Restrict the number of threads to 1
|
||||
intraOpNumThreads: 1,
|
||||
// Be more conservative with RAM usage
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
};
|
||||
@@ -143,6 +143,12 @@ const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
|
||||
const clipTextEmbedding = (text: string): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("clipTextEmbedding", text);
|
||||
|
||||
const detectFaces = (input: Float32Array): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("detectFaces", input);
|
||||
|
||||
const faceEmbedding = (input: Float32Array): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("faceEmbedding", input);
|
||||
|
||||
// - File selection
|
||||
|
||||
// TODO: Deprecated - use dialogs on the renderer process itself
|
||||
@@ -322,6 +328,8 @@ contextBridge.exposeInMainWorld("electron", {
|
||||
// - ML
|
||||
clipImageEmbedding,
|
||||
clipTextEmbedding,
|
||||
detectFaces,
|
||||
faceEmbedding,
|
||||
|
||||
// - File selection
|
||||
selectDirectory,
|
||||
|
||||
@@ -53,15 +53,4 @@ export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
|
||||
enabled: false,
|
||||
};
|
||||
|
||||
export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000;
|
||||
|
||||
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
|
||||
|
||||
export const MAX_ML_SYNC_ERROR_COUNT = 1;
|
||||
|
||||
export const TEXT_DETECTION_TIMEOUT_MS = [10000, 30000, 60000, 120000, 240000];
|
||||
|
||||
export const MOBILEFACENET_FACE_SIZE = 112;
|
||||
export const MOBILEFACENET_EMBEDDING_SIZE = 192;
|
||||
|
||||
export const BATCHES_BEFORE_SYNCING_INDEX = 5;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { inWorker } from "@/next/env";
|
||||
import log from "@/next/log";
|
||||
import { workerBridge } from "@/next/worker/worker-bridge";
|
||||
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
||||
import { CustomError } from "@ente/shared/error";
|
||||
import HTTPService from "@ente/shared/network/HTTPService";
|
||||
@@ -262,7 +264,9 @@ export const putEmbedding = async (
|
||||
putEmbeddingReq: PutEmbeddingRequest,
|
||||
): Promise<EncryptedEmbedding> => {
|
||||
try {
|
||||
const token = getToken();
|
||||
const token = inWorker()
|
||||
? await workerBridge.getAuthToken()
|
||||
: getToken();
|
||||
if (!token) {
|
||||
log.info("putEmbedding failed: token not found");
|
||||
throw Error(CustomError.TOKEN_MISSING);
|
||||
|
||||
@@ -55,7 +55,7 @@ class FaceService {
|
||||
await syncContext.faceDetectionService.detectFaces(imageBitmap);
|
||||
console.timeEnd(timerId);
|
||||
console.log("faceDetections: ", faceDetections?.length);
|
||||
// log.info('3 TF Memory stats: ',JSON.stringify(tf.memory()));
|
||||
|
||||
// TODO: reenable faces filtering based on width
|
||||
const detectedFaces = faceDetections?.map((detection) => {
|
||||
return {
|
||||
@@ -150,7 +150,7 @@ class FaceService {
|
||||
|
||||
imageBitmap.close();
|
||||
log.info("[MLService] alignedFaces: ", newMlFile.faces?.length);
|
||||
// log.info('4 TF Memory stats: ',JSON.stringify(tf.memory()));
|
||||
|
||||
return faceImages;
|
||||
}
|
||||
|
||||
@@ -187,7 +187,6 @@ class FaceService {
|
||||
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
|
||||
|
||||
log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length);
|
||||
// log.info('5 TF Memory stats: ',JSON.stringify(tf.memory()));
|
||||
}
|
||||
|
||||
async syncFileFaceMakeRelativeDetections(
|
||||
@@ -226,11 +225,21 @@ class FaceService {
|
||||
face.detection,
|
||||
syncContext.config.faceCrop,
|
||||
);
|
||||
face.crop = await storeFaceCrop(
|
||||
face.id,
|
||||
faceCrop,
|
||||
syncContext.config.faceCrop.blobOptions,
|
||||
);
|
||||
try {
|
||||
face.crop = await storeFaceCrop(
|
||||
face.id,
|
||||
faceCrop,
|
||||
syncContext.config.faceCrop.blobOptions,
|
||||
);
|
||||
} catch (e) {
|
||||
// TODO(MR): Temporarily ignoring errors about failing cache puts
|
||||
// when using a custom scheme in Electron. Needs an alternative
|
||||
// approach, perhaps OPFS.
|
||||
console.error(
|
||||
"Ignoring error when caching face crop, the face crop will not be available",
|
||||
e,
|
||||
);
|
||||
}
|
||||
const blob = await imageBitmapToBlob(faceCrop.image);
|
||||
faceCrop.image.close();
|
||||
return blob;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { MOBILEFACENET_FACE_SIZE } from "constants/mlConfig";
|
||||
import {
|
||||
BlurDetectionMethod,
|
||||
BlurDetectionService,
|
||||
Versioned,
|
||||
} from "types/machineLearning";
|
||||
import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image";
|
||||
import { mobileFaceNetFaceSize } from "./mobileFaceNetEmbeddingService";
|
||||
|
||||
class LaplacianBlurDetectionService implements BlurDetectionService {
|
||||
public method: Versioned<BlurDetectionMethod>;
|
||||
@@ -19,7 +19,7 @@ class LaplacianBlurDetectionService implements BlurDetectionService {
|
||||
public detectBlur(alignedFaces: Float32Array): number[] {
|
||||
const numFaces = Math.round(
|
||||
alignedFaces.length /
|
||||
(MOBILEFACENET_FACE_SIZE * MOBILEFACENET_FACE_SIZE * 3),
|
||||
(mobileFaceNetFaceSize * mobileFaceNetFaceSize * 3),
|
||||
);
|
||||
const blurValues: number[] = [];
|
||||
for (let i = 0; i < numFaces; i++) {
|
||||
|
||||
@@ -203,9 +203,6 @@ export class LocalMLSyncContext implements MLSyncContext {
|
||||
}
|
||||
|
||||
public async dispose() {
|
||||
// await this.faceDetectionService.dispose();
|
||||
// await this.faceEmbeddingService.dispose();
|
||||
|
||||
this.localFilesMap = undefined;
|
||||
await this.syncQueue.onIdle();
|
||||
this.syncQueue.removeAllListeners();
|
||||
|
||||
@@ -186,8 +186,7 @@ class MLWorkManager {
|
||||
return mlWorker.syncLocalFile(token, userID, enteFile, localFile);
|
||||
});
|
||||
|
||||
// @ts-expect-error "TODO: Fix ML related type errors"
|
||||
if ("message" in result) {
|
||||
if (result instanceof Error) {
|
||||
// TODO: redirect/refresh to gallery in case of session_expired
|
||||
// may not be required as uploader should anyways take care of this
|
||||
console.error("Error while syncing local file: ", result);
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import {
|
||||
MOBILEFACENET_EMBEDDING_SIZE,
|
||||
MOBILEFACENET_FACE_SIZE,
|
||||
} from "constants/mlConfig";
|
||||
import { workerBridge } from "@/next/worker/worker-bridge";
|
||||
import {
|
||||
FaceEmbedding,
|
||||
FaceEmbeddingMethod,
|
||||
@@ -9,17 +6,9 @@ import {
|
||||
Versioned,
|
||||
} from "types/machineLearning";
|
||||
|
||||
// TODO(MR): onnx-yolo
|
||||
// import * as ort from "onnxruntime-web";
|
||||
// import { env } from "onnxruntime-web";
|
||||
const ort: any = {};
|
||||
export const mobileFaceNetFaceSize = 112;
|
||||
|
||||
// TODO(MR): onnx-yolo
|
||||
// env.wasm.wasmPaths = "/js/onnx/";
|
||||
class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
|
||||
// TODO(MR): onnx-yolo
|
||||
// private onnxInferenceSession?: ort.InferenceSession;
|
||||
private onnxInferenceSession?: any;
|
||||
public method: Versioned<FaceEmbeddingMethod>;
|
||||
public faceSize: number;
|
||||
|
||||
@@ -28,80 +17,25 @@ class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
|
||||
value: "MobileFaceNet",
|
||||
version: 2,
|
||||
};
|
||||
this.faceSize = MOBILEFACENET_FACE_SIZE;
|
||||
// TODO: set timeout
|
||||
}
|
||||
|
||||
private async initOnnx() {
|
||||
console.log("start ort mobilefacenet");
|
||||
this.onnxInferenceSession = await ort.InferenceSession.create(
|
||||
"/models/mobilefacenet/mobilefacenet_opset15.onnx",
|
||||
);
|
||||
const faceBatchSize = 1;
|
||||
const data = new Float32Array(
|
||||
faceBatchSize * 3 * this.faceSize * this.faceSize,
|
||||
);
|
||||
const inputTensor = new ort.Tensor("float32", data, [
|
||||
faceBatchSize,
|
||||
this.faceSize,
|
||||
this.faceSize,
|
||||
3,
|
||||
]);
|
||||
// TODO(MR): onnx-yolo
|
||||
// const feeds: Record<string, ort.Tensor> = {};
|
||||
const feeds: Record<string, any> = {};
|
||||
const name = this.onnxInferenceSession.inputNames[0];
|
||||
feeds[name] = inputTensor;
|
||||
await this.onnxInferenceSession.run(feeds);
|
||||
console.log("start end mobilefacenet");
|
||||
}
|
||||
|
||||
private async getOnnxInferenceSession() {
|
||||
if (!this.onnxInferenceSession) {
|
||||
await this.initOnnx();
|
||||
}
|
||||
return this.onnxInferenceSession;
|
||||
this.faceSize = mobileFaceNetFaceSize;
|
||||
}
|
||||
|
||||
public async getFaceEmbeddings(
|
||||
faceData: Float32Array,
|
||||
): Promise<Array<FaceEmbedding>> {
|
||||
const inputTensor = new ort.Tensor("float32", faceData, [
|
||||
Math.round(faceData.length / (this.faceSize * this.faceSize * 3)),
|
||||
this.faceSize,
|
||||
this.faceSize,
|
||||
3,
|
||||
]);
|
||||
// TODO(MR): onnx-yolo
|
||||
// const feeds: Record<string, ort.Tensor> = {};
|
||||
const feeds: Record<string, any> = {};
|
||||
feeds["img_inputs"] = inputTensor;
|
||||
const inferenceSession = await this.getOnnxInferenceSession();
|
||||
// TODO(MR): onnx-yolo
|
||||
// const runout: ort.InferenceSession.OnnxValueMapType =
|
||||
const runout: any = await inferenceSession.run(feeds);
|
||||
// const test = runout.embeddings;
|
||||
// const test2 = test.cpuData;
|
||||
const outputData = runout.embeddings["cpuData"] as Float32Array;
|
||||
const outputData = await workerBridge.faceEmbedding(faceData);
|
||||
|
||||
const embeddingSize = 192;
|
||||
const embeddings = new Array<FaceEmbedding>(
|
||||
outputData.length / MOBILEFACENET_EMBEDDING_SIZE,
|
||||
outputData.length / embeddingSize,
|
||||
);
|
||||
for (let i = 0; i < embeddings.length; i++) {
|
||||
embeddings[i] = new Float32Array(
|
||||
outputData.slice(
|
||||
i * MOBILEFACENET_EMBEDDING_SIZE,
|
||||
(i + 1) * MOBILEFACENET_EMBEDDING_SIZE,
|
||||
),
|
||||
outputData.slice(i * embeddingSize, (i + 1) * embeddingSize),
|
||||
);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
public async dispose() {
|
||||
const inferenceSession = await this.getOnnxInferenceSession();
|
||||
inferenceSession?.release();
|
||||
this.onnxInferenceSession = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export default new MobileFaceNetEmbeddingService();
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { MAX_FACE_DISTANCE_PERCENT } from "constants/mlConfig";
|
||||
import { workerBridge } from "@/next/worker/worker-bridge";
|
||||
import { euclidean } from "hdbscan";
|
||||
import {
|
||||
Matrix,
|
||||
applyToPoint,
|
||||
@@ -21,17 +22,7 @@ import {
|
||||
import { newBox } from "utils/machineLearning";
|
||||
import { Box, Point } from "../../../thirdparty/face-api/classes";
|
||||
|
||||
// TODO(MR): onnx-yolo
|
||||
// import * as ort from "onnxruntime-web";
|
||||
// import { env } from "onnxruntime-web";
|
||||
const ort: any = {};
|
||||
|
||||
// TODO(MR): onnx-yolo
|
||||
// env.wasm.wasmPaths = "/js/onnx/";
|
||||
class YoloFaceDetectionService implements FaceDetectionService {
|
||||
// TODO(MR): onnx-yolo
|
||||
// private onnxInferenceSession?: ort.InferenceSession;
|
||||
private onnxInferenceSession?: any;
|
||||
public method: Versioned<FaceDetectionMethod>;
|
||||
|
||||
public constructor() {
|
||||
@@ -41,27 +32,38 @@ class YoloFaceDetectionService implements FaceDetectionService {
|
||||
};
|
||||
}
|
||||
|
||||
private async initOnnx() {
|
||||
console.log("start ort");
|
||||
this.onnxInferenceSession = await ort.InferenceSession.create(
|
||||
"/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
|
||||
public async detectFaces(
|
||||
imageBitmap: ImageBitmap,
|
||||
): Promise<Array<FaceDetection>> {
|
||||
const maxFaceDistancePercent = Math.sqrt(2) / 100;
|
||||
const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
|
||||
const preprocessResult =
|
||||
this.preprocessImageBitmapToFloat32ChannelsFirst(
|
||||
imageBitmap,
|
||||
640,
|
||||
640,
|
||||
);
|
||||
const data = preprocessResult.data;
|
||||
const resized = preprocessResult.newSize;
|
||||
const outputData = await workerBridge.detectFaces(data);
|
||||
const faces = this.getFacesFromYoloOutput(
|
||||
outputData as Float32Array,
|
||||
0.7,
|
||||
);
|
||||
const data = new Float32Array(1 * 3 * 640 * 640);
|
||||
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
|
||||
// TODO(MR): onnx-yolo
|
||||
// const feeds: Record<string, ort.Tensor> = {};
|
||||
const feeds: Record<string, any> = {};
|
||||
const name = this.onnxInferenceSession.inputNames[0];
|
||||
feeds[name] = inputTensor;
|
||||
await this.onnxInferenceSession.run(feeds);
|
||||
console.log("start end");
|
||||
}
|
||||
|
||||
private async getOnnxInferenceSession() {
|
||||
if (!this.onnxInferenceSession) {
|
||||
await this.initOnnx();
|
||||
}
|
||||
return this.onnxInferenceSession;
|
||||
const inBox = newBox(0, 0, resized.width, resized.height);
|
||||
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
|
||||
const transform = computeTransformToBox(inBox, toBox);
|
||||
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
|
||||
const box = transformBox(f.box, transform);
|
||||
const normLandmarks = f.landmarks;
|
||||
const landmarks = transformPoints(normLandmarks, transform);
|
||||
return {
|
||||
box,
|
||||
landmarks,
|
||||
probability: f.probability as number,
|
||||
} as FaceDetection;
|
||||
});
|
||||
return removeDuplicateDetections(faceDetections, maxFaceDistance);
|
||||
}
|
||||
|
||||
private preprocessImageBitmapToFloat32ChannelsFirst(
|
||||
@@ -156,43 +158,6 @@ class YoloFaceDetectionService implements FaceDetectionService {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated The method should not be used
|
||||
*/
|
||||
private imageBitmapToTensorData(imageBitmap) {
|
||||
// Create an OffscreenCanvas and set its size
|
||||
const offscreenCanvas = new OffscreenCanvas(
|
||||
imageBitmap.width,
|
||||
imageBitmap.height,
|
||||
);
|
||||
const ctx = offscreenCanvas.getContext("2d");
|
||||
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
|
||||
const imageData = ctx.getImageData(
|
||||
0,
|
||||
0,
|
||||
imageBitmap.width,
|
||||
imageBitmap.height,
|
||||
);
|
||||
const pixelData = imageData.data;
|
||||
const data = new Float32Array(
|
||||
1 * 3 * imageBitmap.width * imageBitmap.height,
|
||||
);
|
||||
// Populate the Float32Array with normalized pixel values
|
||||
for (let i = 0; i < pixelData.length; i += 4) {
|
||||
// Normalize pixel values to the range [0, 1]
|
||||
data[i / 4] = pixelData[i] / 255.0; // Red channel
|
||||
data[i / 4 + imageBitmap.width * imageBitmap.height] =
|
||||
pixelData[i + 1] / 255.0; // Green channel
|
||||
data[i / 4 + 2 * imageBitmap.width * imageBitmap.height] =
|
||||
pixelData[i + 2] / 255.0; // Blue channel
|
||||
}
|
||||
|
||||
return {
|
||||
data: data,
|
||||
shape: [1, 3, imageBitmap.width, imageBitmap.height],
|
||||
};
|
||||
}
|
||||
|
||||
// The rowOutput is a Float32Array of shape [25200, 16], where each row represents a bounding box.
|
||||
private getFacesFromYoloOutput(
|
||||
rowOutput: Float32Array,
|
||||
@@ -270,64 +235,10 @@ class YoloFaceDetectionService implements FaceDetectionService {
|
||||
probability: faceDetection.probability,
|
||||
};
|
||||
}
|
||||
|
||||
private async estimateOnnx(imageBitmap: ImageBitmap) {
|
||||
const maxFaceDistance = imageBitmap.width * MAX_FACE_DISTANCE_PERCENT;
|
||||
const preprocessResult =
|
||||
this.preprocessImageBitmapToFloat32ChannelsFirst(
|
||||
imageBitmap,
|
||||
640,
|
||||
640,
|
||||
);
|
||||
const data = preprocessResult.data;
|
||||
const resized = preprocessResult.newSize;
|
||||
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
|
||||
// TODO(MR): onnx-yolo
|
||||
// const feeds: Record<string, ort.Tensor> = {};
|
||||
const feeds: Record<string, any> = {};
|
||||
feeds["input"] = inputTensor;
|
||||
const inferenceSession = await this.getOnnxInferenceSession();
|
||||
const runout = await inferenceSession.run(feeds);
|
||||
const outputData = runout.output.data;
|
||||
const faces = this.getFacesFromYoloOutput(
|
||||
outputData as Float32Array,
|
||||
0.7,
|
||||
);
|
||||
const inBox = newBox(0, 0, resized.width, resized.height);
|
||||
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
|
||||
const transform = computeTransformToBox(inBox, toBox);
|
||||
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
|
||||
const box = transformBox(f.box, transform);
|
||||
const normLandmarks = f.landmarks;
|
||||
const landmarks = transformPoints(normLandmarks, transform);
|
||||
return {
|
||||
box,
|
||||
landmarks,
|
||||
probability: f.probability as number,
|
||||
} as FaceDetection;
|
||||
});
|
||||
return removeDuplicateDetections(faceDetections, maxFaceDistance);
|
||||
}
|
||||
|
||||
public async detectFaces(
|
||||
imageBitmap: ImageBitmap,
|
||||
): Promise<Array<FaceDetection>> {
|
||||
// measure time taken
|
||||
const facesFromOnnx = await this.estimateOnnx(imageBitmap);
|
||||
return facesFromOnnx;
|
||||
}
|
||||
|
||||
public async dispose() {
|
||||
const inferenceSession = await this.getOnnxInferenceSession();
|
||||
inferenceSession?.release();
|
||||
this.onnxInferenceSession = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export default new YoloFaceDetectionService();
|
||||
|
||||
import { euclidean } from "hdbscan";
|
||||
|
||||
/**
|
||||
* Removes duplicate face detections from an array of detections.
|
||||
*
|
||||
|
||||
@@ -45,16 +45,13 @@ export declare type Landmark = Point;
|
||||
|
||||
export declare type ImageType = "Original" | "Preview";
|
||||
|
||||
export declare type FaceDetectionMethod = "FaceApiSSD" | "YoloFace";
|
||||
export declare type FaceDetectionMethod = "YoloFace";
|
||||
|
||||
export declare type FaceCropMethod = "ArcFace";
|
||||
|
||||
export declare type FaceAlignmentMethod =
|
||||
| "ArcFace"
|
||||
| "FaceApiDlib"
|
||||
| "RotatedFaceApiDlib";
|
||||
export declare type FaceAlignmentMethod = "ArcFace";
|
||||
|
||||
export declare type FaceEmbeddingMethod = "MobileFaceNet" | "FaceApiDlib";
|
||||
export declare type FaceEmbeddingMethod = "MobileFaceNet";
|
||||
|
||||
export declare type BlurDetectionMethod = "Laplacian";
|
||||
|
||||
@@ -261,13 +258,12 @@ export declare type MLIndex = "files" | "people";
|
||||
|
||||
export interface FaceDetectionService {
|
||||
method: Versioned<FaceDetectionMethod>;
|
||||
// init(): Promise<void>;
|
||||
|
||||
detectFaces(image: ImageBitmap): Promise<Array<FaceDetection>>;
|
||||
getRelativeDetection(
|
||||
faceDetection: FaceDetection,
|
||||
imageDimensions: Dimensions,
|
||||
): FaceDetection;
|
||||
dispose(): Promise<void>;
|
||||
}
|
||||
|
||||
export interface FaceCropService {
|
||||
@@ -288,9 +284,8 @@ export interface FaceAlignmentService {
|
||||
export interface FaceEmbeddingService {
|
||||
method: Versioned<FaceEmbeddingMethod>;
|
||||
faceSize: number;
|
||||
// init(): Promise<void>;
|
||||
|
||||
getFaceEmbeddings(faceImages: Float32Array): Promise<Array<FaceEmbedding>>;
|
||||
dispose(): Promise<void>;
|
||||
}
|
||||
|
||||
export interface BlurDetectionService {
|
||||
|
||||
@@ -50,7 +50,7 @@ export class SimpleJob<R extends JobResult> {
|
||||
|
||||
try {
|
||||
const jobResult = await this.runCallback();
|
||||
if (jobResult.shouldBackoff) {
|
||||
if (jobResult && jobResult.shouldBackoff) {
|
||||
this.intervalSec = Math.min(
|
||||
this.config.maxItervalSec,
|
||||
this.intervalSec * this.config.backoffMultiplier,
|
||||
|
||||
@@ -124,6 +124,10 @@ class MLIDbStorage {
|
||||
.add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME);
|
||||
}
|
||||
if (oldVersion < 4) {
|
||||
// TODO(MR): This loses the user's settings.
|
||||
db.deleteObjectStore("configs");
|
||||
db.createObjectStore("configs");
|
||||
|
||||
db.deleteObjectStore("things");
|
||||
}
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ export interface Electron {
|
||||
// - ML
|
||||
|
||||
/**
|
||||
* Compute and return a CLIP embedding of the given image.
|
||||
* Return a CLIP embedding of the given image.
|
||||
*
|
||||
* See: [Note: CLIP based magic search]
|
||||
*
|
||||
@@ -207,7 +207,7 @@ export interface Electron {
|
||||
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Compute and return a CLIP embedding of the given image.
|
||||
* Return a CLIP embedding of the given image.
|
||||
*
|
||||
* See: [Note: CLIP based magic search]
|
||||
*
|
||||
@@ -217,6 +217,22 @@ export interface Electron {
|
||||
*/
|
||||
clipTextEmbedding: (text: string) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Detect faces in the given image using YOLO.
|
||||
*
|
||||
* Both the input and output are opaque binary data whose internal structure
|
||||
* is specific to our implementation and the model (YOLO) we use.
|
||||
*/
|
||||
detectFaces: (input: Float32Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Return a MobileFaceNet embedding for the given face data.
|
||||
*
|
||||
* Both the input and output are opaque binary data whose internal structure
|
||||
* is specific to our implementation and the model (MobileFaceNet) we use.
|
||||
*/
|
||||
faceEmbedding: (input: Float32Array) => Promise<Float32Array>;
|
||||
|
||||
// - File selection
|
||||
// TODO: Deprecated - use dialogs on the renderer process itself
|
||||
|
||||
|
||||
@@ -35,6 +35,19 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(MR): Temporary method to forward auth tokens to workers
|
||||
const getAuthToken = () => {
|
||||
// LS_KEYS.USER
|
||||
const userJSONString = localStorage.getItem("user");
|
||||
if (!userJSONString) return undefined;
|
||||
const json: unknown = JSON.parse(userJSONString);
|
||||
if (!json || typeof json != "object" || !("token" in json))
|
||||
return undefined;
|
||||
const token = json.token;
|
||||
if (typeof token != "string") return undefined;
|
||||
return token;
|
||||
};
|
||||
|
||||
/**
|
||||
* A minimal set of utility functions that we expose to all workers that we
|
||||
* create.
|
||||
@@ -44,8 +57,12 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
*/
|
||||
const workerBridge = {
|
||||
logToDisk,
|
||||
getAuthToken,
|
||||
convertToJPEG: (inputFileData: Uint8Array, filename: string) =>
|
||||
ensureElectron().convertToJPEG(inputFileData, filename),
|
||||
detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
|
||||
faceEmbedding: (input: Float32Array) =>
|
||||
ensureElectron().faceEmbedding(input),
|
||||
};
|
||||
|
||||
export type WorkerBridge = typeof workerBridge;
|
||||
|
||||
Reference in New Issue
Block a user