diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index 641ce9963d..37ee0478e5 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -45,8 +45,9 @@ import { logout } from "./services/logout"; import { computeCLIPImageEmbedding, computeCLIPTextEmbeddingIfAvailable, -} from "./services/ml-clip"; -import { computeFaceEmbeddings, detectFaces } from "./services/ml-face"; + computeFaceEmbeddings, + detectFaces, +} from "./services/ml-utility"; import { encryptionKey, lastShownChangelogVersion, diff --git a/desktop/src/main/services/ml-clip.ts b/desktop/src/main/services/ml-clip.ts deleted file mode 100644 index cea1d667b5..0000000000 --- a/desktop/src/main/services/ml-clip.ts +++ /dev/null @@ -1,68 +0,0 @@ -/** - * @file Compute CLIP embeddings for images and text. - * - * The embeddings are computed using ONNX runtime, with CLIP as the model. - */ - -import Tokenizer from "clip-bpe-js"; -import * as ort from "onnxruntime-node"; -import log from "../log"; -import { ensure, wait } from "../utils/common"; -import { makeCachedInferenceSession } from "./ml"; - -const cachedCLIPImageSession = makeCachedInferenceSession( - "clip-image-vit-32-float32.onnx", - 351468764 /* 335.2 MB */, -); - -export const computeCLIPImageEmbedding = async (input: Float32Array) => { - const session = await cachedCLIPImageSession(); - const t = Date.now(); - const feeds = { - input: new ort.Tensor("float32", input, [1, 3, 224, 224]), - }; - const results = await session.run(feeds); - log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`); - /* Need these model specific casts to type the result */ - return ensure(results.output).data as Float32Array; -}; - -const cachedCLIPTextSession = makeCachedInferenceSession( - "clip-text-vit-32-uint8.onnx", - 64173509 /* 61.2 MB */, -); - -let _tokenizer: Tokenizer | undefined; -const getTokenizer = () => { - if (!_tokenizer) _tokenizer = new Tokenizer(); - return _tokenizer; -}; - -export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { - const sessionOrSkip = await Promise.race([ - cachedCLIPTextSession(), - // Wait for a tick to get the session promise to resolved the first time - // this code runs on each app start (and the model has been downloaded). - wait(0).then(() => 1), - ]); - - // Don't wait for the download to complete. - if (typeof sessionOrSkip == "number") { - log.info( - "Ignoring CLIP text embedding request because model download is pending", - ); - return undefined; - } - - const session = sessionOrSkip; - const t = Date.now(); - const tokenizer = getTokenizer(); - const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); - const feeds = { - input: new ort.Tensor("int32", tokenizedText, [1, 77]), - }; - - const results = await session.run(feeds); - log.debug(() => `ONNX/CLIP text embedding took ${Date.now() - t} ms`); - return ensure(results.output).data as Float32Array; -}; diff --git a/desktop/src/main/services/ml-face.ts b/desktop/src/main/services/ml-face.ts deleted file mode 100644 index 33c09efaa2..0000000000 --- a/desktop/src/main/services/ml-face.ts +++ /dev/null @@ -1,53 +0,0 @@ -/** - * @file Various face recognition related tasks. - * - * - Face detection with the YOLO model. - * - Face embedding with the MobileFaceNet model. - * - * The runtime used is ONNX. - */ - -import * as ort from "onnxruntime-node"; -import log from "../log"; -import { ensure } from "../utils/common"; -import { makeCachedInferenceSession } from "./ml"; - -const cachedFaceDetectionSession = makeCachedInferenceSession( - "yolov5s_face_640_640_dynamic.onnx", - 30762872 /* 29.3 MB */, -); - -export const detectFaces = async (input: Float32Array) => { - const session = await cachedFaceDetectionSession(); - const t = Date.now(); - const feeds = { - input: new ort.Tensor("float32", input, [1, 3, 640, 640]), - }; - const results = await session.run(feeds); - log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`); - return ensure(results.output).data; -}; - -const cachedFaceEmbeddingSession = makeCachedInferenceSession( - "mobilefacenet_opset15.onnx", - 5286998 /* 5 MB */, -); - -export const computeFaceEmbeddings = async (input: Float32Array) => { - // Dimension of each face (alias) - const mobileFaceNetFaceSize = 112; - // Smaller alias - const z = mobileFaceNetFaceSize; - // Size of each face's data in the batch - const n = Math.round(input.length / (z * z * 3)); - const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]); - - const session = await cachedFaceEmbeddingSession(); - const t = Date.now(); - const feeds = { img_inputs: inputTensor }; - const results = await session.run(feeds); - log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`); - /* Need these model specific casts to extract and type the result */ - return (results.embeddings as unknown as Record) - .cpuData as Float32Array; -}; diff --git a/desktop/src/main/services/ml-utility.ts b/desktop/src/main/services/ml-utility.ts new file mode 100644 index 0000000000..79d39edea4 --- /dev/null +++ b/desktop/src/main/services/ml-utility.ts @@ -0,0 +1,235 @@ +/** + * @file ML related tasks. This code runs in a utility process. + * + * The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models + * for various tasks are not shipped with the app but are downloaded on demand. + */ + +import Tokenizer from "clip-bpe-js"; +import { app, net } from "electron/main"; +import { existsSync } from "fs"; +import fs from "node:fs/promises"; +import path from "node:path"; +import * as ort from "onnxruntime-node"; +import log from "../log"; +import { writeStream } from "../stream"; +import { ensure, wait } from "../utils/common"; + +/** + * Return a function that can be used to trigger a download of the specified + * model, and the creating of an ONNX inference session initialized using it. + * + * Multiple parallel calls to the returned function are fine, it ensures that + * the the model will be downloaded and the session created using it only once. + * All pending calls to it meanwhile will just await on the same promise. + * + * And once the promise is resolved, the create ONNX inference session will be + * cached, so subsequent calls to the returned function will just reuse the same + * session. + * + * {@link makeCachedInferenceSession} can itself be called anytime, it doesn't + * actively trigger a download until the returned function is called. + * + * @param modelName The name of the model to download. + * + * @param modelByteSize The size in bytes that we expect the model to have. If + * the size of the downloaded model does not match the expected size, then we + * will redownload it. + * + * @returns a function. calling that function returns a promise to an ONNX + * session. + */ +const makeCachedInferenceSession = ( + modelName: string, + modelByteSize: number, +) => { + let session: Promise | undefined; + + const download = () => + modelPathDownloadingIfNeeded(modelName, modelByteSize); + + const createSession = (modelPath: string) => + createInferenceSession(modelPath); + + const cachedInferenceSession = () => { + if (!session) session = download().then(createSession); + return session; + }; + + return cachedInferenceSession; +}; + +/** + * Download the model named {@link modelName} if we don't already have it. + * + * Also verify that the size of the model we get matches {@expectedByteSize} (if + * not, redownload it). + * + * @returns the path to the model on the local machine. + */ +const modelPathDownloadingIfNeeded = async ( + modelName: string, + expectedByteSize: number, +) => { + const modelPath = modelSavePath(modelName); + + if (!existsSync(modelPath)) { + log.info("CLIP image model not found, downloading"); + await downloadModel(modelPath, modelName); + } else { + const size = (await fs.stat(modelPath)).size; + if (size !== expectedByteSize) { + log.error( + `The size ${size} of model ${modelName} does not match the expected size, downloading again`, + ); + await downloadModel(modelPath, modelName); + } + } + + return modelPath; +}; + +/** Return the path where the given {@link modelName} is meant to be saved */ +const modelSavePath = (modelName: string) => + path.join(app.getPath("userData"), "models", modelName); + +const downloadModel = async (saveLocation: string, name: string) => { + // `mkdir -p` the directory where we want to save the model. + const saveDir = path.dirname(saveLocation); + await fs.mkdir(saveDir, { recursive: true }); + // Download. + log.info(`Downloading ML model from ${name}`); + const url = `https://models.ente.io/${name}`; + const res = await net.fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); + const body = res.body; + if (!body) throw new Error(`Received an null response for ${url}`); + // Save. + await writeStream(saveLocation, body); + log.info(`Downloaded CLIP model ${name}`); +}; + +/** + * Create an ONNX {@link InferenceSession} with some defaults. + */ +const createInferenceSession = async (modelPath: string) => { + return await ort.InferenceSession.create(modelPath, { + // Restrict the number of threads to 1. + intraOpNumThreads: 1, + // Be more conservative with RAM usage. + enableCpuMemArena: false, + }); +}; + +const cachedCLIPImageSession = makeCachedInferenceSession( + "clip-image-vit-32-float32.onnx", + 351468764 /* 335.2 MB */, +); + +/** + * Compute CLIP embeddings for an image. + * + * The embeddings are computed using ONNX runtime, with CLIP as the model. + */ +export const computeCLIPImageEmbedding = async (input: Float32Array) => { + const session = await cachedCLIPImageSession(); + const t = Date.now(); + const feeds = { + input: new ort.Tensor("float32", input, [1, 3, 224, 224]), + }; + const results = await session.run(feeds); + log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`); + /* Need these model specific casts to type the result */ + return ensure(results.output).data as Float32Array; +}; + +const cachedCLIPTextSession = makeCachedInferenceSession( + "clip-text-vit-32-uint8.onnx", + 64173509 /* 61.2 MB */, +); + +let _tokenizer: Tokenizer | undefined; +const getTokenizer = () => { + if (!_tokenizer) _tokenizer = new Tokenizer(); + return _tokenizer; +}; + +/** + * Compute CLIP embeddings for an text snippet. + * + * The embeddings are computed using ONNX runtime, with CLIP as the model. + */ +export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { + const sessionOrSkip = await Promise.race([ + cachedCLIPTextSession(), + // Wait for a tick to get the session promise to resolved the first time + // this code runs on each app start (and the model has been downloaded). + wait(0).then(() => 1), + ]); + + // Don't wait for the download to complete. + if (typeof sessionOrSkip == "number") { + log.info( + "Ignoring CLIP text embedding request because model download is pending", + ); + return undefined; + } + + const session = sessionOrSkip; + const t = Date.now(); + const tokenizer = getTokenizer(); + const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); + const feeds = { + input: new ort.Tensor("int32", tokenizedText, [1, 77]), + }; + + const results = await session.run(feeds); + log.debug(() => `ONNX/CLIP text embedding took ${Date.now() - t} ms`); + return ensure(results.output).data as Float32Array; +}; + +const cachedFaceDetectionSession = makeCachedInferenceSession( + "yolov5s_face_640_640_dynamic.onnx", + 30762872 /* 29.3 MB */, +); + +/** + * Face detection with the YOLO model and ONNX runtime. + */ +export const detectFaces = async (input: Float32Array) => { + const session = await cachedFaceDetectionSession(); + const t = Date.now(); + const feeds = { + input: new ort.Tensor("float32", input, [1, 3, 640, 640]), + }; + const results = await session.run(feeds); + log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`); + return ensure(results.output).data; +}; + +const cachedFaceEmbeddingSession = makeCachedInferenceSession( + "mobilefacenet_opset15.onnx", + 5286998 /* 5 MB */, +); + +/** + * Face embedding with the MobileFaceNet model and ONNX runtime. + */ +export const computeFaceEmbeddings = async (input: Float32Array) => { + // Dimension of each face (alias) + const mobileFaceNetFaceSize = 112; + // Smaller alias + const z = mobileFaceNetFaceSize; + // Size of each face's data in the batch + const n = Math.round(input.length / (z * z * 3)); + const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]); + + const session = await cachedFaceEmbeddingSession(); + const t = Date.now(); + const feeds = { img_inputs: inputTensor }; + const results = await session.run(feeds); + log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`); + /* Need these model specific casts to extract and type the result */ + return (results.embeddings as unknown as Record) + .cpuData as Float32Array; +}; diff --git a/desktop/src/main/services/ml.ts b/desktop/src/main/services/ml.ts index dcda511df6..5d8b0cf1e3 100644 --- a/desktop/src/main/services/ml.ts +++ b/desktop/src/main/services/ml.ts @@ -1,20 +1,7 @@ /** - * @file ML related functionality, generic layer. - * - * @see also `ml-clip.ts`, `ml-face.ts`. - * - * The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models - * for various tasks are not shipped with the app but are downloaded on demand. + * @file ML related functionality. This code runs in the main process. */ -import { app, net } from "electron/main"; -import { existsSync } from "fs"; -import fs from "node:fs/promises"; -import path from "node:path"; -import * as ort from "onnxruntime-node"; -import log from "../log"; -import { writeStream } from "../stream"; - /** * Create a new ML session. * @@ -53,15 +40,18 @@ import { writeStream } from "../stream"; * to communicate. A pair of message ports is called a message channel. The cool * thing about these is that we can pass these ports themselves over IPC. * + * > One caveat here is that the message ports can only be passed using the + * > `postMessage` APIs, not the usual send/invoke APIs. + * * So we * - * 1. Spawn a utility process. - * 2. In the utility process create a message channel. - * 3. Keep one port of the pair with us, and send the other over IPC to the - * _web worker_ that is coordinating the ML indexing on the web layer. + * 1. In the utility process create a message channel. + * 2. Spawn a utility process, and send one port of the pair to it. + * 3. Send the other port of the pair to the renderer. * - * Thereafter, the utility process and web worker can directly talk to each - * other! + * The renderer will forward that port to the web worker that is coordinating + * the ML indexing on the web layer. Thereafter, the utility process and web + * worker can directly talk to each other! * * Node.js utility process <-> Renderer web worker * @@ -70,109 +60,3 @@ export const createMLSession = () => { // }: Promise => { throw new Error("Not implemented"); }; - -/** - * Return a function that can be used to trigger a download of the specified - * model, and the creating of an ONNX inference session initialized using it. - * - * Multiple parallel calls to the returned function are fine, it ensures that - * the the model will be downloaded and the session created using it only once. - * All pending calls to it meanwhile will just await on the same promise. - * - * And once the promise is resolved, the create ONNX inference session will be - * cached, so subsequent calls to the returned function will just reuse the same - * session. - * - * {@link makeCachedInferenceSession} can itself be called anytime, it doesn't - * actively trigger a download until the returned function is called. - * - * @param modelName The name of the model to download. - * - * @param modelByteSize The size in bytes that we expect the model to have. If - * the size of the downloaded model does not match the expected size, then we - * will redownload it. - * - * @returns a function. calling that function returns a promise to an ONNX - * session. - */ -export const makeCachedInferenceSession = ( - modelName: string, - modelByteSize: number, -) => { - let session: Promise | undefined; - - const download = () => - modelPathDownloadingIfNeeded(modelName, modelByteSize); - - const createSession = (modelPath: string) => - createInferenceSession(modelPath); - - const cachedInferenceSession = () => { - if (!session) session = download().then(createSession); - return session; - }; - - return cachedInferenceSession; -}; - -/** - * Download the model named {@link modelName} if we don't already have it. - * - * Also verify that the size of the model we get matches {@expectedByteSize} (if - * not, redownload it). - * - * @returns the path to the model on the local machine. - */ -const modelPathDownloadingIfNeeded = async ( - modelName: string, - expectedByteSize: number, -) => { - const modelPath = modelSavePath(modelName); - - if (!existsSync(modelPath)) { - log.info("CLIP image model not found, downloading"); - await downloadModel(modelPath, modelName); - } else { - const size = (await fs.stat(modelPath)).size; - if (size !== expectedByteSize) { - log.error( - `The size ${size} of model ${modelName} does not match the expected size, downloading again`, - ); - await downloadModel(modelPath, modelName); - } - } - - return modelPath; -}; - -/** Return the path where the given {@link modelName} is meant to be saved */ -const modelSavePath = (modelName: string) => - path.join(app.getPath("userData"), "models", modelName); - -const downloadModel = async (saveLocation: string, name: string) => { - // `mkdir -p` the directory where we want to save the model. - const saveDir = path.dirname(saveLocation); - await fs.mkdir(saveDir, { recursive: true }); - // Download. - log.info(`Downloading ML model from ${name}`); - const url = `https://models.ente.io/${name}`; - const res = await net.fetch(url); - if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); - const body = res.body; - if (!body) throw new Error(`Received an null response for ${url}`); - // Save. - await writeStream(saveLocation, body); - log.info(`Downloaded CLIP model ${name}`); -}; - -/** - * Crete an ONNX {@link InferenceSession} with some defaults. - */ -const createInferenceSession = async (modelPath: string) => { - return await ort.InferenceSession.create(modelPath, { - // Restrict the number of threads to 1. - intraOpNumThreads: 1, - // Be more conservative with RAM usage. - enableCpuMemArena: false, - }); -};