diff --git a/desktop/docs/dependencies.md b/desktop/docs/dependencies.md index 9cced1f818..62f70e8e46 100644 --- a/desktop/docs/dependencies.md +++ b/desktop/docs/dependencies.md @@ -1,8 +1,8 @@ # Dependencies -* [Electron](#electron) -* [Dev dependencies](#dev) -* [Functionality](#functionality) +- [Electron](#electron) +- [Dev dependencies](#dev) +- [Functionality](#functionality) ## Electron @@ -114,8 +114,8 @@ available on the host machine, and is not bundled with our app. AI/ML runtime. It powers both natural language searches (using CLIP) and face detection (using YOLO). -[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding -JPEG data into raw RGB bytes before passing it to ONNX. +[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding JPEG +data into raw RGB bytes before passing it to ONNX. html-entities is used by the bundled clip-bpe-ts tokenizer for CLIP. diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index 2b328bb986..180e68cdcf 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -36,13 +36,14 @@ import { updateAndRestart, updateOnNextRestart, } from "./services/app-update"; -import { clipImageEmbedding, clipTextEmbedding } from "./services/clip"; import { runFFmpegCmd } from "./services/ffmpeg"; import { getDirFiles } from "./services/fs"; import { convertToJPEG, generateImageThumbnail, } from "./services/imageProcessor"; +import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip"; +import { detectFaces, faceEmbedding } from "./services/ml-face"; import { clearStores, encryptionKey, @@ -146,6 +147,14 @@ export const attachIPCHandlers = () => { clipTextEmbedding(text), ); + ipcMain.handle("detectFaces", (_, input: Float32Array) => + detectFaces(input), + ); + + ipcMain.handle("faceEmbedding", (_, input: Float32Array) => + faceEmbedding(input), + ); + // - File selection ipcMain.handle("selectDirectory", () => selectDirectory()); diff --git a/desktop/src/main/services/clip.ts b/desktop/src/main/services/ml-clip.ts similarity index 74% rename from desktop/src/main/services/clip.ts rename to desktop/src/main/services/ml-clip.ts index 525e613424..63fa751482 100644 --- a/desktop/src/main/services/clip.ts +++ b/desktop/src/main/services/ml-clip.ts @@ -1,19 +1,13 @@ /** - * @file Compute CLIP embeddings + * @file Compute CLIP embeddings for images and text. * - * @see `web/apps/photos/src/services/clip-service.ts` for more details. This - * file implements the Node.js implementation of the actual embedding - * computation. By doing it in the Node.js layer, we can use the binary ONNX - * runtimes which are 10-20x faster than the WASM based web ones. + * The embeddings are computed using ONNX runtime, with CLIP as the model. * - * The embeddings are computed using ONNX runtime. The model itself is not - * shipped with the app but is downloaded on demand. + * @see `web/apps/photos/src/services/clip-service.ts` for more details. */ -import { app, net } from "electron/main"; import { existsSync } from "fs"; import jpeg from "jpeg-js"; import fs from "node:fs/promises"; -import path from "node:path"; import * as ort from "onnxruntime-node"; import Tokenizer from "../../thirdparty/clip-bpe-ts/mod"; import { CustomErrors } from "../../types/ipc"; @@ -21,6 +15,12 @@ import { writeStream } from "../fs"; import log from "../log"; import { generateTempFilePath } from "../temp"; import { deleteTempFile } from "./ffmpeg"; +import { + createInferenceSession, + downloadModel, + modelPathDownloadingIfNeeded, + modelSavePath, +} from "./ml"; const textModelName = "clip-text-vit-32-uint8.onnx"; const textModelByteSize = 64173509; // 61.2 MB @@ -28,55 +28,20 @@ const textModelByteSize = 64173509; // 61.2 MB const imageModelName = "clip-image-vit-32-float32.onnx"; const imageModelByteSize = 351468764; // 335.2 MB -/** Return the path where the given {@link modelName} is meant to be saved */ -const modelSavePath = (modelName: string) => - path.join(app.getPath("userData"), "models", modelName); - -const downloadModel = async (saveLocation: string, name: string) => { - // `mkdir -p` the directory where we want to save the model. - const saveDir = path.dirname(saveLocation); - await fs.mkdir(saveDir, { recursive: true }); - // Download - log.info(`Downloading CLIP model from ${name}`); - const url = `https://models.ente.io/${name}`; - const res = await net.fetch(url); - if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); - // Save - await writeStream(saveLocation, res.body); - log.info(`Downloaded CLIP model ${name}`); -}; - -let activeImageModelDownload: Promise | undefined; +let activeImageModelDownload: Promise | undefined; const imageModelPathDownloadingIfNeeded = async () => { try { - const modelPath = modelSavePath(imageModelName); if (activeImageModelDownload) { log.info("Waiting for CLIP image model download to finish"); await activeImageModelDownload; } else { - if (!existsSync(modelPath)) { - log.info("CLIP image model not found, downloading"); - activeImageModelDownload = downloadModel( - modelPath, - imageModelName, - ); - await activeImageModelDownload; - } else { - const localFileSize = (await fs.stat(modelPath)).size; - if (localFileSize !== imageModelByteSize) { - log.error( - `CLIP image model size ${localFileSize} does not match the expected size, downloading again`, - ); - activeImageModelDownload = downloadModel( - modelPath, - imageModelName, - ); - await activeImageModelDownload; - } - } + activeImageModelDownload = modelPathDownloadingIfNeeded( + imageModelName, + imageModelByteSize, + ); + return await activeImageModelDownload; } - return modelPath; } finally { activeImageModelDownload = undefined; } @@ -84,6 +49,8 @@ const imageModelPathDownloadingIfNeeded = async () => { let textModelDownloadInProgress = false; +/* TODO(MR): use the generic method. Then we can remove the exports for the + internal details functions that we use here */ const textModelPathDownloadingIfNeeded = async () => { if (textModelDownloadInProgress) throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); @@ -123,13 +90,6 @@ const textModelPathDownloadingIfNeeded = async () => { return modelPath; }; -const createInferenceSession = async (modelPath: string) => { - return await ort.InferenceSession.create(modelPath, { - intraOpNumThreads: 1, - enableCpuMemArena: false, - }); -}; - let imageSessionPromise: Promise | undefined; const onnxImageSession = async () => { @@ -174,7 +134,7 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => { const results = await imageSession.run(feeds); log.debug( () => - `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + `onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, ); const imageEmbedding = results["output"].data; // Float32Array return normalizeEmbedding(imageEmbedding); @@ -281,7 +241,7 @@ export const clipTextEmbedding = async (text: string) => { const results = await imageSession.run(feeds); log.debug( () => - `CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + `onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, ); const textEmbedding = results["output"].data; return normalizeEmbedding(textEmbedding); diff --git a/desktop/src/main/services/ml-face.ts b/desktop/src/main/services/ml-face.ts new file mode 100644 index 0000000000..1f007c5fd8 --- /dev/null +++ b/desktop/src/main/services/ml-face.ts @@ -0,0 +1,108 @@ +/** + * @file Various face recognition related tasks. + * + * - Face detection with the YOLO model. + * - Face embedding with the MobileFaceNet model. + * + * The runtime used is ONNX. + */ +import * as ort from "onnxruntime-node"; +import log from "../log"; +import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml"; + +const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx"; +const faceDetectionModelByteSize = 30762872; // 29.3 MB + +const faceEmbeddingModelName = "mobilefacenet_opset15.onnx"; +const faceEmbeddingModelByteSize = 5286998; // 5 MB + +let activeFaceDetectionModelDownload: Promise | undefined; + +const faceDetectionModelPathDownloadingIfNeeded = async () => { + try { + if (activeFaceDetectionModelDownload) { + log.info("Waiting for face detection model download to finish"); + await activeFaceDetectionModelDownload; + } else { + activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded( + faceDetectionModelName, + faceDetectionModelByteSize, + ); + return await activeFaceDetectionModelDownload; + } + } finally { + activeFaceDetectionModelDownload = undefined; + } +}; + +let _faceDetectionSession: Promise | undefined; + +const faceDetectionSession = async () => { + if (!_faceDetectionSession) { + _faceDetectionSession = + faceDetectionModelPathDownloadingIfNeeded().then((modelPath) => + createInferenceSession(modelPath), + ); + } + return _faceDetectionSession; +}; + +let activeFaceEmbeddingModelDownload: Promise | undefined; + +const faceEmbeddingModelPathDownloadingIfNeeded = async () => { + try { + if (activeFaceEmbeddingModelDownload) { + log.info("Waiting for face embedding model download to finish"); + await activeFaceEmbeddingModelDownload; + } else { + activeFaceEmbeddingModelDownload = modelPathDownloadingIfNeeded( + faceEmbeddingModelName, + faceEmbeddingModelByteSize, + ); + return await activeFaceEmbeddingModelDownload; + } + } finally { + activeFaceEmbeddingModelDownload = undefined; + } +}; + +let _faceEmbeddingSession: Promise | undefined; + +const faceEmbeddingSession = async () => { + if (!_faceEmbeddingSession) { + _faceEmbeddingSession = + faceEmbeddingModelPathDownloadingIfNeeded().then((modelPath) => + createInferenceSession(modelPath), + ); + } + return _faceEmbeddingSession; +}; + +export const detectFaces = async (input: Float32Array) => { + const session = await faceDetectionSession(); + const t = Date.now(); + const feeds = { + input: new ort.Tensor("float32", input, [1, 3, 640, 640]), + }; + const results = await session.run(feeds); + log.debug(() => `onnx/yolo face detection took ${Date.now() - t} ms`); + return results["output"].data; +}; + +export const faceEmbedding = async (input: Float32Array) => { + // Dimension of each face (alias) + const mobileFaceNetFaceSize = 112; + // Smaller alias + const z = mobileFaceNetFaceSize; + // Size of each face's data in the batch + const n = Math.round(input.length / (z * z * 3)); + const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]); + + const session = await faceEmbeddingSession(); + const t = Date.now(); + const feeds = { img_inputs: inputTensor }; + const results = await session.run(feeds); + log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`); + // TODO: What's with this type? It works in practice, but double check. + return (results.embeddings as unknown as any)["cpuData"]; // as Float32Array; +}; diff --git a/desktop/src/main/services/ml.ts b/desktop/src/main/services/ml.ts new file mode 100644 index 0000000000..10402db217 --- /dev/null +++ b/desktop/src/main/services/ml.ts @@ -0,0 +1,79 @@ +/** + * @file AI/ML related functionality. + * + * @see also `ml-clip.ts`, `ml-face.ts`. + * + * The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models + * for various tasks are not shipped with the app but are downloaded on demand. + * + * The primary reason for doing these tasks in the Node.js layer is so that we + * can use the binary ONNX runtime which is 10-20x faster than the WASM based + * web one. + */ +import { app, net } from "electron/main"; +import { existsSync } from "fs"; +import fs from "node:fs/promises"; +import path from "node:path"; +import * as ort from "onnxruntime-node"; +import { writeStream } from "../fs"; +import log from "../log"; + +/** + * Download the model named {@link modelName} if we don't already have it. + * + * Also verify that the size of the model we get matches {@expectedByteSize} (if + * not, redownload it). + * + * @returns the path to the model on the local machine. + */ +export const modelPathDownloadingIfNeeded = async ( + modelName: string, + expectedByteSize: number, +) => { + const modelPath = modelSavePath(modelName); + + if (!existsSync(modelPath)) { + log.info("CLIP image model not found, downloading"); + await downloadModel(modelPath, modelName); + } else { + const size = (await fs.stat(modelPath)).size; + if (size !== expectedByteSize) { + log.error( + `The size ${size} of model ${modelName} does not match the expected size, downloading again`, + ); + await downloadModel(modelPath, modelName); + } + } + + return modelPath; +}; + +/** Return the path where the given {@link modelName} is meant to be saved */ +export const modelSavePath = (modelName: string) => + path.join(app.getPath("userData"), "models", modelName); + +export const downloadModel = async (saveLocation: string, name: string) => { + // `mkdir -p` the directory where we want to save the model. + const saveDir = path.dirname(saveLocation); + await fs.mkdir(saveDir, { recursive: true }); + // Download + log.info(`Downloading ML model from ${name}`); + const url = `https://models.ente.io/${name}`; + const res = await net.fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); + // Save + await writeStream(saveLocation, res.body); + log.info(`Downloaded CLIP model ${name}`); +}; + +/** + * Crete an ONNX {@link InferenceSession} with some defaults. + */ +export const createInferenceSession = async (modelPath: string) => { + return await ort.InferenceSession.create(modelPath, { + // Restrict the number of threads to 1 + intraOpNumThreads: 1, + // Be more conservative with RAM usage + enableCpuMemArena: false, + }); +}; diff --git a/desktop/src/preload.ts b/desktop/src/preload.ts index 07736502bd..2db39e2290 100644 --- a/desktop/src/preload.ts +++ b/desktop/src/preload.ts @@ -143,6 +143,12 @@ const clipImageEmbedding = (jpegImageData: Uint8Array): Promise => const clipTextEmbedding = (text: string): Promise => ipcRenderer.invoke("clipTextEmbedding", text); +const detectFaces = (input: Float32Array): Promise => + ipcRenderer.invoke("detectFaces", input); + +const faceEmbedding = (input: Float32Array): Promise => + ipcRenderer.invoke("faceEmbedding", input); + // - File selection // TODO: Deprecated - use dialogs on the renderer process itself @@ -322,6 +328,8 @@ contextBridge.exposeInMainWorld("electron", { // - ML clipImageEmbedding, clipTextEmbedding, + detectFaces, + faceEmbedding, // - File selection selectDirectory, diff --git a/web/apps/photos/src/constants/mlConfig.ts b/web/apps/photos/src/constants/mlConfig.ts index ff3eed264a..929594e1c1 100644 --- a/web/apps/photos/src/constants/mlConfig.ts +++ b/web/apps/photos/src/constants/mlConfig.ts @@ -53,15 +53,4 @@ export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = { enabled: false, }; -export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000; - -export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100; - export const MAX_ML_SYNC_ERROR_COUNT = 1; - -export const TEXT_DETECTION_TIMEOUT_MS = [10000, 30000, 60000, 120000, 240000]; - -export const MOBILEFACENET_FACE_SIZE = 112; -export const MOBILEFACENET_EMBEDDING_SIZE = 192; - -export const BATCHES_BEFORE_SYNCING_INDEX = 5; diff --git a/web/apps/photos/src/services/embeddingService.ts b/web/apps/photos/src/services/embeddingService.ts index b93b01532b..a4309e314c 100644 --- a/web/apps/photos/src/services/embeddingService.ts +++ b/web/apps/photos/src/services/embeddingService.ts @@ -1,4 +1,6 @@ +import { inWorker } from "@/next/env"; import log from "@/next/log"; +import { workerBridge } from "@/next/worker/worker-bridge"; import ComlinkCryptoWorker from "@ente/shared/crypto"; import { CustomError } from "@ente/shared/error"; import HTTPService from "@ente/shared/network/HTTPService"; @@ -262,7 +264,9 @@ export const putEmbedding = async ( putEmbeddingReq: PutEmbeddingRequest, ): Promise => { try { - const token = getToken(); + const token = inWorker() + ? await workerBridge.getAuthToken() + : getToken(); if (!token) { log.info("putEmbedding failed: token not found"); throw Error(CustomError.TOKEN_MISSING); diff --git a/web/apps/photos/src/services/machineLearning/faceService.ts b/web/apps/photos/src/services/machineLearning/faceService.ts index 449ae0b964..3116ac23c2 100644 --- a/web/apps/photos/src/services/machineLearning/faceService.ts +++ b/web/apps/photos/src/services/machineLearning/faceService.ts @@ -55,7 +55,7 @@ class FaceService { await syncContext.faceDetectionService.detectFaces(imageBitmap); console.timeEnd(timerId); console.log("faceDetections: ", faceDetections?.length); - // log.info('3 TF Memory stats: ',JSON.stringify(tf.memory())); + // TODO: reenable faces filtering based on width const detectedFaces = faceDetections?.map((detection) => { return { @@ -150,7 +150,7 @@ class FaceService { imageBitmap.close(); log.info("[MLService] alignedFaces: ", newMlFile.faces?.length); - // log.info('4 TF Memory stats: ',JSON.stringify(tf.memory())); + return faceImages; } @@ -187,7 +187,6 @@ class FaceService { newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i])); log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length); - // log.info('5 TF Memory stats: ',JSON.stringify(tf.memory())); } async syncFileFaceMakeRelativeDetections( @@ -226,11 +225,21 @@ class FaceService { face.detection, syncContext.config.faceCrop, ); - face.crop = await storeFaceCrop( - face.id, - faceCrop, - syncContext.config.faceCrop.blobOptions, - ); + try { + face.crop = await storeFaceCrop( + face.id, + faceCrop, + syncContext.config.faceCrop.blobOptions, + ); + } catch (e) { + // TODO(MR): Temporarily ignoring errors about failing cache puts + // when using a custom scheme in Electron. Needs an alternative + // approach, perhaps OPFS. + console.error( + "Ignoring error when caching face crop, the face crop will not be available", + e, + ); + } const blob = await imageBitmapToBlob(faceCrop.image); faceCrop.image.close(); return blob; diff --git a/web/apps/photos/src/services/machineLearning/laplacianBlurDetectionService.ts b/web/apps/photos/src/services/machineLearning/laplacianBlurDetectionService.ts index b5842f70c2..14178a5351 100644 --- a/web/apps/photos/src/services/machineLearning/laplacianBlurDetectionService.ts +++ b/web/apps/photos/src/services/machineLearning/laplacianBlurDetectionService.ts @@ -1,10 +1,10 @@ -import { MOBILEFACENET_FACE_SIZE } from "constants/mlConfig"; import { BlurDetectionMethod, BlurDetectionService, Versioned, } from "types/machineLearning"; import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image"; +import { mobileFaceNetFaceSize } from "./mobileFaceNetEmbeddingService"; class LaplacianBlurDetectionService implements BlurDetectionService { public method: Versioned; @@ -19,7 +19,7 @@ class LaplacianBlurDetectionService implements BlurDetectionService { public detectBlur(alignedFaces: Float32Array): number[] { const numFaces = Math.round( alignedFaces.length / - (MOBILEFACENET_FACE_SIZE * MOBILEFACENET_FACE_SIZE * 3), + (mobileFaceNetFaceSize * mobileFaceNetFaceSize * 3), ); const blurValues: number[] = []; for (let i = 0; i < numFaces; i++) { diff --git a/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts b/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts index 36e37d9b83..991ae68087 100644 --- a/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts +++ b/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts @@ -203,9 +203,6 @@ export class LocalMLSyncContext implements MLSyncContext { } public async dispose() { - // await this.faceDetectionService.dispose(); - // await this.faceEmbeddingService.dispose(); - this.localFilesMap = undefined; await this.syncQueue.onIdle(); this.syncQueue.removeAllListeners(); diff --git a/web/apps/photos/src/services/machineLearning/mlWorkManager.ts b/web/apps/photos/src/services/machineLearning/mlWorkManager.ts index d62d6f829e..c5df14b224 100644 --- a/web/apps/photos/src/services/machineLearning/mlWorkManager.ts +++ b/web/apps/photos/src/services/machineLearning/mlWorkManager.ts @@ -186,8 +186,7 @@ class MLWorkManager { return mlWorker.syncLocalFile(token, userID, enteFile, localFile); }); - // @ts-expect-error "TODO: Fix ML related type errors" - if ("message" in result) { + if (result instanceof Error) { // TODO: redirect/refresh to gallery in case of session_expired // may not be required as uploader should anyways take care of this console.error("Error while syncing local file: ", result); diff --git a/web/apps/photos/src/services/machineLearning/mobileFaceNetEmbeddingService.ts b/web/apps/photos/src/services/machineLearning/mobileFaceNetEmbeddingService.ts index 39953689e6..818b8a5d12 100644 --- a/web/apps/photos/src/services/machineLearning/mobileFaceNetEmbeddingService.ts +++ b/web/apps/photos/src/services/machineLearning/mobileFaceNetEmbeddingService.ts @@ -1,7 +1,4 @@ -import { - MOBILEFACENET_EMBEDDING_SIZE, - MOBILEFACENET_FACE_SIZE, -} from "constants/mlConfig"; +import { workerBridge } from "@/next/worker/worker-bridge"; import { FaceEmbedding, FaceEmbeddingMethod, @@ -9,17 +6,9 @@ import { Versioned, } from "types/machineLearning"; -// TODO(MR): onnx-yolo -// import * as ort from "onnxruntime-web"; -// import { env } from "onnxruntime-web"; -const ort: any = {}; +export const mobileFaceNetFaceSize = 112; -// TODO(MR): onnx-yolo -// env.wasm.wasmPaths = "/js/onnx/"; class MobileFaceNetEmbeddingService implements FaceEmbeddingService { - // TODO(MR): onnx-yolo - // private onnxInferenceSession?: ort.InferenceSession; - private onnxInferenceSession?: any; public method: Versioned; public faceSize: number; @@ -28,80 +17,25 @@ class MobileFaceNetEmbeddingService implements FaceEmbeddingService { value: "MobileFaceNet", version: 2, }; - this.faceSize = MOBILEFACENET_FACE_SIZE; - // TODO: set timeout - } - - private async initOnnx() { - console.log("start ort mobilefacenet"); - this.onnxInferenceSession = await ort.InferenceSession.create( - "/models/mobilefacenet/mobilefacenet_opset15.onnx", - ); - const faceBatchSize = 1; - const data = new Float32Array( - faceBatchSize * 3 * this.faceSize * this.faceSize, - ); - const inputTensor = new ort.Tensor("float32", data, [ - faceBatchSize, - this.faceSize, - this.faceSize, - 3, - ]); - // TODO(MR): onnx-yolo - // const feeds: Record = {}; - const feeds: Record = {}; - const name = this.onnxInferenceSession.inputNames[0]; - feeds[name] = inputTensor; - await this.onnxInferenceSession.run(feeds); - console.log("start end mobilefacenet"); - } - - private async getOnnxInferenceSession() { - if (!this.onnxInferenceSession) { - await this.initOnnx(); - } - return this.onnxInferenceSession; + this.faceSize = mobileFaceNetFaceSize; } public async getFaceEmbeddings( faceData: Float32Array, ): Promise> { - const inputTensor = new ort.Tensor("float32", faceData, [ - Math.round(faceData.length / (this.faceSize * this.faceSize * 3)), - this.faceSize, - this.faceSize, - 3, - ]); - // TODO(MR): onnx-yolo - // const feeds: Record = {}; - const feeds: Record = {}; - feeds["img_inputs"] = inputTensor; - const inferenceSession = await this.getOnnxInferenceSession(); - // TODO(MR): onnx-yolo - // const runout: ort.InferenceSession.OnnxValueMapType = - const runout: any = await inferenceSession.run(feeds); - // const test = runout.embeddings; - // const test2 = test.cpuData; - const outputData = runout.embeddings["cpuData"] as Float32Array; + const outputData = await workerBridge.faceEmbedding(faceData); + + const embeddingSize = 192; const embeddings = new Array( - outputData.length / MOBILEFACENET_EMBEDDING_SIZE, + outputData.length / embeddingSize, ); for (let i = 0; i < embeddings.length; i++) { embeddings[i] = new Float32Array( - outputData.slice( - i * MOBILEFACENET_EMBEDDING_SIZE, - (i + 1) * MOBILEFACENET_EMBEDDING_SIZE, - ), + outputData.slice(i * embeddingSize, (i + 1) * embeddingSize), ); } return embeddings; } - - public async dispose() { - const inferenceSession = await this.getOnnxInferenceSession(); - inferenceSession?.release(); - this.onnxInferenceSession = undefined; - } } export default new MobileFaceNetEmbeddingService(); diff --git a/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts b/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts index 9fc0f7ad24..4fa840749d 100644 --- a/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts +++ b/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts @@ -1,4 +1,5 @@ -import { MAX_FACE_DISTANCE_PERCENT } from "constants/mlConfig"; +import { workerBridge } from "@/next/worker/worker-bridge"; +import { euclidean } from "hdbscan"; import { Matrix, applyToPoint, @@ -21,17 +22,7 @@ import { import { newBox } from "utils/machineLearning"; import { Box, Point } from "../../../thirdparty/face-api/classes"; -// TODO(MR): onnx-yolo -// import * as ort from "onnxruntime-web"; -// import { env } from "onnxruntime-web"; -const ort: any = {}; - -// TODO(MR): onnx-yolo -// env.wasm.wasmPaths = "/js/onnx/"; class YoloFaceDetectionService implements FaceDetectionService { - // TODO(MR): onnx-yolo - // private onnxInferenceSession?: ort.InferenceSession; - private onnxInferenceSession?: any; public method: Versioned; public constructor() { @@ -41,27 +32,38 @@ class YoloFaceDetectionService implements FaceDetectionService { }; } - private async initOnnx() { - console.log("start ort"); - this.onnxInferenceSession = await ort.InferenceSession.create( - "/models/yoloface/yolov5s_face_640_640_dynamic.onnx", + public async detectFaces( + imageBitmap: ImageBitmap, + ): Promise> { + const maxFaceDistancePercent = Math.sqrt(2) / 100; + const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent; + const preprocessResult = + this.preprocessImageBitmapToFloat32ChannelsFirst( + imageBitmap, + 640, + 640, + ); + const data = preprocessResult.data; + const resized = preprocessResult.newSize; + const outputData = await workerBridge.detectFaces(data); + const faces = this.getFacesFromYoloOutput( + outputData as Float32Array, + 0.7, ); - const data = new Float32Array(1 * 3 * 640 * 640); - const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]); - // TODO(MR): onnx-yolo - // const feeds: Record = {}; - const feeds: Record = {}; - const name = this.onnxInferenceSession.inputNames[0]; - feeds[name] = inputTensor; - await this.onnxInferenceSession.run(feeds); - console.log("start end"); - } - - private async getOnnxInferenceSession() { - if (!this.onnxInferenceSession) { - await this.initOnnx(); - } - return this.onnxInferenceSession; + const inBox = newBox(0, 0, resized.width, resized.height); + const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height); + const transform = computeTransformToBox(inBox, toBox); + const faceDetections: Array = faces?.map((f) => { + const box = transformBox(f.box, transform); + const normLandmarks = f.landmarks; + const landmarks = transformPoints(normLandmarks, transform); + return { + box, + landmarks, + probability: f.probability as number, + } as FaceDetection; + }); + return removeDuplicateDetections(faceDetections, maxFaceDistance); } private preprocessImageBitmapToFloat32ChannelsFirst( @@ -156,43 +158,6 @@ class YoloFaceDetectionService implements FaceDetectionService { }; } - /** - * @deprecated The method should not be used - */ - private imageBitmapToTensorData(imageBitmap) { - // Create an OffscreenCanvas and set its size - const offscreenCanvas = new OffscreenCanvas( - imageBitmap.width, - imageBitmap.height, - ); - const ctx = offscreenCanvas.getContext("2d"); - ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height); - const imageData = ctx.getImageData( - 0, - 0, - imageBitmap.width, - imageBitmap.height, - ); - const pixelData = imageData.data; - const data = new Float32Array( - 1 * 3 * imageBitmap.width * imageBitmap.height, - ); - // Populate the Float32Array with normalized pixel values - for (let i = 0; i < pixelData.length; i += 4) { - // Normalize pixel values to the range [0, 1] - data[i / 4] = pixelData[i] / 255.0; // Red channel - data[i / 4 + imageBitmap.width * imageBitmap.height] = - pixelData[i + 1] / 255.0; // Green channel - data[i / 4 + 2 * imageBitmap.width * imageBitmap.height] = - pixelData[i + 2] / 255.0; // Blue channel - } - - return { - data: data, - shape: [1, 3, imageBitmap.width, imageBitmap.height], - }; - } - // The rowOutput is a Float32Array of shape [25200, 16], where each row represents a bounding box. private getFacesFromYoloOutput( rowOutput: Float32Array, @@ -270,64 +235,10 @@ class YoloFaceDetectionService implements FaceDetectionService { probability: faceDetection.probability, }; } - - private async estimateOnnx(imageBitmap: ImageBitmap) { - const maxFaceDistance = imageBitmap.width * MAX_FACE_DISTANCE_PERCENT; - const preprocessResult = - this.preprocessImageBitmapToFloat32ChannelsFirst( - imageBitmap, - 640, - 640, - ); - const data = preprocessResult.data; - const resized = preprocessResult.newSize; - const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]); - // TODO(MR): onnx-yolo - // const feeds: Record = {}; - const feeds: Record = {}; - feeds["input"] = inputTensor; - const inferenceSession = await this.getOnnxInferenceSession(); - const runout = await inferenceSession.run(feeds); - const outputData = runout.output.data; - const faces = this.getFacesFromYoloOutput( - outputData as Float32Array, - 0.7, - ); - const inBox = newBox(0, 0, resized.width, resized.height); - const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height); - const transform = computeTransformToBox(inBox, toBox); - const faceDetections: Array = faces?.map((f) => { - const box = transformBox(f.box, transform); - const normLandmarks = f.landmarks; - const landmarks = transformPoints(normLandmarks, transform); - return { - box, - landmarks, - probability: f.probability as number, - } as FaceDetection; - }); - return removeDuplicateDetections(faceDetections, maxFaceDistance); - } - - public async detectFaces( - imageBitmap: ImageBitmap, - ): Promise> { - // measure time taken - const facesFromOnnx = await this.estimateOnnx(imageBitmap); - return facesFromOnnx; - } - - public async dispose() { - const inferenceSession = await this.getOnnxInferenceSession(); - inferenceSession?.release(); - this.onnxInferenceSession = undefined; - } } export default new YoloFaceDetectionService(); -import { euclidean } from "hdbscan"; - /** * Removes duplicate face detections from an array of detections. * diff --git a/web/apps/photos/src/types/machineLearning/index.ts b/web/apps/photos/src/types/machineLearning/index.ts index 3def20a088..d0c902333c 100644 --- a/web/apps/photos/src/types/machineLearning/index.ts +++ b/web/apps/photos/src/types/machineLearning/index.ts @@ -45,16 +45,13 @@ export declare type Landmark = Point; export declare type ImageType = "Original" | "Preview"; -export declare type FaceDetectionMethod = "FaceApiSSD" | "YoloFace"; +export declare type FaceDetectionMethod = "YoloFace"; export declare type FaceCropMethod = "ArcFace"; -export declare type FaceAlignmentMethod = - | "ArcFace" - | "FaceApiDlib" - | "RotatedFaceApiDlib"; +export declare type FaceAlignmentMethod = "ArcFace"; -export declare type FaceEmbeddingMethod = "MobileFaceNet" | "FaceApiDlib"; +export declare type FaceEmbeddingMethod = "MobileFaceNet"; export declare type BlurDetectionMethod = "Laplacian"; @@ -261,13 +258,12 @@ export declare type MLIndex = "files" | "people"; export interface FaceDetectionService { method: Versioned; - // init(): Promise; + detectFaces(image: ImageBitmap): Promise>; getRelativeDetection( faceDetection: FaceDetection, imageDimensions: Dimensions, ): FaceDetection; - dispose(): Promise; } export interface FaceCropService { @@ -288,9 +284,8 @@ export interface FaceAlignmentService { export interface FaceEmbeddingService { method: Versioned; faceSize: number; - // init(): Promise; + getFaceEmbeddings(faceImages: Float32Array): Promise>; - dispose(): Promise; } export interface BlurDetectionService { diff --git a/web/apps/photos/src/utils/common/job.ts b/web/apps/photos/src/utils/common/job.ts index f549966b66..365f879e95 100644 --- a/web/apps/photos/src/utils/common/job.ts +++ b/web/apps/photos/src/utils/common/job.ts @@ -50,7 +50,7 @@ export class SimpleJob { try { const jobResult = await this.runCallback(); - if (jobResult.shouldBackoff) { + if (jobResult && jobResult.shouldBackoff) { this.intervalSec = Math.min( this.config.maxItervalSec, this.intervalSec * this.config.backoffMultiplier, diff --git a/web/apps/photos/src/utils/storage/mlIDbStorage.ts b/web/apps/photos/src/utils/storage/mlIDbStorage.ts index d7e24cbe80..8be60afacf 100644 --- a/web/apps/photos/src/utils/storage/mlIDbStorage.ts +++ b/web/apps/photos/src/utils/storage/mlIDbStorage.ts @@ -124,6 +124,10 @@ class MLIDbStorage { .add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME); } if (oldVersion < 4) { + // TODO(MR): This loses the user's settings. + db.deleteObjectStore("configs"); + db.createObjectStore("configs"); + db.deleteObjectStore("things"); } diff --git a/web/packages/next/types/ipc.ts b/web/packages/next/types/ipc.ts index a0bc07d9a8..5b0979eaa2 100644 --- a/web/packages/next/types/ipc.ts +++ b/web/packages/next/types/ipc.ts @@ -196,7 +196,7 @@ export interface Electron { // - ML /** - * Compute and return a CLIP embedding of the given image. + * Return a CLIP embedding of the given image. * * See: [Note: CLIP based magic search] * @@ -207,7 +207,7 @@ export interface Electron { clipImageEmbedding: (jpegImageData: Uint8Array) => Promise; /** - * Compute and return a CLIP embedding of the given image. + * Return a CLIP embedding of the given image. * * See: [Note: CLIP based magic search] * @@ -217,6 +217,22 @@ export interface Electron { */ clipTextEmbedding: (text: string) => Promise; + /** + * Detect faces in the given image using YOLO. + * + * Both the input and output are opaque binary data whose internal structure + * is specific to our implementation and the model (YOLO) we use. + */ + detectFaces: (input: Float32Array) => Promise; + + /** + * Return a MobileFaceNet embedding for the given face data. + * + * Both the input and output are opaque binary data whose internal structure + * is specific to our implementation and the model (MobileFaceNet) we use. + */ + faceEmbedding: (input: Float32Array) => Promise; + // - File selection // TODO: Deprecated - use dialogs on the renderer process itself diff --git a/web/packages/next/worker/comlink-worker.ts b/web/packages/next/worker/comlink-worker.ts index 033c79fa8c..f082ac1145 100644 --- a/web/packages/next/worker/comlink-worker.ts +++ b/web/packages/next/worker/comlink-worker.ts @@ -35,6 +35,19 @@ export class ComlinkWorker InstanceType> { } } +// TODO(MR): Temporary method to forward auth tokens to workers +const getAuthToken = () => { + // LS_KEYS.USER + const userJSONString = localStorage.getItem("user"); + if (!userJSONString) return undefined; + const json: unknown = JSON.parse(userJSONString); + if (!json || typeof json != "object" || !("token" in json)) + return undefined; + const token = json.token; + if (typeof token != "string") return undefined; + return token; +}; + /** * A minimal set of utility functions that we expose to all workers that we * create. @@ -44,8 +57,12 @@ export class ComlinkWorker InstanceType> { */ const workerBridge = { logToDisk, + getAuthToken, convertToJPEG: (inputFileData: Uint8Array, filename: string) => ensureElectron().convertToJPEG(inputFileData, filename), + detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input), + faceEmbedding: (input: Float32Array) => + ensureElectron().faceEmbedding(input), }; export type WorkerBridge = typeof workerBridge;