[desktop] Use a utility process for ML indexing (#2583)

* This ensures that the UI remains responsive when indexing is running.
* This allows us to parallelize the in-flight networking etc instead of
waterfalling everything.
* The actual CPU intensive indexing is still serial.
This commit is contained in:
Manav Rathi
2024-07-31 16:27:28 +05:30
committed by GitHub
28 changed files with 885 additions and 469 deletions

View File

@@ -28,6 +28,7 @@
"auto-launch": "^5.0",
"chokidar": "^3.6",
"clip-bpe-js": "^0.0.6",
"comlink": "^4.4.1",
"compare-versions": "^6.1",
"electron-log": "^5.1",
"electron-store": "^8.2",

View File

@@ -21,6 +21,7 @@ import {
attachFSWatchIPCHandlers,
attachIPCHandlers,
attachLogoutIPCHandler,
attachMainWindowIPCHandlers,
} from "./main/ipc";
import log, { initLogging } from "./main/log";
import { createApplicationMenu, createTrayContextMenu } from "./main/menu";
@@ -121,6 +122,7 @@ const main = () => {
// Setup IPC and streams.
const watcher = createWatcher(mainWindow);
attachIPCHandlers();
attachMainWindowIPCHandlers(mainWindow);
attachFSWatchIPCHandlers(watcher);
attachLogoutIPCHandler(watcher);
registerStreamProtocol();

View File

@@ -9,6 +9,7 @@
*/
import type { FSWatcher } from "chokidar";
import type { BrowserWindow } from "electron";
import { ipcMain } from "electron/main";
import type {
CollectionMapping,
@@ -42,11 +43,7 @@ import {
} from "./services/fs";
import { convertToJPEG, generateImageThumbnail } from "./services/image";
import { logout } from "./services/logout";
import {
computeCLIPImageEmbedding,
computeCLIPTextEmbeddingIfAvailable,
} from "./services/ml-clip";
import { computeFaceEmbeddings, detectFaces } from "./services/ml-face";
import { createMLWorker } from "./services/ml";
import {
encryptionKey,
lastShownChangelogVersion,
@@ -184,24 +181,6 @@ export const attachIPCHandlers = () => {
) => ffmpegExec(command, dataOrPathOrZipItem, outputFileExtension),
);
// - ML
ipcMain.handle("computeCLIPImageEmbedding", (_, input: Float32Array) =>
computeCLIPImageEmbedding(input),
);
ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) =>
computeCLIPTextEmbeddingIfAvailable(text),
);
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
detectFaces(input),
);
ipcMain.handle("computeFaceEmbeddings", (_, input: Float32Array) =>
computeFaceEmbeddings(input),
);
// - Upload
ipcMain.handle("listZipItems", (_, zipPath: string) =>
@@ -231,6 +210,16 @@ export const attachIPCHandlers = () => {
ipcMain.handle("clearPendingUploads", () => clearPendingUploads());
};
/**
* A subset of {@link attachIPCHandlers} for functions that need a reference to
* the main window to do their thing.
*/
export const attachMainWindowIPCHandlers = (mainWindow: BrowserWindow) => {
// - ML
ipcMain.on("createMLWorker", () => createMLWorker(mainWindow));
};
/**
* Sibling of {@link attachIPCHandlers} that attaches handlers specific to the
* watch folder functionality.

View File

@@ -1,68 +0,0 @@
/**
* @file Compute CLIP embeddings for images and text.
*
* The embeddings are computed using ONNX runtime, with CLIP as the model.
*/
import Tokenizer from "clip-bpe-js";
import * as ort from "onnxruntime-node";
import log from "../log";
import { ensure, wait } from "../utils/common";
import { makeCachedInferenceSession } from "./ml";
const cachedCLIPImageSession = makeCachedInferenceSession(
"clip-image-vit-32-float32.onnx",
351468764 /* 335.2 MB */,
);
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
const session = await cachedCLIPImageSession();
const t = Date.now();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
};
const results = await session.run(feeds);
log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`);
/* Need these model specific casts to type the result */
return ensure(results.output).data as Float32Array;
};
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
);
let _tokenizer: Tokenizer | undefined;
const getTokenizer = () => {
if (!_tokenizer) _tokenizer = new Tokenizer();
return _tokenizer;
};
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
const sessionOrSkip = await Promise.race([
cachedCLIPTextSession(),
// Wait for a tick to get the session promise to resolved the first time
// this code runs on each app start (and the model has been downloaded).
wait(0).then(() => 1),
]);
// Don't wait for the download to complete.
if (typeof sessionOrSkip == "number") {
log.info(
"Ignoring CLIP text embedding request because model download is pending",
);
return undefined;
}
const session = sessionOrSkip;
const t = Date.now();
const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
const feeds = {
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
};
const results = await session.run(feeds);
log.debug(() => `ONNX/CLIP text embedding took ${Date.now() - t} ms`);
return ensure(results.output).data as Float32Array;
};

View File

@@ -1,53 +0,0 @@
/**
* @file Various face recognition related tasks.
*
* - Face detection with the YOLO model.
* - Face embedding with the MobileFaceNet model.
*
* The runtime used is ONNX.
*/
import * as ort from "onnxruntime-node";
import log from "../log";
import { ensure } from "../utils/common";
import { makeCachedInferenceSession } from "./ml";
const cachedFaceDetectionSession = makeCachedInferenceSession(
"yolov5s_face_640_640_dynamic.onnx",
30762872 /* 29.3 MB */,
);
export const detectFaces = async (input: Float32Array) => {
const session = await cachedFaceDetectionSession();
const t = Date.now();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
};
const results = await session.run(feeds);
log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`);
return ensure(results.output).data;
};
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
"mobilefacenet_opset15.onnx",
5286998 /* 5 MB */,
);
export const computeFaceEmbeddings = async (input: Float32Array) => {
// Dimension of each face (alias)
const mobileFaceNetFaceSize = 112;
// Smaller alias
const z = mobileFaceNetFaceSize;
// Size of each face's data in the batch
const n = Math.round(input.length / (z * z * 3));
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
const session = await cachedFaceEmbeddingSession();
const t = Date.now();
const feeds = { img_inputs: inputTensor };
const results = await session.run(feeds);
log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`);
/* Need these model specific casts to extract and type the result */
return (results.embeddings as unknown as Record<string, unknown>)
.cpuData as Float32Array;
};

View File

@@ -0,0 +1,315 @@
/**
* @file ML related tasks. This code runs in a utility process.
*
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
* for various tasks are not shipped with the app but are downloaded on demand.
*/
// See [Note: Using Electron APIs in UtilityProcess] about what we can and
// cannot import.
import Tokenizer from "clip-bpe-js";
import { expose } from "comlink";
import { net } from "electron/main";
import { existsSync } from "fs";
import fs from "node:fs/promises";
import path from "node:path";
import * as ort from "onnxruntime-node";
import { messagePortMainEndpoint } from "../utils/comlink";
import { ensure, wait } from "../utils/common";
import { writeStream } from "../utils/stream";
/**
* We cannot do
*
* import log from "../log";
*
* because that requires the Electron APIs that are not available to a utility
* process (See: [Note: Using Electron APIs in UtilityProcess]). But even if
* that were to work, logging will still be problematic since we'd try opening
* the log file from two different Node.js processes (this one, and the main
* one), and I didn't find any indication in the electron-log repository that
* the log file's integrity would be maintained in such cases.
*
* So instead we create this proxy log object that uses `process.parentPort` to
* transport the logs over to the main process.
*/
const log = {
/**
* Unlike the real {@link log.error}, this accepts only the first string
* argument, not the second optional error one.
*/
errorString: (s: string) => mainProcess("log.errorString", s),
info: (...ms: unknown[]) => mainProcess("log.info", ms),
/**
* Unlike the real {@link log.debug}, this is (a) eagerly evaluated, and (b)
* accepts only strings.
*/
debugString: (s: string) => mainProcess("log.debugString", s),
};
/**
* Send a message to the main process using a barebones RPC protocol.
*/
const mainProcess = (method: string, param: unknown) =>
process.parentPort.postMessage({ method, p: param });
log.debugString(`Started ML worker process`);
process.parentPort.once("message", (e) => {
// Initialize ourselves with the data we got from our parent.
parseInitData(e.data);
// Expose an instance of `ElectronMLWorker` on the port we got from our
// parent.
expose(
{
computeCLIPImageEmbedding,
computeCLIPTextEmbeddingIfAvailable,
detectFaces,
computeFaceEmbeddings,
},
messagePortMainEndpoint(ensure(e.ports[0])),
);
});
/**
* We cannot access Electron's {@link app} object within a utility process, so
* we pass the value of `app.getPath("userData")` during initialization, and it
* can be subsequently retrieved from here.
*/
let _userDataPath: string | undefined;
/** Equivalent to app.getPath("userData") */
const userDataPath = () => ensure(_userDataPath);
const parseInitData = (data: unknown) => {
if (
data &&
typeof data == "object" &&
"userDataPath" in data &&
typeof data.userDataPath == "string"
) {
_userDataPath = data.userDataPath;
} else {
log.errorString("Unparseable initialization data");
}
};
/**
* Return a function that can be used to trigger a download of the specified
* model, and the creating of an ONNX inference session initialized using it.
*
* Multiple parallel calls to the returned function are fine, it ensures that
* the the model will be downloaded and the session created using it only once.
* All pending calls to it meanwhile will just await on the same promise.
*
* And once the promise is resolved, the create ONNX inference session will be
* cached, so subsequent calls to the returned function will just reuse the same
* session.
*
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
* actively trigger a download until the returned function is called.
*
* @param modelName The name of the model to download.
*
* @param modelByteSize The size in bytes that we expect the model to have. If
* the size of the downloaded model does not match the expected size, then we
* will redownload it.
*
* @returns a function. calling that function returns a promise to an ONNX
* session.
*/
const makeCachedInferenceSession = (
modelName: string,
modelByteSize: number,
) => {
let session: Promise<ort.InferenceSession> | undefined;
const download = () =>
modelPathDownloadingIfNeeded(modelName, modelByteSize);
const createSession = (modelPath: string) =>
createInferenceSession(modelPath);
const cachedInferenceSession = () => {
if (!session) session = download().then(createSession);
return session;
};
return cachedInferenceSession;
};
/**
* Download the model named {@link modelName} if we don't already have it.
*
* Also verify that the size of the model we get matches {@expectedByteSize} (if
* not, redownload it).
*
* @returns the path to the model on the local machine.
*/
const modelPathDownloadingIfNeeded = async (
modelName: string,
expectedByteSize: number,
) => {
const modelPath = modelSavePath(modelName);
if (!existsSync(modelPath)) {
log.info("CLIP image model not found, downloading");
await downloadModel(modelPath, modelName);
} else {
const size = (await fs.stat(modelPath)).size;
if (size !== expectedByteSize) {
log.errorString(
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
);
await downloadModel(modelPath, modelName);
}
}
return modelPath;
};
/** Return the path where the given {@link modelName} is meant to be saved */
const modelSavePath = (modelName: string) =>
path.join(userDataPath(), "models", modelName);
const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
// Download.
log.info(`Downloading ML model from ${name}`);
const url = `https://models.ente.io/${name}`;
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
const body = res.body;
if (!body) throw new Error(`Received an null response for ${url}`);
// Save.
await writeStream(saveLocation, body);
log.info(`Downloaded CLIP model ${name}`);
};
/**
* Create an ONNX {@link InferenceSession} with some defaults.
*/
const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
// Restrict the number of threads to 1.
intraOpNumThreads: 1,
// Be more conservative with RAM usage.
enableCpuMemArena: false,
});
};
const cachedCLIPImageSession = makeCachedInferenceSession(
"clip-image-vit-32-float32.onnx",
351468764 /* 335.2 MB */,
);
/**
* Compute CLIP embeddings for an image.
*
* The embeddings are computed using ONNX runtime, with CLIP as the model.
*/
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
const session = await cachedCLIPImageSession();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
};
const t = Date.now();
const results = await session.run(feeds);
log.debugString(`ONNX/CLIP image embedding took ${Date.now() - t} ms`);
/* Need these model specific casts to type the result */
return ensure(results.output).data as Float32Array;
};
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
);
let _tokenizer: Tokenizer | undefined;
const getTokenizer = () => {
if (!_tokenizer) _tokenizer = new Tokenizer();
return _tokenizer;
};
/**
* Compute CLIP embeddings for an text snippet.
*
* The embeddings are computed using ONNX runtime, with CLIP as the model.
*/
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
const sessionOrSkip = await Promise.race([
cachedCLIPTextSession(),
// Wait for a tick to get the session promise to resolved the first time
// this code runs on each app start (and the model has been downloaded).
wait(0).then(() => 1),
]);
// Don't wait for the download to complete.
if (typeof sessionOrSkip == "number") {
log.info(
"Ignoring CLIP text embedding request because model download is pending",
);
return undefined;
}
const session = sessionOrSkip;
const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
const feeds = {
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
};
const t = Date.now();
const results = await session.run(feeds);
log.debugString(`ONNX/CLIP text embedding took ${Date.now() - t} ms`);
return ensure(results.output).data as Float32Array;
};
const cachedFaceDetectionSession = makeCachedInferenceSession(
"yolov5s_face_640_640_dynamic.onnx",
30762872 /* 29.3 MB */,
);
/**
* Face detection with the YOLO model and ONNX runtime.
*/
export const detectFaces = async (input: Float32Array) => {
const session = await cachedFaceDetectionSession();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
};
const t = Date.now();
const results = await session.run(feeds);
log.debugString(`ONNX/YOLO face detection took ${Date.now() - t} ms`);
return ensure(results.output).data;
};
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
"mobilefacenet_opset15.onnx",
5286998 /* 5 MB */,
);
/**
* Face embedding with the MobileFaceNet model and ONNX runtime.
*/
export const computeFaceEmbeddings = async (input: Float32Array) => {
// Dimension of each face (alias)
const mobileFaceNetFaceSize = 112;
// Smaller alias
const z = mobileFaceNetFaceSize;
// Size of each face's data in the batch
const n = Math.round(input.length / (z * z * 3));
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
const session = await cachedFaceEmbeddingSession();
const feeds = { img_inputs: inputTensor };
const t = Date.now();
const results = await session.run(feeds);
log.debugString(`ONNX/MFNT face embedding took ${Date.now() - t} ms`);
/* Need these model specific casts to extract and type the result */
return (results.embeddings as unknown as Record<string, unknown>)
.cpuData as Float32Array;
};

View File

@@ -1,126 +1,147 @@
/**
* @file ML related functionality, generic layer.
*
* @see also `ml-clip.ts`, `ml-face.ts`.
*
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
* for various tasks are not shipped with the app but are downloaded on demand.
*
* The primary reason for doing these tasks in the Node.js layer is so that we
* can use the binary ONNX runtime which is 10-20x faster than the WASM based
* web one.
* @file ML related functionality. This code runs in the main process.
*/
import { app, net } from "electron/main";
import { existsSync } from "fs";
import fs from "node:fs/promises";
import {
MessageChannelMain,
type BrowserWindow,
type UtilityProcess,
} from "electron";
import { app, utilityProcess } from "electron/main";
import path from "node:path";
import * as ort from "onnxruntime-node";
import log from "../log";
import { writeStream } from "../stream";
/** The active ML worker (utility) process, if any. */
let _child: UtilityProcess | undefined;
/**
* Return a function that can be used to trigger a download of the specified
* model, and the creating of an ONNX inference session initialized using it.
* Create a new ML worker process, terminating the older ones (if any).
*
* Multiple parallel calls to the returned function are fine, it ensures that
* the the model will be downloaded and the session created using it only once.
* All pending calls to it meanwhile will just await on the same promise.
* [Note: ML IPC]
*
* And once the promise is resolved, the create ONNX inference session will be
* cached, so subsequent calls to the returned function will just reuse the same
* session.
* The primary reason for doing ML tasks in the Node.js layer is so that we can
* use the binary ONNX runtime, which is 10-20x faster than the WASM one that
* can be used directly on the web layer.
*
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
* actively trigger a download until the returned function is called.
* For this to work, the main and renderer process need to communicate with each
* other. Further, in the web layer the ML indexing runs in a web worker (so as
* to not get in the way of the main thread). So the communication has 2 hops:
*
* @param modelName The name of the model to download.
* Node.js main <-> Renderer main <-> Renderer web worker
*
* @param modelByteSize The size in bytes that we expect the model to have. If
* the size of the downloaded model does not match the expected size, then we
* will redownload it.
* This naive way works, but has a problem. The Node.js main process is in the
* code path for delivering user events to the renderer process. The ML tasks we
* do take in the order of 100-300 ms (possibly more) for each individual
* inference. Thus, the Node.js main process is busy for those 100-300 ms, and
* does not forward events to the renderer, causing the UI to jitter.
*
* @returns a function. calling that function returns a promise to an ONNX
* session.
* The solution for this is to spawn an Electron UtilityProcess, which we can
* think of a regular Node.js child process. This frees up the Node.js main
* process, and would remove the jitter.
* https://www.electronjs.org/docs/latest/tutorial/process-model
*
* It would seem that this introduces another hop in our IPC
*
* Node.js utility process <-> Node.js main <-> ...
*
* but here we can use the special bit about Electron utility processes that
* separates them from regular Node.js child processes: their support for
* message ports. https://www.electronjs.org/docs/latest/tutorial/message-ports
*
* As a brief summary, a MessagePort is a web feature that allows two contexts
* to communicate. A pair of message ports is called a message channel. The cool
* thing about these is that we can pass these ports themselves over IPC.
*
* > One caveat here is that the message ports can only be passed using the
* > `postMessage` APIs, not the usual send/invoke APIs.
*
* So we
*
* 1. In the utility process create a message channel.
* 2. Spawn a utility process, and send one port of the pair to it.
* 3. Send the other port of the pair to the renderer.
*
* The renderer will forward that port to the web worker that is coordinating
* the ML indexing on the web layer. Thereafter, the utility process and web
* worker can directly talk to each other!
*
* Node.js utility process <-> Renderer web worker
*
* The RPC protocol is handled using comlink on both ends. The port itself needs
* to be relayed using `postMessage`.
*/
export const makeCachedInferenceSession = (
modelName: string,
modelByteSize: number,
) => {
let session: Promise<ort.InferenceSession> | undefined;
const download = () =>
modelPathDownloadingIfNeeded(modelName, modelByteSize);
const createSession = (modelPath: string) =>
createInferenceSession(modelPath);
const cachedInferenceSession = () => {
if (!session) session = download().then(createSession);
return session;
};
return cachedInferenceSession;
};
/**
* Download the model named {@link modelName} if we don't already have it.
*
* Also verify that the size of the model we get matches {@expectedByteSize} (if
* not, redownload it).
*
* @returns the path to the model on the local machine.
*/
const modelPathDownloadingIfNeeded = async (
modelName: string,
expectedByteSize: number,
) => {
const modelPath = modelSavePath(modelName);
if (!existsSync(modelPath)) {
log.info("CLIP image model not found, downloading");
await downloadModel(modelPath, modelName);
} else {
const size = (await fs.stat(modelPath)).size;
if (size !== expectedByteSize) {
log.error(
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
);
await downloadModel(modelPath, modelName);
}
export const createMLWorker = (window: BrowserWindow) => {
if (_child) {
log.debug(() => "Terminating previous ML worker process");
_child.kill();
_child = undefined;
}
return modelPath;
};
const { port1, port2 } = new MessageChannelMain();
/** Return the path where the given {@link modelName} is meant to be saved */
const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
const child = utilityProcess.fork(path.join(__dirname, "ml-worker.js"));
const userDataPath = app.getPath("userData");
child.postMessage({ userDataPath }, [port1]);
const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
// Download.
log.info(`Downloading ML model from ${name}`);
const url = `https://models.ente.io/${name}`;
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
const body = res.body;
if (!body) throw new Error(`Received an null response for ${url}`);
// Save.
await writeStream(saveLocation, body);
log.info(`Downloaded CLIP model ${name}`);
window.webContents.postMessage("createMLWorker/port", undefined, [port2]);
handleMessagesFromUtilityProcess(child);
_child = child;
};
/**
* Crete an ONNX {@link InferenceSession} with some defaults.
* Handle messages posted from the utility process.
*
* [Note: Using Electron APIs in UtilityProcess]
*
* Only a small subset of the Electron APIs are available to a UtilityProcess.
* As of writing (Jul 2024, Electron 30), only the following are available:
*
* - net
* - systemPreferences
*
* In particular, `app` is not available.
*
* We structure our code so that it doesn't need anything apart from `net`.
*
* For the other cases,
*
* - Additional parameters to the utility process are passed alongwith the
* initial message where we provide it the message port.
*
* - When we need to communicate from the utility process to the main process,
* we use the `parentPort` in the utility process.
*/
const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
// Restrict the number of threads to 1.
intraOpNumThreads: 1,
// Be more conservative with RAM usage.
enableCpuMemArena: false,
const handleMessagesFromUtilityProcess = (child: UtilityProcess) => {
const logTag = "[ml-worker]";
child.on("message", (m: unknown) => {
if (m && typeof m == "object" && "method" in m && "p" in m) {
const p = m.p;
switch (m.method) {
case "log.errorString":
if (typeof p == "string") {
log.error(`${logTag} ${p}`);
return;
}
break;
case "log.info":
if (Array.isArray(p)) {
// Need to cast from any[] to unknown[]
log.info(logTag, ...(p as unknown[]));
return;
}
break;
case "log.debugString":
if (typeof p == "string") {
log.debug(() => `${logTag} ${p}`);
return;
}
break;
default:
break;
}
}
log.info("Ignoring unknown message from ML worker", m);
});
};

View File

@@ -3,7 +3,6 @@
*/
import { net, protocol } from "electron/main";
import { randomUUID } from "node:crypto";
import { createWriteStream, existsSync } from "node:fs";
import fs from "node:fs/promises";
import { Readable } from "node:stream";
import { ReadableStream } from "node:stream/web";
@@ -12,6 +11,7 @@ import log from "./log";
import { ffmpegConvertToMP4 } from "./services/ffmpeg";
import { markClosableZip, openZip } from "./services/zip";
import { ensure } from "./utils/common";
import { writeStream } from "./utils/stream";
import {
deleteTempFile,
deleteTempFileIgnoringErrors,
@@ -142,6 +142,7 @@ const handleReadZip = async (zipPath: string, entryName: string) => {
// https://github.com/antelle/node-stream-zip/blob/master/node_stream_zip.js
const modifiedMs = entry.time;
// @ts-expect-error [Note: Node and web stream type mismatch]
return new Response(webReadableStream, {
headers: {
// We don't know the exact type, but it doesn't really matter, just
@@ -159,39 +160,6 @@ const handleWrite = async (path: string, request: Request) => {
return new Response("", { status: 200 });
};
/**
* Write a (web) ReadableStream to a file at the given {@link filePath}.
*
* The returned promise resolves when the write completes.
*
* @param filePath The local file system path where the file should be written.
*
* @param readableStream A web
* [ReadableStream](https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream).
*/
export const writeStream = (filePath: string, readableStream: ReadableStream) =>
writeNodeStream(filePath, Readable.fromWeb(readableStream));
const writeNodeStream = async (filePath: string, fileStream: Readable) => {
const writeable = createWriteStream(filePath);
fileStream.on("error", (err) => {
writeable.destroy(err); // Close the writable stream with an error
});
fileStream.pipe(writeable);
await new Promise((resolve, reject) => {
writeable.on("finish", resolve);
writeable.on("error", (err) => {
if (existsSync(filePath)) {
void fs.unlink(filePath);
}
reject(err);
});
});
};
/**
* A map from token to file paths for convert-to-mp4 requests that we have
* received.

View File

@@ -0,0 +1,42 @@
import type { Endpoint } from "comlink";
import type { MessagePortMain } from "electron";
/**
* An adaptation of the `nodeEndpoint` function from comlink suitable for use in
* TypeScript with an Electron utility process.
*
* This is an adaption of the following function from comlink:
* https://github.com/GoogleChromeLabs/comlink/blob/main/src/node-adapter.ts
*
* It has been modified (somewhat hackily) to be useful with an Electron
* MessagePortMain instead of a Node.js worker_thread. Only things that we
* currently need have been made to work as you can see by the abundant type
* casts. Caveat emptor.
*/
export const messagePortMainEndpoint = (mp: MessagePortMain): Endpoint => {
type NL = EventListenerOrEventListenerObject;
type EL = (data: Electron.MessageEvent) => void;
const listeners = new WeakMap<NL, EL>();
return {
postMessage: (message, transfer) => {
mp.postMessage(message, transfer as unknown as MessagePortMain[]);
},
addEventListener: (_, eh) => {
const l: EL = (data) =>
"handleEvent" in eh
? eh.handleEvent({ data } as MessageEvent)
: eh(data as unknown as MessageEvent);
mp.on("message", (data) => {
l(data);
});
listeners.set(eh, l);
},
removeEventListener: (_, eh) => {
const l = listeners.get(eh);
if (!l) return;
mp.off("message", l);
listeners.delete(eh);
},
start: mp.start.bind(mp),
};
};

View File

@@ -0,0 +1,39 @@
import { createWriteStream, existsSync } from "node:fs";
import fs from "node:fs/promises";
import { Readable } from "node:stream";
/**
* Write a (web) ReadableStream to a file at the given {@link filePath}.
*
* The returned promise resolves when the write completes.
*
* @param filePath The local file system path where the file should be written.
*
* @param readableStream A web
* [ReadableStream](https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream).
*
*/
export const writeStream = (
filePath: string,
readableStream: unknown /*ReadableStream*/, // @ts-expect-error [Note: Node and web stream type mismatch]
) => writeNodeStream(filePath, Readable.fromWeb(readableStream));
const writeNodeStream = async (filePath: string, fileStream: Readable) => {
const writeable = createWriteStream(filePath);
fileStream.on("error", (err) => {
writeable.destroy(err); // Close the writable stream with an error
});
fileStream.pipe(writeable);
await new Promise((resolve, reject) => {
writeable.on("finish", resolve);
writeable.on("error", (err) => {
if (existsSync(filePath)) {
void fs.unlink(filePath);
}
reject(err);
});
});
};

View File

@@ -36,6 +36,28 @@
* - [main] desktop/src/main/ipc.ts contains impl
*/
// This code runs in the (isolated) web layer. Contrary to the impression given
// by the Electron docs (as of 2024), the window object is actually available to
// the preload script, and it is necessary for legitimate uses too.
//
// > The isolated world is connected to the DOM just the same is the main world,
// > it is just the JS contexts that are separated.
// >
// > https://github.com/electron/electron/issues/27024#issuecomment-745618327
//
// Adding this reference here tells TypeScript that DOM typings (in particular,
// window) should be introduced in the ambient scope.
//
// [Note: Node and web stream type mismatch]
//
// Unfortunately, adding this reference causes the ReadableStream typings to
// break since lib.dom.d.ts adds its own incompatible definitions of
// ReadableStream to the global scope.
//
// https://github.com/DefinitelyTyped/DefinitelyTyped/discussions/68407
/// <reference lib="dom" />
import { contextBridge, ipcRenderer, webUtils } from "electron/renderer";
// While we can't import other code, we can import types since they're just
@@ -48,6 +70,19 @@ import type {
ZipItem,
} from "./types/ipc";
// - Infrastructure
// We need to wait until the renderer is ready before sending ports via
// postMessage, and this promise comes handy in such cases. We create the
// promise at the top level so that it is guaranteed to be registered before the
// load event is fired.
//
// See: https://www.electronjs.org/docs/latest/tutorial/message-ports
const windowLoaded = new Promise((resolve) => {
window.onload = resolve;
});
// - General
const appVersion = () => ipcRenderer.invoke("appVersion");
@@ -163,17 +198,15 @@ const ffmpegExec = (
// - ML
const computeCLIPImageEmbedding = (input: Float32Array) =>
ipcRenderer.invoke("computeCLIPImageEmbedding", input);
const computeCLIPTextEmbeddingIfAvailable = (text: string) =>
ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text);
const detectFaces = (input: Float32Array) =>
ipcRenderer.invoke("detectFaces", input);
const computeFaceEmbeddings = (input: Float32Array) =>
ipcRenderer.invoke("computeFaceEmbeddings", input);
const createMLWorker = () => {
ipcRenderer.send("createMLWorker");
ipcRenderer.on("createMLWorker/port", (event) => {
void windowLoaded.then(() => {
// "*"" is the origin to send to.
window.postMessage("createMLWorker/port", "*", event.ports);
});
});
};
// - Watch
@@ -281,8 +314,11 @@ const clearPendingUploads = () => ipcRenderer.invoke("clearPendingUploads");
* operation when it happens across threads.
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
*
* In our case though, we're not dealing with threads but separate processes. So
* the ArrayBuffer will be copied:
* In our case though, we're not dealing with threads but separate processes.
* Electron currently only supports transferring MessagePorts:
* https://github.com/electron/electron/issues/34905
*
* So the ArrayBuffer will be copied:
*
* > "parameters, errors and return values are **copied** when they're sent over
* > the bridge".
@@ -339,10 +375,7 @@ contextBridge.exposeInMainWorld("electron", {
// - ML
computeCLIPImageEmbedding,
computeCLIPTextEmbeddingIfAvailable,
detectFaces,
computeFaceEmbeddings,
createMLWorker,
// - Watch

View File

@@ -968,6 +968,11 @@ combined-stream@^1.0.8:
dependencies:
delayed-stream "~1.0.0"
comlink@^4.4.1:
version "4.4.1"
resolved "https://registry.yarnpkg.com/comlink/-/comlink-4.4.1.tgz#e568b8e86410b809e8600eb2cf40c189371ef981"
integrity sha512-+1dlx0aY5Jo1vHy/tSsIGpSkN4tS9rZSW8FIhG0JH/crs9wwweswIo/POr451r7bZww3hFbPAKnTpimzL/mm4Q==
commander@^5.0.0:
version "5.1.0"
resolved "https://registry.yarnpkg.com/commander/-/commander-5.1.0.tgz#46abbd1652f8e059bddaef99bbdcb2ad9cf179ae"

View File

@@ -82,7 +82,8 @@ export const register = async (): Promise<Registration> => {
// Register keypair with museum to get a pairing code.
let pairingCode: string | undefined;
// TODO: eslint has fixed this spurious warning, but we're not on the latest
// [TODO: spurious while(true) eslint warning].
// eslint has fixed this spurious warning, but we're not on the latest
// version yet, so add a disable.
// https://github.com/eslint/eslint/pull/18286
/* eslint-disable no-constant-condition */

View File

@@ -82,7 +82,7 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
<EnteMenuItem
endIcon={<ChevronRight />}
onClick={() => setOpenMLSettings(true)}
label={pt("ML search")}
label={pt("Face and magic search")}
/>
</MenuItemGroup>
</Box>

View File

@@ -85,14 +85,9 @@ export default function Preferences({ open, onClose, onRootClose }) {
<EnteMenuItem
endIcon={<ChevronRight />}
onClick={() => setOpenMLSettings(true)}
label={pt("ML search")}
label={pt("Face and magic search")}
/>
</MenuItemGroup>
<MenuSectionTitle
title={pt(
"Face recognition, magic search and more",
)}
/>
</Box>
)}
</Stack>

View File

@@ -22,7 +22,7 @@ export const photosLogout = async () => {
// See: [Note: Caching IDB instances in separate execution contexts].
try {
terminateMLWorker();
await terminateMLWorker();
} catch (e) {
ignoreError("face", e);
}

View File

@@ -1,13 +1,12 @@
import { isDesktop } from "@/base/app";
import { ensureElectron } from "@/base/electron";
import log from "@/base/log";
import { FileType } from "@/media/file-type";
import {
clipMatches,
isMLEnabled,
isMLSupported,
mlStatusSnapshot,
} from "@/new/photos/services/ml";
import { clipMatches } from "@/new/photos/services/ml/clip";
import type { Person } from "@/new/photos/services/ml/people";
import { EnteFile } from "@/new/photos/types/file";
import * as chrono from "chrono-node";
@@ -374,7 +373,7 @@ const searchClip = async (
searchPhrase: string,
): Promise<ClipSearchScores | undefined> => {
if (!isMLEnabled()) return undefined;
const matches = await clipMatches(searchPhrase, ensureElectron());
const matches = await clipMatches(searchPhrase);
log.debug(() => ["clip/scores", matches]);
return matches;
};

View File

@@ -335,57 +335,21 @@ export interface Electron {
// - ML
/**
* Return a CLIP embedding of the given image.
* Create a new ML worker, terminating the older ones (if any).
*
* See: [Note: Natural language search using CLIP]
* This creates a new Node.js utility process, and sets things up so that we
* can communicate directly with that utility process using a
* {@link MessagePort} that gets posted using "createMLWorker/port".
*
* The input is a opaque float32 array representing the image. The layout
* and exact encoding of the input is specific to our implementation and the
* ML model (CLIP) we use.
* At the other end of that port will be an object that conforms to the
* {@link ElectronMLWorker} interface.
*
* @returns A CLIP embedding (an array of 512 floating point values).
* For more details about the IPC flow, see: [Note: ML IPC].
*
* Note: For simplicity of implementation, we assume that there is at most
* one outstanding call to {@link createMLWorker}.
*/
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a CLIP embedding of the given image if we already have the model
* downloaded and prepped. If the model is not available return `undefined`.
*
* This differs from the other sibling ML functions in that it doesn't wait
* for the model download to finish. It does trigger a model download, but
* then immediately returns `undefined`. At some future point, when the
* model downloaded finishes, calls to this function will start returning
* the result we seek.
*
* The reason for doing it in this asymmetric way is because CLIP text
* embeddings are used as part of deducing user initiated search results,
* and we don't want to block that interaction on a large network request.
*
* See: [Note: Natural language search using CLIP]
*
* @param text The string whose embedding we want to compute.
*
* @returns A CLIP embedding.
*/
computeCLIPTextEmbeddingIfAvailable: (
text: string,
) => Promise<Float32Array | undefined>;
/**
* Detect faces in the given image using YOLO.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (YOLO) we use.
*/
detectFaces: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a MobileFaceNet embeddings for the given faces.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (MobileFaceNet) we use.
*/
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
createMLWorker: () => void;
// - Watch
@@ -574,6 +538,65 @@ export interface Electron {
clearPendingUploads: () => Promise<void>;
}
/**
* The shape of the object exposed by the Node.js ML worker process on the
* message port that the web layer obtains by doing {@link createMLWorker}.
*/
export interface ElectronMLWorker {
/**
* Return a CLIP embedding of the given image.
*
* See: [Note: Natural language search using CLIP]
*
* The input is a opaque float32 array representing the image. The layout
* and exact encoding of the input is specific to our implementation and the
* ML model (CLIP) we use.
*
* @returns A CLIP embedding (an array of 512 floating point values).
*/
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a CLIP embedding of the given image if we already have the model
* downloaded and prepped. If the model is not available return `undefined`.
*
* This differs from the other sibling ML functions in that it doesn't wait
* for the model download to finish. It does trigger a model download, but
* then immediately returns `undefined`. At some future point, when the
* model downloaded finishes, calls to this function will start returning
* the result we seek.
*
* The reason for doing it in this asymmetric way is because CLIP text
* embeddings are used as part of deducing user initiated search results,
* and we don't want to block that interaction on a large network request.
*
* See: [Note: Natural language search using CLIP]
*
* @param text The string whose embedding we want to compute.
*
* @returns A CLIP embedding.
*/
computeCLIPTextEmbeddingIfAvailable: (
text: string,
) => Promise<Float32Array | undefined>;
/**
* Detect faces in the given image using YOLO.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (YOLO) we use.
*/
detectFaces: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a MobileFaceNet embeddings for the given faces.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (MobileFaceNet) we use.
*/
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
}
/**
* Errors that have special semantics on the web side.
*

View File

@@ -28,7 +28,7 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
/** The class (T) exposed by the web worker */
public remote: Promise<Remote<InstanceType<T>>>;
/** The web worker */
private worker: Worker;
public worker: Worker;
/** An arbitrary name associated with this ComlinkWorker for debugging. */
private name: string;

View File

@@ -124,7 +124,7 @@ export const MLSettings: React.FC<MLSettingsProps> = ({
<Stack spacing={"4px"} py={"12px"}>
<Titlebar
onClose={onClose}
title={pt("ML search")}
title={pt("Face and magic search")}
onRootClose={onRootClose}
/>
{component}
@@ -305,7 +305,7 @@ const ManageML: React.FC<ManageMLProps> = ({
let status: string;
switch (phase) {
case "indexing":
status = pt("Indexing");
status = pt("Running");
break;
case "scheduled":
status = pt("Scheduled");
@@ -319,9 +319,9 @@ const ManageML: React.FC<ManageMLProps> = ({
const confirmDisableML = () => {
setDialogBoxAttributesV2({
title: pt("Disable ML search"),
title: pt("Disable face and magic search"),
content: pt(
"Do you want to disable ML search on all your devices?",
"Do you want to disable face and magic search on all your devices?",
),
close: { text: t("cancel") },
proceed: {
@@ -356,7 +356,7 @@ const ManageML: React.FC<ManageMLProps> = ({
justifyContent={"space-between"}
>
<Typography color="text.faint">
{pt("Status")}
{pt("Indexing")}
</Typography>
<Typography>{status}</Typography>
</Stack>

View File

@@ -42,7 +42,7 @@ export const MLSettingsBeta: React.FC<MLSettingsBetaProps> = ({
<Stack spacing={"4px"} py={"12px"}>
<Titlebar
onClose={onClose}
title={pt("ML search")}
title={pt("Face and magic search")}
onRootClose={onRootClose}
/>

View File

@@ -1,4 +1,5 @@
import { basename } from "@/base/file";
import type { ElectronMLWorker } from "@/base/types/ipc";
import { FileType } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import { ensure } from "@/utils/ensure";
@@ -7,7 +8,6 @@ import { renderableImageBlob } from "../../utils/file";
import { readStream } from "../../utils/native-stream";
import DownloadManager from "../download";
import type { UploadItem } from "../upload/types";
import type { MLWorkerElectron } from "./worker-types";
/**
* A pair of blobs - the original, and a possibly converted "renderable" one -
@@ -103,13 +103,14 @@ export const imageBitmapAndData = async (
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
* use the on-disk file instead of needing to download the original from remote.
*
* @param electron The {@link MLWorkerElectron} instance that allows us to call
* our Node.js layer for various functionality.
* @param electron The {@link ElectronMLWorker} instance that stands as a
* witness that we're actually running in our desktop app (and thus can safely
* call our Node.js layer for various functionality).
*/
export const indexableBlobs = async (
enteFile: EnteFile,
uploadItem: UploadItem | undefined,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<IndexableBlobs> =>
uploadItem
? await indexableUploadItemBlobs(enteFile, uploadItem, electron)
@@ -118,7 +119,7 @@ export const indexableBlobs = async (
const indexableUploadItemBlobs = async (
enteFile: EnteFile,
uploadItem: UploadItem,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
) => {
const fileType = enteFile.metadata.fileType;
let originalImageBlob: Blob | undefined;
@@ -149,7 +150,7 @@ const indexableUploadItemBlobs = async (
*/
const readNonVideoUploadItem = async (
uploadItem: UploadItem,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<File> => {
if (typeof uploadItem == "string" || Array.isArray(uploadItem)) {
const { response, lastModifiedMs } = await readStream(

View File

@@ -1,9 +1,9 @@
import type { Electron } from "@/base/types/ipc";
import type { ElectronMLWorker } from "@/base/types/ipc";
import type { ImageBitmapAndData } from "./blob";
import { clipIndexes } from "./db";
import { pixelRGBBicubic } from "./image";
import { dotProduct, norm } from "./math";
import type { MLWorkerElectron } from "./worker-types";
import type { CLIPMatches } from "./worker-types";
/**
* The version of the CLIP indexing pipeline implemented by the current client.
@@ -98,19 +98,19 @@ export type LocalCLIPIndex = CLIPIndex & {
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
* use the on-disk file instead of needing to download the original from remote.
*
* @param electron The {@link MLWorkerElectron} instance that allows us to call
* @param electron The {@link ElectronMLWorker} instance that allows us to call
* our Node.js layer to run the ONNX inference.
*/
export const indexCLIP = async (
image: ImageBitmapAndData,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<CLIPIndex> => ({
embedding: await computeEmbedding(image.data, electron),
});
const computeEmbedding = async (
imageData: ImageData,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<number[]> => {
const clipInput = convertToCLIPInput(imageData);
return normalized(await electron.computeCLIPImageEmbedding(clipInput));
@@ -167,26 +167,15 @@ const normalized = (embedding: Float32Array) => {
};
/**
* Use CLIP to perform a natural language search over image embeddings.
*
* @param searchPhrase The text entered by the user in the search box.
*
* @param electron The {@link Electron} instance to use to communicate with the
* native code running in our desktop app (the embedding happens in the native
* layer).
*
* It returns file (IDs) that should be shown in the search results. They're
* returned as a map from fileIDs to the scores they got (higher is better).
* This map will only contains entries whose score was above our minimum
* threshold.
* Find the files whose CLIP embedding "matches" the given {@link searchPhrase}.
*
* The result can also be `undefined`, which indicates that the download for the
* ML model is still in progress (trying again later should succeed).
*/
export const clipMatches = async (
searchPhrase: string,
electron: Electron,
): Promise<Map<number, number> | undefined> => {
electron: ElectronMLWorker,
): Promise<CLIPMatches | undefined> => {
const t = await electron.computeCLIPTextEmbeddingIfAvailable(searchPhrase);
if (!t) return undefined;

View File

@@ -7,6 +7,7 @@
//
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import type { ElectronMLWorker } from "@/base/types/ipc";
import type { EnteFile } from "@/new/photos/types/file";
import { Matrix } from "ml-matrix";
import { getSimilarityTransformation } from "similarity-transformation";
@@ -24,7 +25,6 @@ import {
warpAffineFloat32List,
} from "./image";
import { clamp } from "./math";
import type { MLWorkerElectron } from "./worker-types";
/**
* The version of the face indexing pipeline implemented by the current client.
@@ -236,13 +236,13 @@ export interface Box {
*
* @param image The file's contents.
*
* @param electron The {@link MLWorkerElectron} instance that allows us to call
* @param electron The {@link ElectronMLWorker} instance that allows us to call
* our Node.js layer to run the ONNX inference.
*/
export const indexFaces = async (
enteFile: EnteFile,
{ data: imageData }: ImageBitmapAndData,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<FaceIndex> => ({
width: imageData.width,
height: imageData.height,
@@ -252,7 +252,7 @@ export const indexFaces = async (
const indexFaces_ = async (
fileID: number,
imageData: ImageData,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<Face[]> => {
const { width, height } = imageData;
const imageDimensions = { width, height };
@@ -316,7 +316,7 @@ const indexFaces_ = async (
*/
const detectFaces = async (
imageData: ImageData,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<YOLOFaceDetection[]> => {
const rect = ({ width, height }: Dimensions) => ({
x: 0,
@@ -878,7 +878,7 @@ const mobileFaceNetEmbeddingSize = 192;
*/
const computeEmbeddings = async (
faceData: Float32Array,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
): Promise<Float32Array[]> => {
const outputData = await electron.computeFaceEmbeddings(faceData);

View File

@@ -6,17 +6,20 @@ import { isDesktop } from "@/base/app";
import { blobCache } from "@/base/blob-cache";
import { ensureElectron } from "@/base/electron";
import log from "@/base/log";
import type { Electron } from "@/base/types/ipc";
import { ComlinkWorker } from "@/base/worker/comlink-worker";
import { FileType } from "@/media/file-type";
import type { EnteFile } from "@/new/photos/types/file";
import { ensure } from "@/utils/ensure";
import { throttled } from "@/utils/promise";
import { proxy } from "comlink";
import { proxy, transfer } from "comlink";
import { isInternalUser } from "../feature-flags";
import { getRemoteFlag, updateRemoteFlag } from "../remote-store";
import type { UploadItem } from "../upload/types";
import { regenerateFaceCrops } from "./crop";
import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db";
import { MLWorker } from "./worker";
import type { CLIPMatches } from "./worker-types";
/**
* In-memory flag that tracks if ML is enabled.
@@ -33,7 +36,7 @@ import { MLWorker } from "./worker";
let _isMLEnabled = false;
/** Cached instance of the {@link ComlinkWorker} that wraps our web worker. */
let _comlinkWorker: ComlinkWorker<typeof MLWorker> | undefined;
let _comlinkWorker: Promise<ComlinkWorker<typeof MLWorker>> | undefined;
/**
* Subscriptions to {@link MLStatus}.
@@ -50,29 +53,28 @@ let _mlStatusListeners: (() => void)[] = [];
let _mlStatusSnapshot: MLStatus | undefined;
/** Lazily created, cached, instance of {@link MLWorker}. */
const worker = async () => {
if (!_comlinkWorker) _comlinkWorker = await createComlinkWorker();
return _comlinkWorker.remote;
};
const worker = () =>
(_comlinkWorker ??= createComlinkWorker()).then((cw) => cw.remote);
const createComlinkWorker = async () => {
const electron = ensureElectron();
const mlWorkerElectron = {
detectFaces: electron.detectFaces,
computeFaceEmbeddings: electron.computeFaceEmbeddings,
computeCLIPImageEmbedding: electron.computeCLIPImageEmbedding,
};
const delegate = {
workerDidProcessFile,
};
// Obtain a message port from the Electron layer.
const messagePort = await createMLWorker(electron);
const cw = new ComlinkWorker<typeof MLWorker>(
"ML",
new Worker(new URL("worker.ts", import.meta.url)),
);
await cw.remote.then((w) =>
w.init(proxy(mlWorkerElectron), proxy(delegate)),
// Forward the port to the web worker.
w.init(transfer(messagePort, [messagePort]), proxy(delegate)),
);
return cw;
};
@@ -85,13 +87,40 @@ const createComlinkWorker = async () => {
*
* It is also called when the user pauses or disables ML.
*/
export const terminateMLWorker = () => {
export const terminateMLWorker = async () => {
if (_comlinkWorker) {
_comlinkWorker.terminate();
await _comlinkWorker.then((cw) => cw.terminate());
_comlinkWorker = undefined;
}
};
/**
* Obtain a port from the Node.js layer that can be used to communicate with the
* ML worker process.
*/
const createMLWorker = (electron: Electron): Promise<MessagePort> => {
// The main process will do its thing, and send back the port it created to
// us by sending an message on the "createMLWorker/port" channel via the
// postMessage API. This roundabout way is needed because MessagePorts
// cannot be transferred via the usual send/invoke pattern.
const port = new Promise<MessagePort>((resolve) => {
const l = ({ source, data, ports }: MessageEvent) => {
// The source check verifies that the message is coming from our own
// preload script. The data is the message that was posted.
if (source == window && data == "createMLWorker/port") {
window.removeEventListener("message", l);
resolve(ensure(ports[0]));
}
};
window.addEventListener("message", l);
});
electron.createMLWorker();
return port;
};
/**
* Return true if the current client supports ML.
*
@@ -163,7 +192,7 @@ export const disableML = async () => {
await updateIsMLEnabledRemote(false);
setIsMLEnabledLocal(false);
_isMLEnabled = false;
terminateMLWorker();
await terminateMLWorker();
triggerStatusUpdate();
};
@@ -369,6 +398,22 @@ const setInterimScheduledStatus = () => {
const workerDidProcessFile = throttled(updateMLStatusSnapshot, 2000);
/**
* Use CLIP to perform a natural language search over image embeddings.
*
* @param searchPhrase The text entered by the user in the search box.
*
* It returns file (IDs) that should be shown in the search results, along with
* their scores.
*
* The result can also be `undefined`, which indicates that the download for the
* ML model is still in progress (trying again later should succeed).
*/
export const clipMatches = (
searchPhrase: string,
): Promise<CLIPMatches | undefined> =>
worker().then((w) => w.clipMatches(searchPhrase));
/**
* Return the IDs of all the faces in the given {@link enteFile} that are not
* associated with a person cluster.

View File

@@ -1,22 +1,7 @@
/**
* @file Type for the objects shared (as a Comlink proxy) by the main thread and
* the ML worker.
* @file Types for the objects shared between the main thread and the ML worker.
*/
/**
* A subset of {@link Electron} provided to the {@link MLWorker}.
*
* `globalThis.electron` does not exist in the execution context of web workers.
* So instead, we manually provide a proxy object of type
* {@link MLWorkerElectron} that exposes a subset of the functions from
* {@link Electron} that are needed by the code running in the ML web worker.
*/
export interface MLWorkerElectron {
detectFaces: (input: Float32Array) => Promise<Float32Array>;
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
}
/**
* Callbacks invoked by the worker at various points in the indexing pipeline to
* notify the main thread of events it might be interested in.
@@ -25,7 +10,18 @@ export interface MLWorkerDelegate {
/**
* Called whenever a file is processed during indexing.
*
* It is called both when the indexing was successful or failed.
* It is called both when the indexing was successful or it failed.
*/
workerDidProcessFile: () => void;
}
/**
* The result of file ids that should be considered as matches for a particular
* search phrase, each with their associated score.
*
* This is a map of file (IDs) that should be shown in the search results.
* They're returned as a map from fileIDs to the scores they got (higher is
* better). This map will only contains entries whose score was above our
* minimum threshold.
*/
export type CLIPMatches = Map<number, number>;

View File

@@ -3,14 +3,15 @@ import { isHTTP4xxError } from "@/base/http";
import { getKVN } from "@/base/kv";
import { ensureAuthToken } from "@/base/local-user";
import log from "@/base/log";
import type { ElectronMLWorker } from "@/base/types/ipc";
import type { EnteFile } from "@/new/photos/types/file";
import { fileLogID } from "@/new/photos/utils/file";
import { ensure } from "@/utils/ensure";
import { wait } from "@/utils/promise";
import { DOMParser } from "@xmldom/xmldom";
import { expose } from "comlink";
import { expose, wrap } from "comlink";
import downloadManager from "../download";
import { cmpNewLib2, extractRawExif } from "../exif";
import { cmpNewLib2, extractRawExif, type RawExifTags } from "../exif";
import { getAllLocalFiles, getLocalTrashedFiles } from "../files";
import type { UploadItem } from "../upload/types";
import {
@@ -18,7 +19,12 @@ import {
indexableBlobs,
type ImageBitmapAndData,
} from "./blob";
import { clipIndexingVersion, indexCLIP, type CLIPIndex } from "./clip";
import {
clipIndexingVersion,
clipMatches,
indexCLIP,
type CLIPIndex,
} from "./clip";
import { saveFaceCrops } from "./crop";
import {
indexableFileIDs,
@@ -29,10 +35,11 @@ import {
import {
fetchDerivedData,
putDerivedData,
type RawRemoteDerivedData,
type RemoteDerivedData,
} from "./embedding";
import { faceIndexingVersion, indexFaces, type FaceIndex } from "./face";
import type { MLWorkerDelegate, MLWorkerElectron } from "./worker-types";
import type { CLIPMatches, MLWorkerDelegate } from "./worker-types";
const idleDurationStart = 5; /* 5 seconds */
const idleDurationMax = 16 * 60; /* 16 minutes */
@@ -67,9 +74,12 @@ interface IndexableItem {
* - "backfillq": fetching remote embeddings of unindexed items, and then
* indexing them if needed,
* - "idle": in between state transitions.
*
* In addition, MLWorker can also be invoked for interactive tasks: in
* particular, for finding the closest CLIP match when the user does a search.
*/
export class MLWorker {
private electron: MLWorkerElectron | undefined;
private electron: ElectronMLWorker | undefined;
private delegate: MLWorkerDelegate | undefined;
private state: "idle" | "indexing" = "idle";
private liveQ: IndexableItem[] = [];
@@ -82,15 +92,16 @@ export class MLWorker {
* This is conceptually the constructor, however it is easier to have this
* as a separate function to avoid complicating the comlink types further.
*
* @param electron The {@link MLWorkerElectron} that allows the worker to
* use the functionality provided by our Node.js layer when running in the
* context of our desktop app.
* @param port A {@link MessagePort} that allows us to communicate with an
* Electron utility process running in the Node.js layer of our desktop app,
* exposing an object that conforms to the {@link ElectronMLWorker}
* interface.
*
* @param delegate The {@link MLWorkerDelegate} the worker can use to inform
* the main thread of interesting events.
*/
async init(electron: MLWorkerElectron, delegate?: MLWorkerDelegate) {
this.electron = electron;
async init(port: MessagePort, delegate: MLWorkerDelegate) {
this.electron = wrap<ElectronMLWorker>(port);
this.delegate = delegate;
// Initialize the downloadManager running in the web worker with the
// user's token. It'll be used to download files to index if needed.
@@ -176,6 +187,13 @@ export class MLWorker {
return this.state == "indexing";
}
/**
* Find {@link CLIPMatches} for a given {@link searchPhrase}.
*/
async clipMatches(searchPhrase: string): Promise<CLIPMatches | undefined> {
return clipMatches(searchPhrase, ensure(this.electron));
}
private async tick() {
log.debug(() => [
"ml/tick",
@@ -224,7 +242,7 @@ export class MLWorker {
}
/** Return the next batch of items to backfill (if any). */
async backfillQ() {
private async backfillQ() {
const userID = ensure(await getKVN("userID"));
// Find files that our local DB thinks need syncing.
const filesByID = await syncWithLocalFilesAndGetFilesToIndex(
@@ -256,7 +274,7 @@ expose(MLWorker);
*/
const indexNextBatch = async (
items: IndexableItem[],
electron: MLWorkerElectron,
electron: ElectronMLWorker,
delegate: MLWorkerDelegate | undefined,
) => {
// Don't try to index if we wouldn't be able to upload them anyway. The
@@ -270,18 +288,37 @@ const indexNextBatch = async (
// Nothing to do.
if (items.length == 0) return false;
// Index, keeping track if any of the items failed.
// Keep track if any of the items failed.
let allSuccess = true;
for (const item of items) {
try {
await index(item, electron);
delegate?.workerDidProcessFile();
// Possibly unnecessary, but let us drain the microtask queue.
await wait(0);
} catch (e) {
log.warn(`Skipping unindexable file ${item.enteFile.id}`, e);
allSuccess = false;
// Index up to 4 items simultaneously.
const tasks = new Array<Promise<void> | undefined>(4).fill(undefined);
let i = 0;
while (i < items.length) {
for (let j = 0; j < tasks.length; j++) {
if (i < items.length && !tasks[j]) {
tasks[j] = index(ensure(items[i++]), electron)
.then(() => {
tasks[j] = undefined;
})
.catch(() => {
allSuccess = false;
tasks[j] = undefined;
});
}
}
// Wait for at least one to complete (the other runners continue running
// even if one promise reaches the finish line).
await Promise.race(tasks);
// Let the main thread now we're doing something.
delegate?.workerDidProcessFile();
// Let us drain the microtask queue. This also gives a chance for other
// interactive tasks like `clipMatches` to run.
await wait(0);
}
// Return true if nothing failed.
@@ -374,7 +411,7 @@ const syncWithLocalFilesAndGetFilesToIndex = async (
*/
const index = async (
{ enteFile, uploadItem, remoteDerivedData }: IndexableItem,
electron: MLWorkerElectron,
electron: ElectronMLWorker,
) => {
const f = fileLogID(enteFile);
const fileID = enteFile.id;
@@ -463,10 +500,7 @@ const index = async (
[faceIndex, clipIndex, exif] = await Promise.all([
existingFaceIndex ?? indexFaces(enteFile, image, electron),
existingCLIPIndex ?? indexCLIP(image, electron),
existingExif ??
(originalImageBlob
? extractRawExif(originalImageBlob)
: undefined),
existingExif ?? tryExtractExif(originalImageBlob, f),
]);
} catch (e) {
// See: [Note: Transient and permanent indexing failures]
@@ -475,8 +509,12 @@ const index = async (
throw e;
}
if (originalImageBlob && exif)
await cmpNewLib2(enteFile, originalImageBlob, exif);
try {
if (originalImageBlob && exif)
await cmpNewLib2(enteFile, originalImageBlob, exif);
} catch (e) {
log.warn(`Skipping exif cmp for ${f}`, e);
}
log.debug(() => {
const ms = Date.now() - startTime;
@@ -504,22 +542,27 @@ const index = async (
// parts. See: [Note: Preserve unknown derived data fields].
const existingRawDerivedData = remoteDerivedData?.raw ?? {};
const rawDerivedData = {
const rawDerivedData: RawRemoteDerivedData = {
...existingRawDerivedData,
face: remoteFaceIndex,
clip: remoteCLIPIndex,
exif,
...(exif ? { exif } : {}),
};
log.debug(() => ["Uploading derived data", rawDerivedData]);
if (existingFaceIndex && existingCLIPIndex && !exif) {
// If we were indexing just for exif, but exif generation didn't
// happen, there is no need to upload.
} else {
log.debug(() => ["Uploading derived data", rawDerivedData]);
try {
await putDerivedData(enteFile, rawDerivedData);
} catch (e) {
// See: [Note: Transient and permanent indexing failures]
log.error(`Failed to put derived data for ${f}`, e);
if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id);
throw e;
try {
await putDerivedData(enteFile, rawDerivedData);
} catch (e) {
// See: [Note: Transient and permanent indexing failures]
log.error(`Failed to put derived data for ${f}`, e);
if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id);
throw e;
}
}
try {
@@ -550,3 +593,34 @@ const index = async (
image.bitmap.close();
}
};
/**
* A helper function that tries to extract the raw Exif, but returns `undefined`
* if something goes wrong (or it isn't possible) instead of throwing.
*
* Exif extraction is not a critical item, we don't want the actual indexing to
* fail because we were unable to extract Exif. This is not rare: one scenario
* is if we were trying to index a file in an exotic format. The ML indexing
* will succeed (because we convert it to a renderable blob), but the Exif
* extraction will fail (since it needs the original blob, but the original blob
* can be an arbitrary format).
*
* @param originalImageBlob A {@link Blob} containing the original data for the
* image (or the image component of a live photo) whose Exif we're trying to
* extract. If this is not available, we skip the extraction and return
* `undefined`.
*
* @param f The {@link fileLogID} for the file this blob corresponds to.
*/
export const tryExtractExif = async (
originalImageBlob: Blob | undefined,
f: string,
): Promise<RawExifTags | undefined> => {
if (!originalImageBlob) return undefined;
try {
return await extractRawExif(originalImageBlob);
} catch (e) {
log.warn(`Ignoring error during Exif extraction for ${f}`, e);
return undefined;
}
};

View File

@@ -6,8 +6,7 @@
* See: [Note: IPC streams].
*/
import type { Electron, ZipItem } from "@/base/types/ipc";
import type { MLWorkerElectron } from "../services/ml/worker-types";
import type { Electron, ElectronMLWorker, ZipItem } from "@/base/types/ipc";
/**
* Stream the given file or zip entry from the user's local file system.
@@ -18,7 +17,7 @@ import type { MLWorkerElectron } from "../services/ml/worker-types";
*
* To avoid accidentally invoking it in a non-desktop app context, it requires
* the {@link Electron} (or a functionally similar) object as a parameter (even
* though it doesn't use it).
* though it doesn't need or use it).
*
* @param pathOrZipItem Either the path on the file on the user's local file
* system whose contents we want to stream. Or a tuple containing the path to a
@@ -36,7 +35,7 @@ import type { MLWorkerElectron } from "../services/ml/worker-types";
* reading, expressed as epoch milliseconds.
*/
export const readStream = async (
_: Electron | MLWorkerElectron,
_: Electron | ElectronMLWorker,
pathOrZipItem: string | ZipItem,
): Promise<{ response: Response; size: number; lastModifiedMs: number }> => {
let url: URL;