diff --git a/desktop/package.json b/desktop/package.json index 453bb931fa..c3da2c3591 100644 --- a/desktop/package.json +++ b/desktop/package.json @@ -28,6 +28,7 @@ "auto-launch": "^5.0", "chokidar": "^3.6", "clip-bpe-js": "^0.0.6", + "comlink": "^4.4.1", "compare-versions": "^6.1", "electron-log": "^5.1", "electron-store": "^8.2", diff --git a/desktop/src/main.ts b/desktop/src/main.ts index de969e3cf7..4ebe565bca 100644 --- a/desktop/src/main.ts +++ b/desktop/src/main.ts @@ -21,6 +21,7 @@ import { attachFSWatchIPCHandlers, attachIPCHandlers, attachLogoutIPCHandler, + attachMainWindowIPCHandlers, } from "./main/ipc"; import log, { initLogging } from "./main/log"; import { createApplicationMenu, createTrayContextMenu } from "./main/menu"; @@ -121,6 +122,7 @@ const main = () => { // Setup IPC and streams. const watcher = createWatcher(mainWindow); attachIPCHandlers(); + attachMainWindowIPCHandlers(mainWindow); attachFSWatchIPCHandlers(watcher); attachLogoutIPCHandler(watcher); registerStreamProtocol(); diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index 641ce9963d..6c4020d6ee 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -9,6 +9,7 @@ */ import type { FSWatcher } from "chokidar"; +import type { BrowserWindow } from "electron"; import { ipcMain } from "electron/main"; import type { CollectionMapping, @@ -42,11 +43,7 @@ import { } from "./services/fs"; import { convertToJPEG, generateImageThumbnail } from "./services/image"; import { logout } from "./services/logout"; -import { - computeCLIPImageEmbedding, - computeCLIPTextEmbeddingIfAvailable, -} from "./services/ml-clip"; -import { computeFaceEmbeddings, detectFaces } from "./services/ml-face"; +import { createMLWorker } from "./services/ml"; import { encryptionKey, lastShownChangelogVersion, @@ -184,24 +181,6 @@ export const attachIPCHandlers = () => { ) => ffmpegExec(command, dataOrPathOrZipItem, outputFileExtension), ); - // - ML - - ipcMain.handle("computeCLIPImageEmbedding", (_, input: Float32Array) => - computeCLIPImageEmbedding(input), - ); - - ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) => - computeCLIPTextEmbeddingIfAvailable(text), - ); - - ipcMain.handle("detectFaces", (_, input: Float32Array) => - detectFaces(input), - ); - - ipcMain.handle("computeFaceEmbeddings", (_, input: Float32Array) => - computeFaceEmbeddings(input), - ); - // - Upload ipcMain.handle("listZipItems", (_, zipPath: string) => @@ -231,6 +210,16 @@ export const attachIPCHandlers = () => { ipcMain.handle("clearPendingUploads", () => clearPendingUploads()); }; +/** + * A subset of {@link attachIPCHandlers} for functions that need a reference to + * the main window to do their thing. + */ +export const attachMainWindowIPCHandlers = (mainWindow: BrowserWindow) => { + // - ML + + ipcMain.on("createMLWorker", () => createMLWorker(mainWindow)); +}; + /** * Sibling of {@link attachIPCHandlers} that attaches handlers specific to the * watch folder functionality. diff --git a/desktop/src/main/services/ml-clip.ts b/desktop/src/main/services/ml-clip.ts deleted file mode 100644 index cea1d667b5..0000000000 --- a/desktop/src/main/services/ml-clip.ts +++ /dev/null @@ -1,68 +0,0 @@ -/** - * @file Compute CLIP embeddings for images and text. - * - * The embeddings are computed using ONNX runtime, with CLIP as the model. - */ - -import Tokenizer from "clip-bpe-js"; -import * as ort from "onnxruntime-node"; -import log from "../log"; -import { ensure, wait } from "../utils/common"; -import { makeCachedInferenceSession } from "./ml"; - -const cachedCLIPImageSession = makeCachedInferenceSession( - "clip-image-vit-32-float32.onnx", - 351468764 /* 335.2 MB */, -); - -export const computeCLIPImageEmbedding = async (input: Float32Array) => { - const session = await cachedCLIPImageSession(); - const t = Date.now(); - const feeds = { - input: new ort.Tensor("float32", input, [1, 3, 224, 224]), - }; - const results = await session.run(feeds); - log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`); - /* Need these model specific casts to type the result */ - return ensure(results.output).data as Float32Array; -}; - -const cachedCLIPTextSession = makeCachedInferenceSession( - "clip-text-vit-32-uint8.onnx", - 64173509 /* 61.2 MB */, -); - -let _tokenizer: Tokenizer | undefined; -const getTokenizer = () => { - if (!_tokenizer) _tokenizer = new Tokenizer(); - return _tokenizer; -}; - -export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { - const sessionOrSkip = await Promise.race([ - cachedCLIPTextSession(), - // Wait for a tick to get the session promise to resolved the first time - // this code runs on each app start (and the model has been downloaded). - wait(0).then(() => 1), - ]); - - // Don't wait for the download to complete. - if (typeof sessionOrSkip == "number") { - log.info( - "Ignoring CLIP text embedding request because model download is pending", - ); - return undefined; - } - - const session = sessionOrSkip; - const t = Date.now(); - const tokenizer = getTokenizer(); - const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); - const feeds = { - input: new ort.Tensor("int32", tokenizedText, [1, 77]), - }; - - const results = await session.run(feeds); - log.debug(() => `ONNX/CLIP text embedding took ${Date.now() - t} ms`); - return ensure(results.output).data as Float32Array; -}; diff --git a/desktop/src/main/services/ml-face.ts b/desktop/src/main/services/ml-face.ts deleted file mode 100644 index 33c09efaa2..0000000000 --- a/desktop/src/main/services/ml-face.ts +++ /dev/null @@ -1,53 +0,0 @@ -/** - * @file Various face recognition related tasks. - * - * - Face detection with the YOLO model. - * - Face embedding with the MobileFaceNet model. - * - * The runtime used is ONNX. - */ - -import * as ort from "onnxruntime-node"; -import log from "../log"; -import { ensure } from "../utils/common"; -import { makeCachedInferenceSession } from "./ml"; - -const cachedFaceDetectionSession = makeCachedInferenceSession( - "yolov5s_face_640_640_dynamic.onnx", - 30762872 /* 29.3 MB */, -); - -export const detectFaces = async (input: Float32Array) => { - const session = await cachedFaceDetectionSession(); - const t = Date.now(); - const feeds = { - input: new ort.Tensor("float32", input, [1, 3, 640, 640]), - }; - const results = await session.run(feeds); - log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`); - return ensure(results.output).data; -}; - -const cachedFaceEmbeddingSession = makeCachedInferenceSession( - "mobilefacenet_opset15.onnx", - 5286998 /* 5 MB */, -); - -export const computeFaceEmbeddings = async (input: Float32Array) => { - // Dimension of each face (alias) - const mobileFaceNetFaceSize = 112; - // Smaller alias - const z = mobileFaceNetFaceSize; - // Size of each face's data in the batch - const n = Math.round(input.length / (z * z * 3)); - const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]); - - const session = await cachedFaceEmbeddingSession(); - const t = Date.now(); - const feeds = { img_inputs: inputTensor }; - const results = await session.run(feeds); - log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`); - /* Need these model specific casts to extract and type the result */ - return (results.embeddings as unknown as Record) - .cpuData as Float32Array; -}; diff --git a/desktop/src/main/services/ml-worker.ts b/desktop/src/main/services/ml-worker.ts new file mode 100644 index 0000000000..f4b9221f64 --- /dev/null +++ b/desktop/src/main/services/ml-worker.ts @@ -0,0 +1,315 @@ +/** + * @file ML related tasks. This code runs in a utility process. + * + * The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models + * for various tasks are not shipped with the app but are downloaded on demand. + */ + +// See [Note: Using Electron APIs in UtilityProcess] about what we can and +// cannot import. + +import Tokenizer from "clip-bpe-js"; +import { expose } from "comlink"; +import { net } from "electron/main"; +import { existsSync } from "fs"; +import fs from "node:fs/promises"; +import path from "node:path"; +import * as ort from "onnxruntime-node"; +import { messagePortMainEndpoint } from "../utils/comlink"; +import { ensure, wait } from "../utils/common"; +import { writeStream } from "../utils/stream"; + +/** + * We cannot do + * + * import log from "../log"; + * + * because that requires the Electron APIs that are not available to a utility + * process (See: [Note: Using Electron APIs in UtilityProcess]). But even if + * that were to work, logging will still be problematic since we'd try opening + * the log file from two different Node.js processes (this one, and the main + * one), and I didn't find any indication in the electron-log repository that + * the log file's integrity would be maintained in such cases. + * + * So instead we create this proxy log object that uses `process.parentPort` to + * transport the logs over to the main process. + */ +const log = { + /** + * Unlike the real {@link log.error}, this accepts only the first string + * argument, not the second optional error one. + */ + errorString: (s: string) => mainProcess("log.errorString", s), + info: (...ms: unknown[]) => mainProcess("log.info", ms), + /** + * Unlike the real {@link log.debug}, this is (a) eagerly evaluated, and (b) + * accepts only strings. + */ + debugString: (s: string) => mainProcess("log.debugString", s), +}; + +/** + * Send a message to the main process using a barebones RPC protocol. + */ +const mainProcess = (method: string, param: unknown) => + process.parentPort.postMessage({ method, p: param }); + +log.debugString(`Started ML worker process`); + +process.parentPort.once("message", (e) => { + // Initialize ourselves with the data we got from our parent. + parseInitData(e.data); + // Expose an instance of `ElectronMLWorker` on the port we got from our + // parent. + expose( + { + computeCLIPImageEmbedding, + computeCLIPTextEmbeddingIfAvailable, + detectFaces, + computeFaceEmbeddings, + }, + messagePortMainEndpoint(ensure(e.ports[0])), + ); +}); + +/** + * We cannot access Electron's {@link app} object within a utility process, so + * we pass the value of `app.getPath("userData")` during initialization, and it + * can be subsequently retrieved from here. + */ +let _userDataPath: string | undefined; + +/** Equivalent to app.getPath("userData") */ +const userDataPath = () => ensure(_userDataPath); + +const parseInitData = (data: unknown) => { + if ( + data && + typeof data == "object" && + "userDataPath" in data && + typeof data.userDataPath == "string" + ) { + _userDataPath = data.userDataPath; + } else { + log.errorString("Unparseable initialization data"); + } +}; + +/** + * Return a function that can be used to trigger a download of the specified + * model, and the creating of an ONNX inference session initialized using it. + * + * Multiple parallel calls to the returned function are fine, it ensures that + * the the model will be downloaded and the session created using it only once. + * All pending calls to it meanwhile will just await on the same promise. + * + * And once the promise is resolved, the create ONNX inference session will be + * cached, so subsequent calls to the returned function will just reuse the same + * session. + * + * {@link makeCachedInferenceSession} can itself be called anytime, it doesn't + * actively trigger a download until the returned function is called. + * + * @param modelName The name of the model to download. + * + * @param modelByteSize The size in bytes that we expect the model to have. If + * the size of the downloaded model does not match the expected size, then we + * will redownload it. + * + * @returns a function. calling that function returns a promise to an ONNX + * session. + */ +const makeCachedInferenceSession = ( + modelName: string, + modelByteSize: number, +) => { + let session: Promise | undefined; + + const download = () => + modelPathDownloadingIfNeeded(modelName, modelByteSize); + + const createSession = (modelPath: string) => + createInferenceSession(modelPath); + + const cachedInferenceSession = () => { + if (!session) session = download().then(createSession); + return session; + }; + + return cachedInferenceSession; +}; + +/** + * Download the model named {@link modelName} if we don't already have it. + * + * Also verify that the size of the model we get matches {@expectedByteSize} (if + * not, redownload it). + * + * @returns the path to the model on the local machine. + */ +const modelPathDownloadingIfNeeded = async ( + modelName: string, + expectedByteSize: number, +) => { + const modelPath = modelSavePath(modelName); + + if (!existsSync(modelPath)) { + log.info("CLIP image model not found, downloading"); + await downloadModel(modelPath, modelName); + } else { + const size = (await fs.stat(modelPath)).size; + if (size !== expectedByteSize) { + log.errorString( + `The size ${size} of model ${modelName} does not match the expected size, downloading again`, + ); + await downloadModel(modelPath, modelName); + } + } + + return modelPath; +}; + +/** Return the path where the given {@link modelName} is meant to be saved */ +const modelSavePath = (modelName: string) => + path.join(userDataPath(), "models", modelName); + +const downloadModel = async (saveLocation: string, name: string) => { + // `mkdir -p` the directory where we want to save the model. + const saveDir = path.dirname(saveLocation); + await fs.mkdir(saveDir, { recursive: true }); + // Download. + log.info(`Downloading ML model from ${name}`); + const url = `https://models.ente.io/${name}`; + const res = await net.fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); + const body = res.body; + if (!body) throw new Error(`Received an null response for ${url}`); + // Save. + await writeStream(saveLocation, body); + log.info(`Downloaded CLIP model ${name}`); +}; + +/** + * Create an ONNX {@link InferenceSession} with some defaults. + */ +const createInferenceSession = async (modelPath: string) => { + return await ort.InferenceSession.create(modelPath, { + // Restrict the number of threads to 1. + intraOpNumThreads: 1, + // Be more conservative with RAM usage. + enableCpuMemArena: false, + }); +}; + +const cachedCLIPImageSession = makeCachedInferenceSession( + "clip-image-vit-32-float32.onnx", + 351468764 /* 335.2 MB */, +); + +/** + * Compute CLIP embeddings for an image. + * + * The embeddings are computed using ONNX runtime, with CLIP as the model. + */ +export const computeCLIPImageEmbedding = async (input: Float32Array) => { + const session = await cachedCLIPImageSession(); + const feeds = { + input: new ort.Tensor("float32", input, [1, 3, 224, 224]), + }; + const t = Date.now(); + const results = await session.run(feeds); + log.debugString(`ONNX/CLIP image embedding took ${Date.now() - t} ms`); + /* Need these model specific casts to type the result */ + return ensure(results.output).data as Float32Array; +}; + +const cachedCLIPTextSession = makeCachedInferenceSession( + "clip-text-vit-32-uint8.onnx", + 64173509 /* 61.2 MB */, +); + +let _tokenizer: Tokenizer | undefined; +const getTokenizer = () => { + if (!_tokenizer) _tokenizer = new Tokenizer(); + return _tokenizer; +}; + +/** + * Compute CLIP embeddings for an text snippet. + * + * The embeddings are computed using ONNX runtime, with CLIP as the model. + */ +export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { + const sessionOrSkip = await Promise.race([ + cachedCLIPTextSession(), + // Wait for a tick to get the session promise to resolved the first time + // this code runs on each app start (and the model has been downloaded). + wait(0).then(() => 1), + ]); + + // Don't wait for the download to complete. + if (typeof sessionOrSkip == "number") { + log.info( + "Ignoring CLIP text embedding request because model download is pending", + ); + return undefined; + } + + const session = sessionOrSkip; + const tokenizer = getTokenizer(); + const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); + const feeds = { + input: new ort.Tensor("int32", tokenizedText, [1, 77]), + }; + + const t = Date.now(); + const results = await session.run(feeds); + log.debugString(`ONNX/CLIP text embedding took ${Date.now() - t} ms`); + return ensure(results.output).data as Float32Array; +}; + +const cachedFaceDetectionSession = makeCachedInferenceSession( + "yolov5s_face_640_640_dynamic.onnx", + 30762872 /* 29.3 MB */, +); + +/** + * Face detection with the YOLO model and ONNX runtime. + */ +export const detectFaces = async (input: Float32Array) => { + const session = await cachedFaceDetectionSession(); + const feeds = { + input: new ort.Tensor("float32", input, [1, 3, 640, 640]), + }; + const t = Date.now(); + const results = await session.run(feeds); + log.debugString(`ONNX/YOLO face detection took ${Date.now() - t} ms`); + return ensure(results.output).data; +}; + +const cachedFaceEmbeddingSession = makeCachedInferenceSession( + "mobilefacenet_opset15.onnx", + 5286998 /* 5 MB */, +); + +/** + * Face embedding with the MobileFaceNet model and ONNX runtime. + */ +export const computeFaceEmbeddings = async (input: Float32Array) => { + // Dimension of each face (alias) + const mobileFaceNetFaceSize = 112; + // Smaller alias + const z = mobileFaceNetFaceSize; + // Size of each face's data in the batch + const n = Math.round(input.length / (z * z * 3)); + const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]); + + const session = await cachedFaceEmbeddingSession(); + const feeds = { img_inputs: inputTensor }; + const t = Date.now(); + const results = await session.run(feeds); + log.debugString(`ONNX/MFNT face embedding took ${Date.now() - t} ms`); + /* Need these model specific casts to extract and type the result */ + return (results.embeddings as unknown as Record) + .cpuData as Float32Array; +}; diff --git a/desktop/src/main/services/ml.ts b/desktop/src/main/services/ml.ts index 55bb8d79c2..cc1ae5764c 100644 --- a/desktop/src/main/services/ml.ts +++ b/desktop/src/main/services/ml.ts @@ -1,126 +1,147 @@ /** - * @file ML related functionality, generic layer. - * - * @see also `ml-clip.ts`, `ml-face.ts`. - * - * The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models - * for various tasks are not shipped with the app but are downloaded on demand. - * - * The primary reason for doing these tasks in the Node.js layer is so that we - * can use the binary ONNX runtime which is 10-20x faster than the WASM based - * web one. + * @file ML related functionality. This code runs in the main process. */ -import { app, net } from "electron/main"; -import { existsSync } from "fs"; -import fs from "node:fs/promises"; +import { + MessageChannelMain, + type BrowserWindow, + type UtilityProcess, +} from "electron"; +import { app, utilityProcess } from "electron/main"; import path from "node:path"; -import * as ort from "onnxruntime-node"; import log from "../log"; -import { writeStream } from "../stream"; + +/** The active ML worker (utility) process, if any. */ +let _child: UtilityProcess | undefined; /** - * Return a function that can be used to trigger a download of the specified - * model, and the creating of an ONNX inference session initialized using it. + * Create a new ML worker process, terminating the older ones (if any). * - * Multiple parallel calls to the returned function are fine, it ensures that - * the the model will be downloaded and the session created using it only once. - * All pending calls to it meanwhile will just await on the same promise. + * [Note: ML IPC] * - * And once the promise is resolved, the create ONNX inference session will be - * cached, so subsequent calls to the returned function will just reuse the same - * session. + * The primary reason for doing ML tasks in the Node.js layer is so that we can + * use the binary ONNX runtime, which is 10-20x faster than the WASM one that + * can be used directly on the web layer. * - * {@link makeCachedInferenceSession} can itself be called anytime, it doesn't - * actively trigger a download until the returned function is called. + * For this to work, the main and renderer process need to communicate with each + * other. Further, in the web layer the ML indexing runs in a web worker (so as + * to not get in the way of the main thread). So the communication has 2 hops: * - * @param modelName The name of the model to download. + * Node.js main <-> Renderer main <-> Renderer web worker * - * @param modelByteSize The size in bytes that we expect the model to have. If - * the size of the downloaded model does not match the expected size, then we - * will redownload it. + * This naive way works, but has a problem. The Node.js main process is in the + * code path for delivering user events to the renderer process. The ML tasks we + * do take in the order of 100-300 ms (possibly more) for each individual + * inference. Thus, the Node.js main process is busy for those 100-300 ms, and + * does not forward events to the renderer, causing the UI to jitter. * - * @returns a function. calling that function returns a promise to an ONNX - * session. + * The solution for this is to spawn an Electron UtilityProcess, which we can + * think of a regular Node.js child process. This frees up the Node.js main + * process, and would remove the jitter. + * https://www.electronjs.org/docs/latest/tutorial/process-model + * + * It would seem that this introduces another hop in our IPC + * + * Node.js utility process <-> Node.js main <-> ... + * + * but here we can use the special bit about Electron utility processes that + * separates them from regular Node.js child processes: their support for + * message ports. https://www.electronjs.org/docs/latest/tutorial/message-ports + * + * As a brief summary, a MessagePort is a web feature that allows two contexts + * to communicate. A pair of message ports is called a message channel. The cool + * thing about these is that we can pass these ports themselves over IPC. + * + * > One caveat here is that the message ports can only be passed using the + * > `postMessage` APIs, not the usual send/invoke APIs. + * + * So we + * + * 1. In the utility process create a message channel. + * 2. Spawn a utility process, and send one port of the pair to it. + * 3. Send the other port of the pair to the renderer. + * + * The renderer will forward that port to the web worker that is coordinating + * the ML indexing on the web layer. Thereafter, the utility process and web + * worker can directly talk to each other! + * + * Node.js utility process <-> Renderer web worker + * + * The RPC protocol is handled using comlink on both ends. The port itself needs + * to be relayed using `postMessage`. */ -export const makeCachedInferenceSession = ( - modelName: string, - modelByteSize: number, -) => { - let session: Promise | undefined; - - const download = () => - modelPathDownloadingIfNeeded(modelName, modelByteSize); - - const createSession = (modelPath: string) => - createInferenceSession(modelPath); - - const cachedInferenceSession = () => { - if (!session) session = download().then(createSession); - return session; - }; - - return cachedInferenceSession; -}; - -/** - * Download the model named {@link modelName} if we don't already have it. - * - * Also verify that the size of the model we get matches {@expectedByteSize} (if - * not, redownload it). - * - * @returns the path to the model on the local machine. - */ -const modelPathDownloadingIfNeeded = async ( - modelName: string, - expectedByteSize: number, -) => { - const modelPath = modelSavePath(modelName); - - if (!existsSync(modelPath)) { - log.info("CLIP image model not found, downloading"); - await downloadModel(modelPath, modelName); - } else { - const size = (await fs.stat(modelPath)).size; - if (size !== expectedByteSize) { - log.error( - `The size ${size} of model ${modelName} does not match the expected size, downloading again`, - ); - await downloadModel(modelPath, modelName); - } +export const createMLWorker = (window: BrowserWindow) => { + if (_child) { + log.debug(() => "Terminating previous ML worker process"); + _child.kill(); + _child = undefined; } - return modelPath; -}; + const { port1, port2 } = new MessageChannelMain(); -/** Return the path where the given {@link modelName} is meant to be saved */ -const modelSavePath = (modelName: string) => - path.join(app.getPath("userData"), "models", modelName); + const child = utilityProcess.fork(path.join(__dirname, "ml-worker.js")); + const userDataPath = app.getPath("userData"); + child.postMessage({ userDataPath }, [port1]); -const downloadModel = async (saveLocation: string, name: string) => { - // `mkdir -p` the directory where we want to save the model. - const saveDir = path.dirname(saveLocation); - await fs.mkdir(saveDir, { recursive: true }); - // Download. - log.info(`Downloading ML model from ${name}`); - const url = `https://models.ente.io/${name}`; - const res = await net.fetch(url); - if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); - const body = res.body; - if (!body) throw new Error(`Received an null response for ${url}`); - // Save. - await writeStream(saveLocation, body); - log.info(`Downloaded CLIP model ${name}`); + window.webContents.postMessage("createMLWorker/port", undefined, [port2]); + + handleMessagesFromUtilityProcess(child); + + _child = child; }; /** - * Crete an ONNX {@link InferenceSession} with some defaults. + * Handle messages posted from the utility process. + * + * [Note: Using Electron APIs in UtilityProcess] + * + * Only a small subset of the Electron APIs are available to a UtilityProcess. + * As of writing (Jul 2024, Electron 30), only the following are available: + * + * - net + * - systemPreferences + * + * In particular, `app` is not available. + * + * We structure our code so that it doesn't need anything apart from `net`. + * + * For the other cases, + * + * - Additional parameters to the utility process are passed alongwith the + * initial message where we provide it the message port. + * + * - When we need to communicate from the utility process to the main process, + * we use the `parentPort` in the utility process. */ -const createInferenceSession = async (modelPath: string) => { - return await ort.InferenceSession.create(modelPath, { - // Restrict the number of threads to 1. - intraOpNumThreads: 1, - // Be more conservative with RAM usage. - enableCpuMemArena: false, +const handleMessagesFromUtilityProcess = (child: UtilityProcess) => { + const logTag = "[ml-worker]"; + child.on("message", (m: unknown) => { + if (m && typeof m == "object" && "method" in m && "p" in m) { + const p = m.p; + switch (m.method) { + case "log.errorString": + if (typeof p == "string") { + log.error(`${logTag} ${p}`); + return; + } + break; + case "log.info": + if (Array.isArray(p)) { + // Need to cast from any[] to unknown[] + log.info(logTag, ...(p as unknown[])); + return; + } + break; + case "log.debugString": + if (typeof p == "string") { + log.debug(() => `${logTag} ${p}`); + return; + } + break; + default: + break; + } + } + log.info("Ignoring unknown message from ML worker", m); }); }; diff --git a/desktop/src/main/stream.ts b/desktop/src/main/stream.ts index 749c94f491..d32eecc627 100644 --- a/desktop/src/main/stream.ts +++ b/desktop/src/main/stream.ts @@ -3,7 +3,6 @@ */ import { net, protocol } from "electron/main"; import { randomUUID } from "node:crypto"; -import { createWriteStream, existsSync } from "node:fs"; import fs from "node:fs/promises"; import { Readable } from "node:stream"; import { ReadableStream } from "node:stream/web"; @@ -12,6 +11,7 @@ import log from "./log"; import { ffmpegConvertToMP4 } from "./services/ffmpeg"; import { markClosableZip, openZip } from "./services/zip"; import { ensure } from "./utils/common"; +import { writeStream } from "./utils/stream"; import { deleteTempFile, deleteTempFileIgnoringErrors, @@ -142,6 +142,7 @@ const handleReadZip = async (zipPath: string, entryName: string) => { // https://github.com/antelle/node-stream-zip/blob/master/node_stream_zip.js const modifiedMs = entry.time; + // @ts-expect-error [Note: Node and web stream type mismatch] return new Response(webReadableStream, { headers: { // We don't know the exact type, but it doesn't really matter, just @@ -159,39 +160,6 @@ const handleWrite = async (path: string, request: Request) => { return new Response("", { status: 200 }); }; -/** - * Write a (web) ReadableStream to a file at the given {@link filePath}. - * - * The returned promise resolves when the write completes. - * - * @param filePath The local file system path where the file should be written. - * - * @param readableStream A web - * [ReadableStream](https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream). - */ -export const writeStream = (filePath: string, readableStream: ReadableStream) => - writeNodeStream(filePath, Readable.fromWeb(readableStream)); - -const writeNodeStream = async (filePath: string, fileStream: Readable) => { - const writeable = createWriteStream(filePath); - - fileStream.on("error", (err) => { - writeable.destroy(err); // Close the writable stream with an error - }); - - fileStream.pipe(writeable); - - await new Promise((resolve, reject) => { - writeable.on("finish", resolve); - writeable.on("error", (err) => { - if (existsSync(filePath)) { - void fs.unlink(filePath); - } - reject(err); - }); - }); -}; - /** * A map from token to file paths for convert-to-mp4 requests that we have * received. diff --git a/desktop/src/main/utils/comlink.ts b/desktop/src/main/utils/comlink.ts new file mode 100644 index 0000000000..d2006e795b --- /dev/null +++ b/desktop/src/main/utils/comlink.ts @@ -0,0 +1,42 @@ +import type { Endpoint } from "comlink"; +import type { MessagePortMain } from "electron"; + +/** + * An adaptation of the `nodeEndpoint` function from comlink suitable for use in + * TypeScript with an Electron utility process. + * + * This is an adaption of the following function from comlink: + * https://github.com/GoogleChromeLabs/comlink/blob/main/src/node-adapter.ts + * + * It has been modified (somewhat hackily) to be useful with an Electron + * MessagePortMain instead of a Node.js worker_thread. Only things that we + * currently need have been made to work as you can see by the abundant type + * casts. Caveat emptor. + */ +export const messagePortMainEndpoint = (mp: MessagePortMain): Endpoint => { + type NL = EventListenerOrEventListenerObject; + type EL = (data: Electron.MessageEvent) => void; + const listeners = new WeakMap(); + return { + postMessage: (message, transfer) => { + mp.postMessage(message, transfer as unknown as MessagePortMain[]); + }, + addEventListener: (_, eh) => { + const l: EL = (data) => + "handleEvent" in eh + ? eh.handleEvent({ data } as MessageEvent) + : eh(data as unknown as MessageEvent); + mp.on("message", (data) => { + l(data); + }); + listeners.set(eh, l); + }, + removeEventListener: (_, eh) => { + const l = listeners.get(eh); + if (!l) return; + mp.off("message", l); + listeners.delete(eh); + }, + start: mp.start.bind(mp), + }; +}; diff --git a/desktop/src/main/utils/stream.ts b/desktop/src/main/utils/stream.ts new file mode 100644 index 0000000000..f5a98de0f7 --- /dev/null +++ b/desktop/src/main/utils/stream.ts @@ -0,0 +1,39 @@ +import { createWriteStream, existsSync } from "node:fs"; +import fs from "node:fs/promises"; +import { Readable } from "node:stream"; + +/** + * Write a (web) ReadableStream to a file at the given {@link filePath}. + * + * The returned promise resolves when the write completes. + * + * @param filePath The local file system path where the file should be written. + * + * @param readableStream A web + * [ReadableStream](https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream). + * + */ +export const writeStream = ( + filePath: string, + readableStream: unknown /*ReadableStream*/, // @ts-expect-error [Note: Node and web stream type mismatch] +) => writeNodeStream(filePath, Readable.fromWeb(readableStream)); + +const writeNodeStream = async (filePath: string, fileStream: Readable) => { + const writeable = createWriteStream(filePath); + + fileStream.on("error", (err) => { + writeable.destroy(err); // Close the writable stream with an error + }); + + fileStream.pipe(writeable); + + await new Promise((resolve, reject) => { + writeable.on("finish", resolve); + writeable.on("error", (err) => { + if (existsSync(filePath)) { + void fs.unlink(filePath); + } + reject(err); + }); + }); +}; diff --git a/desktop/src/preload.ts b/desktop/src/preload.ts index f2366aa63d..3058a6376f 100644 --- a/desktop/src/preload.ts +++ b/desktop/src/preload.ts @@ -36,6 +36,28 @@ * - [main] desktop/src/main/ipc.ts contains impl */ +// This code runs in the (isolated) web layer. Contrary to the impression given +// by the Electron docs (as of 2024), the window object is actually available to +// the preload script, and it is necessary for legitimate uses too. +// +// > The isolated world is connected to the DOM just the same is the main world, +// > it is just the JS contexts that are separated. +// > +// > https://github.com/electron/electron/issues/27024#issuecomment-745618327 +// +// Adding this reference here tells TypeScript that DOM typings (in particular, +// window) should be introduced in the ambient scope. +// +// [Note: Node and web stream type mismatch] +// +// Unfortunately, adding this reference causes the ReadableStream typings to +// break since lib.dom.d.ts adds its own incompatible definitions of +// ReadableStream to the global scope. +// +// https://github.com/DefinitelyTyped/DefinitelyTyped/discussions/68407 + +/// + import { contextBridge, ipcRenderer, webUtils } from "electron/renderer"; // While we can't import other code, we can import types since they're just @@ -48,6 +70,19 @@ import type { ZipItem, } from "./types/ipc"; +// - Infrastructure + +// We need to wait until the renderer is ready before sending ports via +// postMessage, and this promise comes handy in such cases. We create the +// promise at the top level so that it is guaranteed to be registered before the +// load event is fired. +// +// See: https://www.electronjs.org/docs/latest/tutorial/message-ports + +const windowLoaded = new Promise((resolve) => { + window.onload = resolve; +}); + // - General const appVersion = () => ipcRenderer.invoke("appVersion"); @@ -163,17 +198,15 @@ const ffmpegExec = ( // - ML -const computeCLIPImageEmbedding = (input: Float32Array) => - ipcRenderer.invoke("computeCLIPImageEmbedding", input); - -const computeCLIPTextEmbeddingIfAvailable = (text: string) => - ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text); - -const detectFaces = (input: Float32Array) => - ipcRenderer.invoke("detectFaces", input); - -const computeFaceEmbeddings = (input: Float32Array) => - ipcRenderer.invoke("computeFaceEmbeddings", input); +const createMLWorker = () => { + ipcRenderer.send("createMLWorker"); + ipcRenderer.on("createMLWorker/port", (event) => { + void windowLoaded.then(() => { + // "*"" is the origin to send to. + window.postMessage("createMLWorker/port", "*", event.ports); + }); + }); +}; // - Watch @@ -281,8 +314,11 @@ const clearPendingUploads = () => ipcRenderer.invoke("clearPendingUploads"); * operation when it happens across threads. * https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects * - * In our case though, we're not dealing with threads but separate processes. So - * the ArrayBuffer will be copied: + * In our case though, we're not dealing with threads but separate processes. + * Electron currently only supports transferring MessagePorts: + * https://github.com/electron/electron/issues/34905 + * + * So the ArrayBuffer will be copied: * * > "parameters, errors and return values are **copied** when they're sent over * > the bridge". @@ -339,10 +375,7 @@ contextBridge.exposeInMainWorld("electron", { // - ML - computeCLIPImageEmbedding, - computeCLIPTextEmbeddingIfAvailable, - detectFaces, - computeFaceEmbeddings, + createMLWorker, // - Watch diff --git a/desktop/yarn.lock b/desktop/yarn.lock index 5feaf65f6f..afbe850a91 100644 --- a/desktop/yarn.lock +++ b/desktop/yarn.lock @@ -968,6 +968,11 @@ combined-stream@^1.0.8: dependencies: delayed-stream "~1.0.0" +comlink@^4.4.1: + version "4.4.1" + resolved "https://registry.yarnpkg.com/comlink/-/comlink-4.4.1.tgz#e568b8e86410b809e8600eb2cf40c189371ef981" + integrity sha512-+1dlx0aY5Jo1vHy/tSsIGpSkN4tS9rZSW8FIhG0JH/crs9wwweswIo/POr451r7bZww3hFbPAKnTpimzL/mm4Q== + commander@^5.0.0: version "5.1.0" resolved "https://registry.yarnpkg.com/commander/-/commander-5.1.0.tgz#46abbd1652f8e059bddaef99bbdcb2ad9cf179ae" diff --git a/web/apps/cast/src/services/pair.ts b/web/apps/cast/src/services/pair.ts index b5646698cc..287122c456 100644 --- a/web/apps/cast/src/services/pair.ts +++ b/web/apps/cast/src/services/pair.ts @@ -82,7 +82,8 @@ export const register = async (): Promise => { // Register keypair with museum to get a pairing code. let pairingCode: string | undefined; - // TODO: eslint has fixed this spurious warning, but we're not on the latest + // [TODO: spurious while(true) eslint warning]. + // eslint has fixed this spurious warning, but we're not on the latest // version yet, so add a disable. // https://github.com/eslint/eslint/pull/18286 /* eslint-disable no-constant-condition */ diff --git a/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx b/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx index 5fdcd23852..8980c8ed5f 100644 --- a/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx +++ b/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx @@ -82,7 +82,7 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) { } onClick={() => setOpenMLSettings(true)} - label={pt("ML search")} + label={pt("Face and magic search")} /> diff --git a/web/apps/photos/src/components/Sidebar/Preferences.tsx b/web/apps/photos/src/components/Sidebar/Preferences.tsx index 38cccf45de..000a1e44cc 100644 --- a/web/apps/photos/src/components/Sidebar/Preferences.tsx +++ b/web/apps/photos/src/components/Sidebar/Preferences.tsx @@ -85,14 +85,9 @@ export default function Preferences({ open, onClose, onRootClose }) { } onClick={() => setOpenMLSettings(true)} - label={pt("ML search")} + label={pt("Face and magic search")} /> - )} diff --git a/web/apps/photos/src/services/logout.ts b/web/apps/photos/src/services/logout.ts index 9722757689..ab4b1aa6c5 100644 --- a/web/apps/photos/src/services/logout.ts +++ b/web/apps/photos/src/services/logout.ts @@ -22,7 +22,7 @@ export const photosLogout = async () => { // See: [Note: Caching IDB instances in separate execution contexts]. try { - terminateMLWorker(); + await terminateMLWorker(); } catch (e) { ignoreError("face", e); } diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index 8b90652e64..750a1fb186 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -1,13 +1,12 @@ import { isDesktop } from "@/base/app"; -import { ensureElectron } from "@/base/electron"; import log from "@/base/log"; import { FileType } from "@/media/file-type"; import { + clipMatches, isMLEnabled, isMLSupported, mlStatusSnapshot, } from "@/new/photos/services/ml"; -import { clipMatches } from "@/new/photos/services/ml/clip"; import type { Person } from "@/new/photos/services/ml/people"; import { EnteFile } from "@/new/photos/types/file"; import * as chrono from "chrono-node"; @@ -374,7 +373,7 @@ const searchClip = async ( searchPhrase: string, ): Promise => { if (!isMLEnabled()) return undefined; - const matches = await clipMatches(searchPhrase, ensureElectron()); + const matches = await clipMatches(searchPhrase); log.debug(() => ["clip/scores", matches]); return matches; }; diff --git a/web/packages/base/types/ipc.ts b/web/packages/base/types/ipc.ts index 7a11553835..c0644760c0 100644 --- a/web/packages/base/types/ipc.ts +++ b/web/packages/base/types/ipc.ts @@ -335,57 +335,21 @@ export interface Electron { // - ML /** - * Return a CLIP embedding of the given image. + * Create a new ML worker, terminating the older ones (if any). * - * See: [Note: Natural language search using CLIP] + * This creates a new Node.js utility process, and sets things up so that we + * can communicate directly with that utility process using a + * {@link MessagePort} that gets posted using "createMLWorker/port". * - * The input is a opaque float32 array representing the image. The layout - * and exact encoding of the input is specific to our implementation and the - * ML model (CLIP) we use. + * At the other end of that port will be an object that conforms to the + * {@link ElectronMLWorker} interface. * - * @returns A CLIP embedding (an array of 512 floating point values). + * For more details about the IPC flow, see: [Note: ML IPC]. + * + * Note: For simplicity of implementation, we assume that there is at most + * one outstanding call to {@link createMLWorker}. */ - computeCLIPImageEmbedding: (input: Float32Array) => Promise; - - /** - * Return a CLIP embedding of the given image if we already have the model - * downloaded and prepped. If the model is not available return `undefined`. - * - * This differs from the other sibling ML functions in that it doesn't wait - * for the model download to finish. It does trigger a model download, but - * then immediately returns `undefined`. At some future point, when the - * model downloaded finishes, calls to this function will start returning - * the result we seek. - * - * The reason for doing it in this asymmetric way is because CLIP text - * embeddings are used as part of deducing user initiated search results, - * and we don't want to block that interaction on a large network request. - * - * See: [Note: Natural language search using CLIP] - * - * @param text The string whose embedding we want to compute. - * - * @returns A CLIP embedding. - */ - computeCLIPTextEmbeddingIfAvailable: ( - text: string, - ) => Promise; - - /** - * Detect faces in the given image using YOLO. - * - * Both the input and output are opaque binary data whose internal structure - * is specific to our implementation and the model (YOLO) we use. - */ - detectFaces: (input: Float32Array) => Promise; - - /** - * Return a MobileFaceNet embeddings for the given faces. - * - * Both the input and output are opaque binary data whose internal structure - * is specific to our implementation and the model (MobileFaceNet) we use. - */ - computeFaceEmbeddings: (input: Float32Array) => Promise; + createMLWorker: () => void; // - Watch @@ -574,6 +538,65 @@ export interface Electron { clearPendingUploads: () => Promise; } +/** + * The shape of the object exposed by the Node.js ML worker process on the + * message port that the web layer obtains by doing {@link createMLWorker}. + */ +export interface ElectronMLWorker { + /** + * Return a CLIP embedding of the given image. + * + * See: [Note: Natural language search using CLIP] + * + * The input is a opaque float32 array representing the image. The layout + * and exact encoding of the input is specific to our implementation and the + * ML model (CLIP) we use. + * + * @returns A CLIP embedding (an array of 512 floating point values). + */ + computeCLIPImageEmbedding: (input: Float32Array) => Promise; + + /** + * Return a CLIP embedding of the given image if we already have the model + * downloaded and prepped. If the model is not available return `undefined`. + * + * This differs from the other sibling ML functions in that it doesn't wait + * for the model download to finish. It does trigger a model download, but + * then immediately returns `undefined`. At some future point, when the + * model downloaded finishes, calls to this function will start returning + * the result we seek. + * + * The reason for doing it in this asymmetric way is because CLIP text + * embeddings are used as part of deducing user initiated search results, + * and we don't want to block that interaction on a large network request. + * + * See: [Note: Natural language search using CLIP] + * + * @param text The string whose embedding we want to compute. + * + * @returns A CLIP embedding. + */ + computeCLIPTextEmbeddingIfAvailable: ( + text: string, + ) => Promise; + + /** + * Detect faces in the given image using YOLO. + * + * Both the input and output are opaque binary data whose internal structure + * is specific to our implementation and the model (YOLO) we use. + */ + detectFaces: (input: Float32Array) => Promise; + + /** + * Return a MobileFaceNet embeddings for the given faces. + * + * Both the input and output are opaque binary data whose internal structure + * is specific to our implementation and the model (MobileFaceNet) we use. + */ + computeFaceEmbeddings: (input: Float32Array) => Promise; +} + /** * Errors that have special semantics on the web side. * diff --git a/web/packages/base/worker/comlink-worker.ts b/web/packages/base/worker/comlink-worker.ts index 4562805b3b..330c5637bd 100644 --- a/web/packages/base/worker/comlink-worker.ts +++ b/web/packages/base/worker/comlink-worker.ts @@ -28,7 +28,7 @@ export class ComlinkWorker InstanceType> { /** The class (T) exposed by the web worker */ public remote: Promise>>; /** The web worker */ - private worker: Worker; + public worker: Worker; /** An arbitrary name associated with this ComlinkWorker for debugging. */ private name: string; diff --git a/web/packages/new/photos/components/MLSettings.tsx b/web/packages/new/photos/components/MLSettings.tsx index 7e2d53f872..2e5a0c6def 100644 --- a/web/packages/new/photos/components/MLSettings.tsx +++ b/web/packages/new/photos/components/MLSettings.tsx @@ -124,7 +124,7 @@ export const MLSettings: React.FC = ({ {component} @@ -305,7 +305,7 @@ const ManageML: React.FC = ({ let status: string; switch (phase) { case "indexing": - status = pt("Indexing"); + status = pt("Running"); break; case "scheduled": status = pt("Scheduled"); @@ -319,9 +319,9 @@ const ManageML: React.FC = ({ const confirmDisableML = () => { setDialogBoxAttributesV2({ - title: pt("Disable ML search"), + title: pt("Disable face and magic search"), content: pt( - "Do you want to disable ML search on all your devices?", + "Do you want to disable face and magic search on all your devices?", ), close: { text: t("cancel") }, proceed: { @@ -356,7 +356,7 @@ const ManageML: React.FC = ({ justifyContent={"space-between"} > - {pt("Status")} + {pt("Indexing")} {status} diff --git a/web/packages/new/photos/components/MLSettingsBeta.tsx b/web/packages/new/photos/components/MLSettingsBeta.tsx index db1b83da11..2f9bae19f4 100644 --- a/web/packages/new/photos/components/MLSettingsBeta.tsx +++ b/web/packages/new/photos/components/MLSettingsBeta.tsx @@ -42,7 +42,7 @@ export const MLSettingsBeta: React.FC = ({ diff --git a/web/packages/new/photos/services/ml/blob.ts b/web/packages/new/photos/services/ml/blob.ts index 015dc7462d..d52772b6a4 100644 --- a/web/packages/new/photos/services/ml/blob.ts +++ b/web/packages/new/photos/services/ml/blob.ts @@ -1,4 +1,5 @@ import { basename } from "@/base/file"; +import type { ElectronMLWorker } from "@/base/types/ipc"; import { FileType } from "@/media/file-type"; import { decodeLivePhoto } from "@/media/live-photo"; import { ensure } from "@/utils/ensure"; @@ -7,7 +8,6 @@ import { renderableImageBlob } from "../../utils/file"; import { readStream } from "../../utils/native-stream"; import DownloadManager from "../download"; import type { UploadItem } from "../upload/types"; -import type { MLWorkerElectron } from "./worker-types"; /** * A pair of blobs - the original, and a possibly converted "renderable" one - @@ -103,13 +103,14 @@ export const imageBitmapAndData = async ( * be set to the {@link UploadItem} that was uploaded. This way, we can directly * use the on-disk file instead of needing to download the original from remote. * - * @param electron The {@link MLWorkerElectron} instance that allows us to call - * our Node.js layer for various functionality. + * @param electron The {@link ElectronMLWorker} instance that stands as a + * witness that we're actually running in our desktop app (and thus can safely + * call our Node.js layer for various functionality). */ export const indexableBlobs = async ( enteFile: EnteFile, uploadItem: UploadItem | undefined, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => uploadItem ? await indexableUploadItemBlobs(enteFile, uploadItem, electron) @@ -118,7 +119,7 @@ export const indexableBlobs = async ( const indexableUploadItemBlobs = async ( enteFile: EnteFile, uploadItem: UploadItem, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ) => { const fileType = enteFile.metadata.fileType; let originalImageBlob: Blob | undefined; @@ -149,7 +150,7 @@ const indexableUploadItemBlobs = async ( */ const readNonVideoUploadItem = async ( uploadItem: UploadItem, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => { if (typeof uploadItem == "string" || Array.isArray(uploadItem)) { const { response, lastModifiedMs } = await readStream( diff --git a/web/packages/new/photos/services/ml/clip.ts b/web/packages/new/photos/services/ml/clip.ts index eecf7e2209..78eff1c04d 100644 --- a/web/packages/new/photos/services/ml/clip.ts +++ b/web/packages/new/photos/services/ml/clip.ts @@ -1,9 +1,9 @@ -import type { Electron } from "@/base/types/ipc"; +import type { ElectronMLWorker } from "@/base/types/ipc"; import type { ImageBitmapAndData } from "./blob"; import { clipIndexes } from "./db"; import { pixelRGBBicubic } from "./image"; import { dotProduct, norm } from "./math"; -import type { MLWorkerElectron } from "./worker-types"; +import type { CLIPMatches } from "./worker-types"; /** * The version of the CLIP indexing pipeline implemented by the current client. @@ -98,19 +98,19 @@ export type LocalCLIPIndex = CLIPIndex & { * be set to the {@link UploadItem} that was uploaded. This way, we can directly * use the on-disk file instead of needing to download the original from remote. * - * @param electron The {@link MLWorkerElectron} instance that allows us to call + * @param electron The {@link ElectronMLWorker} instance that allows us to call * our Node.js layer to run the ONNX inference. */ export const indexCLIP = async ( image: ImageBitmapAndData, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => ({ embedding: await computeEmbedding(image.data, electron), }); const computeEmbedding = async ( imageData: ImageData, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => { const clipInput = convertToCLIPInput(imageData); return normalized(await electron.computeCLIPImageEmbedding(clipInput)); @@ -167,26 +167,15 @@ const normalized = (embedding: Float32Array) => { }; /** - * Use CLIP to perform a natural language search over image embeddings. - * - * @param searchPhrase The text entered by the user in the search box. - * - * @param electron The {@link Electron} instance to use to communicate with the - * native code running in our desktop app (the embedding happens in the native - * layer). - * - * It returns file (IDs) that should be shown in the search results. They're - * returned as a map from fileIDs to the scores they got (higher is better). - * This map will only contains entries whose score was above our minimum - * threshold. + * Find the files whose CLIP embedding "matches" the given {@link searchPhrase}. * * The result can also be `undefined`, which indicates that the download for the * ML model is still in progress (trying again later should succeed). */ export const clipMatches = async ( searchPhrase: string, - electron: Electron, -): Promise | undefined> => { + electron: ElectronMLWorker, +): Promise => { const t = await electron.computeCLIPTextEmbeddingIfAvailable(searchPhrase); if (!t) return undefined; diff --git a/web/packages/new/photos/services/ml/face.ts b/web/packages/new/photos/services/ml/face.ts index 910970d3b9..7ecbf06002 100644 --- a/web/packages/new/photos/services/ml/face.ts +++ b/web/packages/new/photos/services/ml/face.ts @@ -7,6 +7,7 @@ // /* eslint-disable @typescript-eslint/no-non-null-assertion */ +import type { ElectronMLWorker } from "@/base/types/ipc"; import type { EnteFile } from "@/new/photos/types/file"; import { Matrix } from "ml-matrix"; import { getSimilarityTransformation } from "similarity-transformation"; @@ -24,7 +25,6 @@ import { warpAffineFloat32List, } from "./image"; import { clamp } from "./math"; -import type { MLWorkerElectron } from "./worker-types"; /** * The version of the face indexing pipeline implemented by the current client. @@ -236,13 +236,13 @@ export interface Box { * * @param image The file's contents. * - * @param electron The {@link MLWorkerElectron} instance that allows us to call + * @param electron The {@link ElectronMLWorker} instance that allows us to call * our Node.js layer to run the ONNX inference. */ export const indexFaces = async ( enteFile: EnteFile, { data: imageData }: ImageBitmapAndData, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => ({ width: imageData.width, height: imageData.height, @@ -252,7 +252,7 @@ export const indexFaces = async ( const indexFaces_ = async ( fileID: number, imageData: ImageData, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => { const { width, height } = imageData; const imageDimensions = { width, height }; @@ -316,7 +316,7 @@ const indexFaces_ = async ( */ const detectFaces = async ( imageData: ImageData, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => { const rect = ({ width, height }: Dimensions) => ({ x: 0, @@ -878,7 +878,7 @@ const mobileFaceNetEmbeddingSize = 192; */ const computeEmbeddings = async ( faceData: Float32Array, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ): Promise => { const outputData = await electron.computeFaceEmbeddings(faceData); diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts index c3432b1023..5b57dade21 100644 --- a/web/packages/new/photos/services/ml/index.ts +++ b/web/packages/new/photos/services/ml/index.ts @@ -6,17 +6,20 @@ import { isDesktop } from "@/base/app"; import { blobCache } from "@/base/blob-cache"; import { ensureElectron } from "@/base/electron"; import log from "@/base/log"; +import type { Electron } from "@/base/types/ipc"; import { ComlinkWorker } from "@/base/worker/comlink-worker"; import { FileType } from "@/media/file-type"; import type { EnteFile } from "@/new/photos/types/file"; +import { ensure } from "@/utils/ensure"; import { throttled } from "@/utils/promise"; -import { proxy } from "comlink"; +import { proxy, transfer } from "comlink"; import { isInternalUser } from "../feature-flags"; import { getRemoteFlag, updateRemoteFlag } from "../remote-store"; import type { UploadItem } from "../upload/types"; import { regenerateFaceCrops } from "./crop"; import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db"; import { MLWorker } from "./worker"; +import type { CLIPMatches } from "./worker-types"; /** * In-memory flag that tracks if ML is enabled. @@ -33,7 +36,7 @@ import { MLWorker } from "./worker"; let _isMLEnabled = false; /** Cached instance of the {@link ComlinkWorker} that wraps our web worker. */ -let _comlinkWorker: ComlinkWorker | undefined; +let _comlinkWorker: Promise> | undefined; /** * Subscriptions to {@link MLStatus}. @@ -50,29 +53,28 @@ let _mlStatusListeners: (() => void)[] = []; let _mlStatusSnapshot: MLStatus | undefined; /** Lazily created, cached, instance of {@link MLWorker}. */ -const worker = async () => { - if (!_comlinkWorker) _comlinkWorker = await createComlinkWorker(); - return _comlinkWorker.remote; -}; +const worker = () => + (_comlinkWorker ??= createComlinkWorker()).then((cw) => cw.remote); const createComlinkWorker = async () => { const electron = ensureElectron(); - const mlWorkerElectron = { - detectFaces: electron.detectFaces, - computeFaceEmbeddings: electron.computeFaceEmbeddings, - computeCLIPImageEmbedding: electron.computeCLIPImageEmbedding, - }; const delegate = { workerDidProcessFile, }; + // Obtain a message port from the Electron layer. + const messagePort = await createMLWorker(electron); + const cw = new ComlinkWorker( "ML", new Worker(new URL("worker.ts", import.meta.url)), ); + await cw.remote.then((w) => - w.init(proxy(mlWorkerElectron), proxy(delegate)), + // Forward the port to the web worker. + w.init(transfer(messagePort, [messagePort]), proxy(delegate)), ); + return cw; }; @@ -85,13 +87,40 @@ const createComlinkWorker = async () => { * * It is also called when the user pauses or disables ML. */ -export const terminateMLWorker = () => { +export const terminateMLWorker = async () => { if (_comlinkWorker) { - _comlinkWorker.terminate(); + await _comlinkWorker.then((cw) => cw.terminate()); _comlinkWorker = undefined; } }; +/** + * Obtain a port from the Node.js layer that can be used to communicate with the + * ML worker process. + */ +const createMLWorker = (electron: Electron): Promise => { + // The main process will do its thing, and send back the port it created to + // us by sending an message on the "createMLWorker/port" channel via the + // postMessage API. This roundabout way is needed because MessagePorts + // cannot be transferred via the usual send/invoke pattern. + + const port = new Promise((resolve) => { + const l = ({ source, data, ports }: MessageEvent) => { + // The source check verifies that the message is coming from our own + // preload script. The data is the message that was posted. + if (source == window && data == "createMLWorker/port") { + window.removeEventListener("message", l); + resolve(ensure(ports[0])); + } + }; + window.addEventListener("message", l); + }); + + electron.createMLWorker(); + + return port; +}; + /** * Return true if the current client supports ML. * @@ -163,7 +192,7 @@ export const disableML = async () => { await updateIsMLEnabledRemote(false); setIsMLEnabledLocal(false); _isMLEnabled = false; - terminateMLWorker(); + await terminateMLWorker(); triggerStatusUpdate(); }; @@ -369,6 +398,22 @@ const setInterimScheduledStatus = () => { const workerDidProcessFile = throttled(updateMLStatusSnapshot, 2000); +/** + * Use CLIP to perform a natural language search over image embeddings. + * + * @param searchPhrase The text entered by the user in the search box. + * + * It returns file (IDs) that should be shown in the search results, along with + * their scores. + * + * The result can also be `undefined`, which indicates that the download for the + * ML model is still in progress (trying again later should succeed). + */ +export const clipMatches = ( + searchPhrase: string, +): Promise => + worker().then((w) => w.clipMatches(searchPhrase)); + /** * Return the IDs of all the faces in the given {@link enteFile} that are not * associated with a person cluster. diff --git a/web/packages/new/photos/services/ml/worker-types.ts b/web/packages/new/photos/services/ml/worker-types.ts index 1eb43933a3..72d6bce61b 100644 --- a/web/packages/new/photos/services/ml/worker-types.ts +++ b/web/packages/new/photos/services/ml/worker-types.ts @@ -1,22 +1,7 @@ /** - * @file Type for the objects shared (as a Comlink proxy) by the main thread and - * the ML worker. + * @file Types for the objects shared between the main thread and the ML worker. */ -/** - * A subset of {@link Electron} provided to the {@link MLWorker}. - * - * `globalThis.electron` does not exist in the execution context of web workers. - * So instead, we manually provide a proxy object of type - * {@link MLWorkerElectron} that exposes a subset of the functions from - * {@link Electron} that are needed by the code running in the ML web worker. - */ -export interface MLWorkerElectron { - detectFaces: (input: Float32Array) => Promise; - computeFaceEmbeddings: (input: Float32Array) => Promise; - computeCLIPImageEmbedding: (input: Float32Array) => Promise; -} - /** * Callbacks invoked by the worker at various points in the indexing pipeline to * notify the main thread of events it might be interested in. @@ -25,7 +10,18 @@ export interface MLWorkerDelegate { /** * Called whenever a file is processed during indexing. * - * It is called both when the indexing was successful or failed. + * It is called both when the indexing was successful or it failed. */ workerDidProcessFile: () => void; } + +/** + * The result of file ids that should be considered as matches for a particular + * search phrase, each with their associated score. + * + * This is a map of file (IDs) that should be shown in the search results. + * They're returned as a map from fileIDs to the scores they got (higher is + * better). This map will only contains entries whose score was above our + * minimum threshold. + */ +export type CLIPMatches = Map; diff --git a/web/packages/new/photos/services/ml/worker.ts b/web/packages/new/photos/services/ml/worker.ts index 3baf512a8c..37d6dc259e 100644 --- a/web/packages/new/photos/services/ml/worker.ts +++ b/web/packages/new/photos/services/ml/worker.ts @@ -3,14 +3,15 @@ import { isHTTP4xxError } from "@/base/http"; import { getKVN } from "@/base/kv"; import { ensureAuthToken } from "@/base/local-user"; import log from "@/base/log"; +import type { ElectronMLWorker } from "@/base/types/ipc"; import type { EnteFile } from "@/new/photos/types/file"; import { fileLogID } from "@/new/photos/utils/file"; import { ensure } from "@/utils/ensure"; import { wait } from "@/utils/promise"; import { DOMParser } from "@xmldom/xmldom"; -import { expose } from "comlink"; +import { expose, wrap } from "comlink"; import downloadManager from "../download"; -import { cmpNewLib2, extractRawExif } from "../exif"; +import { cmpNewLib2, extractRawExif, type RawExifTags } from "../exif"; import { getAllLocalFiles, getLocalTrashedFiles } from "../files"; import type { UploadItem } from "../upload/types"; import { @@ -18,7 +19,12 @@ import { indexableBlobs, type ImageBitmapAndData, } from "./blob"; -import { clipIndexingVersion, indexCLIP, type CLIPIndex } from "./clip"; +import { + clipIndexingVersion, + clipMatches, + indexCLIP, + type CLIPIndex, +} from "./clip"; import { saveFaceCrops } from "./crop"; import { indexableFileIDs, @@ -29,10 +35,11 @@ import { import { fetchDerivedData, putDerivedData, + type RawRemoteDerivedData, type RemoteDerivedData, } from "./embedding"; import { faceIndexingVersion, indexFaces, type FaceIndex } from "./face"; -import type { MLWorkerDelegate, MLWorkerElectron } from "./worker-types"; +import type { CLIPMatches, MLWorkerDelegate } from "./worker-types"; const idleDurationStart = 5; /* 5 seconds */ const idleDurationMax = 16 * 60; /* 16 minutes */ @@ -67,9 +74,12 @@ interface IndexableItem { * - "backfillq": fetching remote embeddings of unindexed items, and then * indexing them if needed, * - "idle": in between state transitions. + * + * In addition, MLWorker can also be invoked for interactive tasks: in + * particular, for finding the closest CLIP match when the user does a search. */ export class MLWorker { - private electron: MLWorkerElectron | undefined; + private electron: ElectronMLWorker | undefined; private delegate: MLWorkerDelegate | undefined; private state: "idle" | "indexing" = "idle"; private liveQ: IndexableItem[] = []; @@ -82,15 +92,16 @@ export class MLWorker { * This is conceptually the constructor, however it is easier to have this * as a separate function to avoid complicating the comlink types further. * - * @param electron The {@link MLWorkerElectron} that allows the worker to - * use the functionality provided by our Node.js layer when running in the - * context of our desktop app. + * @param port A {@link MessagePort} that allows us to communicate with an + * Electron utility process running in the Node.js layer of our desktop app, + * exposing an object that conforms to the {@link ElectronMLWorker} + * interface. * * @param delegate The {@link MLWorkerDelegate} the worker can use to inform * the main thread of interesting events. */ - async init(electron: MLWorkerElectron, delegate?: MLWorkerDelegate) { - this.electron = electron; + async init(port: MessagePort, delegate: MLWorkerDelegate) { + this.electron = wrap(port); this.delegate = delegate; // Initialize the downloadManager running in the web worker with the // user's token. It'll be used to download files to index if needed. @@ -176,6 +187,13 @@ export class MLWorker { return this.state == "indexing"; } + /** + * Find {@link CLIPMatches} for a given {@link searchPhrase}. + */ + async clipMatches(searchPhrase: string): Promise { + return clipMatches(searchPhrase, ensure(this.electron)); + } + private async tick() { log.debug(() => [ "ml/tick", @@ -224,7 +242,7 @@ export class MLWorker { } /** Return the next batch of items to backfill (if any). */ - async backfillQ() { + private async backfillQ() { const userID = ensure(await getKVN("userID")); // Find files that our local DB thinks need syncing. const filesByID = await syncWithLocalFilesAndGetFilesToIndex( @@ -256,7 +274,7 @@ expose(MLWorker); */ const indexNextBatch = async ( items: IndexableItem[], - electron: MLWorkerElectron, + electron: ElectronMLWorker, delegate: MLWorkerDelegate | undefined, ) => { // Don't try to index if we wouldn't be able to upload them anyway. The @@ -270,18 +288,37 @@ const indexNextBatch = async ( // Nothing to do. if (items.length == 0) return false; - // Index, keeping track if any of the items failed. + // Keep track if any of the items failed. let allSuccess = true; - for (const item of items) { - try { - await index(item, electron); - delegate?.workerDidProcessFile(); - // Possibly unnecessary, but let us drain the microtask queue. - await wait(0); - } catch (e) { - log.warn(`Skipping unindexable file ${item.enteFile.id}`, e); - allSuccess = false; + + // Index up to 4 items simultaneously. + const tasks = new Array | undefined>(4).fill(undefined); + + let i = 0; + while (i < items.length) { + for (let j = 0; j < tasks.length; j++) { + if (i < items.length && !tasks[j]) { + tasks[j] = index(ensure(items[i++]), electron) + .then(() => { + tasks[j] = undefined; + }) + .catch(() => { + allSuccess = false; + tasks[j] = undefined; + }); + } } + + // Wait for at least one to complete (the other runners continue running + // even if one promise reaches the finish line). + await Promise.race(tasks); + + // Let the main thread now we're doing something. + delegate?.workerDidProcessFile(); + + // Let us drain the microtask queue. This also gives a chance for other + // interactive tasks like `clipMatches` to run. + await wait(0); } // Return true if nothing failed. @@ -374,7 +411,7 @@ const syncWithLocalFilesAndGetFilesToIndex = async ( */ const index = async ( { enteFile, uploadItem, remoteDerivedData }: IndexableItem, - electron: MLWorkerElectron, + electron: ElectronMLWorker, ) => { const f = fileLogID(enteFile); const fileID = enteFile.id; @@ -463,10 +500,7 @@ const index = async ( [faceIndex, clipIndex, exif] = await Promise.all([ existingFaceIndex ?? indexFaces(enteFile, image, electron), existingCLIPIndex ?? indexCLIP(image, electron), - existingExif ?? - (originalImageBlob - ? extractRawExif(originalImageBlob) - : undefined), + existingExif ?? tryExtractExif(originalImageBlob, f), ]); } catch (e) { // See: [Note: Transient and permanent indexing failures] @@ -475,8 +509,12 @@ const index = async ( throw e; } - if (originalImageBlob && exif) - await cmpNewLib2(enteFile, originalImageBlob, exif); + try { + if (originalImageBlob && exif) + await cmpNewLib2(enteFile, originalImageBlob, exif); + } catch (e) { + log.warn(`Skipping exif cmp for ${f}`, e); + } log.debug(() => { const ms = Date.now() - startTime; @@ -504,22 +542,27 @@ const index = async ( // parts. See: [Note: Preserve unknown derived data fields]. const existingRawDerivedData = remoteDerivedData?.raw ?? {}; - const rawDerivedData = { + const rawDerivedData: RawRemoteDerivedData = { ...existingRawDerivedData, face: remoteFaceIndex, clip: remoteCLIPIndex, - exif, + ...(exif ? { exif } : {}), }; - log.debug(() => ["Uploading derived data", rawDerivedData]); + if (existingFaceIndex && existingCLIPIndex && !exif) { + // If we were indexing just for exif, but exif generation didn't + // happen, there is no need to upload. + } else { + log.debug(() => ["Uploading derived data", rawDerivedData]); - try { - await putDerivedData(enteFile, rawDerivedData); - } catch (e) { - // See: [Note: Transient and permanent indexing failures] - log.error(`Failed to put derived data for ${f}`, e); - if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id); - throw e; + try { + await putDerivedData(enteFile, rawDerivedData); + } catch (e) { + // See: [Note: Transient and permanent indexing failures] + log.error(`Failed to put derived data for ${f}`, e); + if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id); + throw e; + } } try { @@ -550,3 +593,34 @@ const index = async ( image.bitmap.close(); } }; + +/** + * A helper function that tries to extract the raw Exif, but returns `undefined` + * if something goes wrong (or it isn't possible) instead of throwing. + * + * Exif extraction is not a critical item, we don't want the actual indexing to + * fail because we were unable to extract Exif. This is not rare: one scenario + * is if we were trying to index a file in an exotic format. The ML indexing + * will succeed (because we convert it to a renderable blob), but the Exif + * extraction will fail (since it needs the original blob, but the original blob + * can be an arbitrary format). + * + * @param originalImageBlob A {@link Blob} containing the original data for the + * image (or the image component of a live photo) whose Exif we're trying to + * extract. If this is not available, we skip the extraction and return + * `undefined`. + * + * @param f The {@link fileLogID} for the file this blob corresponds to. + */ +export const tryExtractExif = async ( + originalImageBlob: Blob | undefined, + f: string, +): Promise => { + if (!originalImageBlob) return undefined; + try { + return await extractRawExif(originalImageBlob); + } catch (e) { + log.warn(`Ignoring error during Exif extraction for ${f}`, e); + return undefined; + } +}; diff --git a/web/packages/new/photos/utils/native-stream.ts b/web/packages/new/photos/utils/native-stream.ts index 6f61656597..6aee016c0c 100644 --- a/web/packages/new/photos/utils/native-stream.ts +++ b/web/packages/new/photos/utils/native-stream.ts @@ -6,8 +6,7 @@ * See: [Note: IPC streams]. */ -import type { Electron, ZipItem } from "@/base/types/ipc"; -import type { MLWorkerElectron } from "../services/ml/worker-types"; +import type { Electron, ElectronMLWorker, ZipItem } from "@/base/types/ipc"; /** * Stream the given file or zip entry from the user's local file system. @@ -18,7 +17,7 @@ import type { MLWorkerElectron } from "../services/ml/worker-types"; * * To avoid accidentally invoking it in a non-desktop app context, it requires * the {@link Electron} (or a functionally similar) object as a parameter (even - * though it doesn't use it). + * though it doesn't need or use it). * * @param pathOrZipItem Either the path on the file on the user's local file * system whose contents we want to stream. Or a tuple containing the path to a @@ -36,7 +35,7 @@ import type { MLWorkerElectron } from "../services/ml/worker-types"; * reading, expressed as epoch milliseconds. */ export const readStream = async ( - _: Electron | MLWorkerElectron, + _: Electron | ElectronMLWorker, pathOrZipItem: string | ZipItem, ): Promise<{ response: Response; size: number; lastModifiedMs: number }> => { let url: URL;