[desktop] Use a utility process for ML indexing (#2583)
* This ensures that the UI remains responsive when indexing is running. * This allows us to parallelize the in-flight networking etc instead of waterfalling everything. * The actual CPU intensive indexing is still serial.
This commit is contained in:
@@ -28,6 +28,7 @@
|
||||
"auto-launch": "^5.0",
|
||||
"chokidar": "^3.6",
|
||||
"clip-bpe-js": "^0.0.6",
|
||||
"comlink": "^4.4.1",
|
||||
"compare-versions": "^6.1",
|
||||
"electron-log": "^5.1",
|
||||
"electron-store": "^8.2",
|
||||
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
attachFSWatchIPCHandlers,
|
||||
attachIPCHandlers,
|
||||
attachLogoutIPCHandler,
|
||||
attachMainWindowIPCHandlers,
|
||||
} from "./main/ipc";
|
||||
import log, { initLogging } from "./main/log";
|
||||
import { createApplicationMenu, createTrayContextMenu } from "./main/menu";
|
||||
@@ -121,6 +122,7 @@ const main = () => {
|
||||
// Setup IPC and streams.
|
||||
const watcher = createWatcher(mainWindow);
|
||||
attachIPCHandlers();
|
||||
attachMainWindowIPCHandlers(mainWindow);
|
||||
attachFSWatchIPCHandlers(watcher);
|
||||
attachLogoutIPCHandler(watcher);
|
||||
registerStreamProtocol();
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
*/
|
||||
|
||||
import type { FSWatcher } from "chokidar";
|
||||
import type { BrowserWindow } from "electron";
|
||||
import { ipcMain } from "electron/main";
|
||||
import type {
|
||||
CollectionMapping,
|
||||
@@ -42,11 +43,7 @@ import {
|
||||
} from "./services/fs";
|
||||
import { convertToJPEG, generateImageThumbnail } from "./services/image";
|
||||
import { logout } from "./services/logout";
|
||||
import {
|
||||
computeCLIPImageEmbedding,
|
||||
computeCLIPTextEmbeddingIfAvailable,
|
||||
} from "./services/ml-clip";
|
||||
import { computeFaceEmbeddings, detectFaces } from "./services/ml-face";
|
||||
import { createMLWorker } from "./services/ml";
|
||||
import {
|
||||
encryptionKey,
|
||||
lastShownChangelogVersion,
|
||||
@@ -184,24 +181,6 @@ export const attachIPCHandlers = () => {
|
||||
) => ffmpegExec(command, dataOrPathOrZipItem, outputFileExtension),
|
||||
);
|
||||
|
||||
// - ML
|
||||
|
||||
ipcMain.handle("computeCLIPImageEmbedding", (_, input: Float32Array) =>
|
||||
computeCLIPImageEmbedding(input),
|
||||
);
|
||||
|
||||
ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) =>
|
||||
computeCLIPTextEmbeddingIfAvailable(text),
|
||||
);
|
||||
|
||||
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
|
||||
detectFaces(input),
|
||||
);
|
||||
|
||||
ipcMain.handle("computeFaceEmbeddings", (_, input: Float32Array) =>
|
||||
computeFaceEmbeddings(input),
|
||||
);
|
||||
|
||||
// - Upload
|
||||
|
||||
ipcMain.handle("listZipItems", (_, zipPath: string) =>
|
||||
@@ -231,6 +210,16 @@ export const attachIPCHandlers = () => {
|
||||
ipcMain.handle("clearPendingUploads", () => clearPendingUploads());
|
||||
};
|
||||
|
||||
/**
|
||||
* A subset of {@link attachIPCHandlers} for functions that need a reference to
|
||||
* the main window to do their thing.
|
||||
*/
|
||||
export const attachMainWindowIPCHandlers = (mainWindow: BrowserWindow) => {
|
||||
// - ML
|
||||
|
||||
ipcMain.on("createMLWorker", () => createMLWorker(mainWindow));
|
||||
};
|
||||
|
||||
/**
|
||||
* Sibling of {@link attachIPCHandlers} that attaches handlers specific to the
|
||||
* watch folder functionality.
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
/**
|
||||
* @file Compute CLIP embeddings for images and text.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*/
|
||||
|
||||
import Tokenizer from "clip-bpe-js";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { ensure, wait } from "../utils/common";
|
||||
import { makeCachedInferenceSession } from "./ml";
|
||||
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
"clip-image-vit-32-float32.onnx",
|
||||
351468764 /* 335.2 MB */,
|
||||
);
|
||||
|
||||
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
|
||||
const session = await cachedCLIPImageSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
|
||||
};
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to type the result */
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
"clip-text-vit-32-uint8.onnx",
|
||||
64173509 /* 61.2 MB */,
|
||||
);
|
||||
|
||||
let _tokenizer: Tokenizer | undefined;
|
||||
const getTokenizer = () => {
|
||||
if (!_tokenizer) _tokenizer = new Tokenizer();
|
||||
return _tokenizer;
|
||||
};
|
||||
|
||||
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||
const sessionOrSkip = await Promise.race([
|
||||
cachedCLIPTextSession(),
|
||||
// Wait for a tick to get the session promise to resolved the first time
|
||||
// this code runs on each app start (and the model has been downloaded).
|
||||
wait(0).then(() => 1),
|
||||
]);
|
||||
|
||||
// Don't wait for the download to complete.
|
||||
if (typeof sessionOrSkip == "number") {
|
||||
log.info(
|
||||
"Ignoring CLIP text embedding request because model download is pending",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const session = sessionOrSkip;
|
||||
const t = Date.now();
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
const feeds = {
|
||||
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
||||
};
|
||||
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/CLIP text embedding took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
@@ -1,53 +0,0 @@
|
||||
/**
|
||||
* @file Various face recognition related tasks.
|
||||
*
|
||||
* - Face detection with the YOLO model.
|
||||
* - Face embedding with the MobileFaceNet model.
|
||||
*
|
||||
* The runtime used is ONNX.
|
||||
*/
|
||||
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { ensure } from "../utils/common";
|
||||
import { makeCachedInferenceSession } from "./ml";
|
||||
|
||||
const cachedFaceDetectionSession = makeCachedInferenceSession(
|
||||
"yolov5s_face_640_640_dynamic.onnx",
|
||||
30762872 /* 29.3 MB */,
|
||||
);
|
||||
|
||||
export const detectFaces = async (input: Float32Array) => {
|
||||
const session = await cachedFaceDetectionSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
|
||||
};
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data;
|
||||
};
|
||||
|
||||
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
|
||||
"mobilefacenet_opset15.onnx",
|
||||
5286998 /* 5 MB */,
|
||||
);
|
||||
|
||||
export const computeFaceEmbeddings = async (input: Float32Array) => {
|
||||
// Dimension of each face (alias)
|
||||
const mobileFaceNetFaceSize = 112;
|
||||
// Smaller alias
|
||||
const z = mobileFaceNetFaceSize;
|
||||
// Size of each face's data in the batch
|
||||
const n = Math.round(input.length / (z * z * 3));
|
||||
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
|
||||
|
||||
const session = await cachedFaceEmbeddingSession();
|
||||
const t = Date.now();
|
||||
const feeds = { img_inputs: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to extract and type the result */
|
||||
return (results.embeddings as unknown as Record<string, unknown>)
|
||||
.cpuData as Float32Array;
|
||||
};
|
||||
315
desktop/src/main/services/ml-worker.ts
Normal file
315
desktop/src/main/services/ml-worker.ts
Normal file
@@ -0,0 +1,315 @@
|
||||
/**
|
||||
* @file ML related tasks. This code runs in a utility process.
|
||||
*
|
||||
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
|
||||
* for various tasks are not shipped with the app but are downloaded on demand.
|
||||
*/
|
||||
|
||||
// See [Note: Using Electron APIs in UtilityProcess] about what we can and
|
||||
// cannot import.
|
||||
|
||||
import Tokenizer from "clip-bpe-js";
|
||||
import { expose } from "comlink";
|
||||
import { net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import { messagePortMainEndpoint } from "../utils/comlink";
|
||||
import { ensure, wait } from "../utils/common";
|
||||
import { writeStream } from "../utils/stream";
|
||||
|
||||
/**
|
||||
* We cannot do
|
||||
*
|
||||
* import log from "../log";
|
||||
*
|
||||
* because that requires the Electron APIs that are not available to a utility
|
||||
* process (See: [Note: Using Electron APIs in UtilityProcess]). But even if
|
||||
* that were to work, logging will still be problematic since we'd try opening
|
||||
* the log file from two different Node.js processes (this one, and the main
|
||||
* one), and I didn't find any indication in the electron-log repository that
|
||||
* the log file's integrity would be maintained in such cases.
|
||||
*
|
||||
* So instead we create this proxy log object that uses `process.parentPort` to
|
||||
* transport the logs over to the main process.
|
||||
*/
|
||||
const log = {
|
||||
/**
|
||||
* Unlike the real {@link log.error}, this accepts only the first string
|
||||
* argument, not the second optional error one.
|
||||
*/
|
||||
errorString: (s: string) => mainProcess("log.errorString", s),
|
||||
info: (...ms: unknown[]) => mainProcess("log.info", ms),
|
||||
/**
|
||||
* Unlike the real {@link log.debug}, this is (a) eagerly evaluated, and (b)
|
||||
* accepts only strings.
|
||||
*/
|
||||
debugString: (s: string) => mainProcess("log.debugString", s),
|
||||
};
|
||||
|
||||
/**
|
||||
* Send a message to the main process using a barebones RPC protocol.
|
||||
*/
|
||||
const mainProcess = (method: string, param: unknown) =>
|
||||
process.parentPort.postMessage({ method, p: param });
|
||||
|
||||
log.debugString(`Started ML worker process`);
|
||||
|
||||
process.parentPort.once("message", (e) => {
|
||||
// Initialize ourselves with the data we got from our parent.
|
||||
parseInitData(e.data);
|
||||
// Expose an instance of `ElectronMLWorker` on the port we got from our
|
||||
// parent.
|
||||
expose(
|
||||
{
|
||||
computeCLIPImageEmbedding,
|
||||
computeCLIPTextEmbeddingIfAvailable,
|
||||
detectFaces,
|
||||
computeFaceEmbeddings,
|
||||
},
|
||||
messagePortMainEndpoint(ensure(e.ports[0])),
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* We cannot access Electron's {@link app} object within a utility process, so
|
||||
* we pass the value of `app.getPath("userData")` during initialization, and it
|
||||
* can be subsequently retrieved from here.
|
||||
*/
|
||||
let _userDataPath: string | undefined;
|
||||
|
||||
/** Equivalent to app.getPath("userData") */
|
||||
const userDataPath = () => ensure(_userDataPath);
|
||||
|
||||
const parseInitData = (data: unknown) => {
|
||||
if (
|
||||
data &&
|
||||
typeof data == "object" &&
|
||||
"userDataPath" in data &&
|
||||
typeof data.userDataPath == "string"
|
||||
) {
|
||||
_userDataPath = data.userDataPath;
|
||||
} else {
|
||||
log.errorString("Unparseable initialization data");
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Return a function that can be used to trigger a download of the specified
|
||||
* model, and the creating of an ONNX inference session initialized using it.
|
||||
*
|
||||
* Multiple parallel calls to the returned function are fine, it ensures that
|
||||
* the the model will be downloaded and the session created using it only once.
|
||||
* All pending calls to it meanwhile will just await on the same promise.
|
||||
*
|
||||
* And once the promise is resolved, the create ONNX inference session will be
|
||||
* cached, so subsequent calls to the returned function will just reuse the same
|
||||
* session.
|
||||
*
|
||||
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
|
||||
* actively trigger a download until the returned function is called.
|
||||
*
|
||||
* @param modelName The name of the model to download.
|
||||
*
|
||||
* @param modelByteSize The size in bytes that we expect the model to have. If
|
||||
* the size of the downloaded model does not match the expected size, then we
|
||||
* will redownload it.
|
||||
*
|
||||
* @returns a function. calling that function returns a promise to an ONNX
|
||||
* session.
|
||||
*/
|
||||
const makeCachedInferenceSession = (
|
||||
modelName: string,
|
||||
modelByteSize: number,
|
||||
) => {
|
||||
let session: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const download = () =>
|
||||
modelPathDownloadingIfNeeded(modelName, modelByteSize);
|
||||
|
||||
const createSession = (modelPath: string) =>
|
||||
createInferenceSession(modelPath);
|
||||
|
||||
const cachedInferenceSession = () => {
|
||||
if (!session) session = download().then(createSession);
|
||||
return session;
|
||||
};
|
||||
|
||||
return cachedInferenceSession;
|
||||
};
|
||||
|
||||
/**
|
||||
* Download the model named {@link modelName} if we don't already have it.
|
||||
*
|
||||
* Also verify that the size of the model we get matches {@expectedByteSize} (if
|
||||
* not, redownload it).
|
||||
*
|
||||
* @returns the path to the model on the local machine.
|
||||
*/
|
||||
const modelPathDownloadingIfNeeded = async (
|
||||
modelName: string,
|
||||
expectedByteSize: number,
|
||||
) => {
|
||||
const modelPath = modelSavePath(modelName);
|
||||
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
await downloadModel(modelPath, modelName);
|
||||
} else {
|
||||
const size = (await fs.stat(modelPath)).size;
|
||||
if (size !== expectedByteSize) {
|
||||
log.errorString(
|
||||
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
|
||||
);
|
||||
await downloadModel(modelPath, modelName);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
const modelSavePath = (modelName: string) =>
|
||||
path.join(userDataPath(), "models", modelName);
|
||||
|
||||
const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download.
|
||||
log.info(`Downloading ML model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
const body = res.body;
|
||||
if (!body) throw new Error(`Received an null response for ${url}`);
|
||||
// Save.
|
||||
await writeStream(saveLocation, body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Create an ONNX {@link InferenceSession} with some defaults.
|
||||
*/
|
||||
const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
// Restrict the number of threads to 1.
|
||||
intraOpNumThreads: 1,
|
||||
// Be more conservative with RAM usage.
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
};
|
||||
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
"clip-image-vit-32-float32.onnx",
|
||||
351468764 /* 335.2 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
* Compute CLIP embeddings for an image.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*/
|
||||
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
|
||||
const session = await cachedCLIPImageSession();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
|
||||
};
|
||||
const t = Date.now();
|
||||
const results = await session.run(feeds);
|
||||
log.debugString(`ONNX/CLIP image embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to type the result */
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
"clip-text-vit-32-uint8.onnx",
|
||||
64173509 /* 61.2 MB */,
|
||||
);
|
||||
|
||||
let _tokenizer: Tokenizer | undefined;
|
||||
const getTokenizer = () => {
|
||||
if (!_tokenizer) _tokenizer = new Tokenizer();
|
||||
return _tokenizer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Compute CLIP embeddings for an text snippet.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*/
|
||||
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||
const sessionOrSkip = await Promise.race([
|
||||
cachedCLIPTextSession(),
|
||||
// Wait for a tick to get the session promise to resolved the first time
|
||||
// this code runs on each app start (and the model has been downloaded).
|
||||
wait(0).then(() => 1),
|
||||
]);
|
||||
|
||||
// Don't wait for the download to complete.
|
||||
if (typeof sessionOrSkip == "number") {
|
||||
log.info(
|
||||
"Ignoring CLIP text embedding request because model download is pending",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const session = sessionOrSkip;
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
const feeds = {
|
||||
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
||||
};
|
||||
|
||||
const t = Date.now();
|
||||
const results = await session.run(feeds);
|
||||
log.debugString(`ONNX/CLIP text embedding took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
const cachedFaceDetectionSession = makeCachedInferenceSession(
|
||||
"yolov5s_face_640_640_dynamic.onnx",
|
||||
30762872 /* 29.3 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
* Face detection with the YOLO model and ONNX runtime.
|
||||
*/
|
||||
export const detectFaces = async (input: Float32Array) => {
|
||||
const session = await cachedFaceDetectionSession();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
|
||||
};
|
||||
const t = Date.now();
|
||||
const results = await session.run(feeds);
|
||||
log.debugString(`ONNX/YOLO face detection took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data;
|
||||
};
|
||||
|
||||
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
|
||||
"mobilefacenet_opset15.onnx",
|
||||
5286998 /* 5 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
* Face embedding with the MobileFaceNet model and ONNX runtime.
|
||||
*/
|
||||
export const computeFaceEmbeddings = async (input: Float32Array) => {
|
||||
// Dimension of each face (alias)
|
||||
const mobileFaceNetFaceSize = 112;
|
||||
// Smaller alias
|
||||
const z = mobileFaceNetFaceSize;
|
||||
// Size of each face's data in the batch
|
||||
const n = Math.round(input.length / (z * z * 3));
|
||||
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
|
||||
|
||||
const session = await cachedFaceEmbeddingSession();
|
||||
const feeds = { img_inputs: inputTensor };
|
||||
const t = Date.now();
|
||||
const results = await session.run(feeds);
|
||||
log.debugString(`ONNX/MFNT face embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to extract and type the result */
|
||||
return (results.embeddings as unknown as Record<string, unknown>)
|
||||
.cpuData as Float32Array;
|
||||
};
|
||||
@@ -1,126 +1,147 @@
|
||||
/**
|
||||
* @file ML related functionality, generic layer.
|
||||
*
|
||||
* @see also `ml-clip.ts`, `ml-face.ts`.
|
||||
*
|
||||
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
|
||||
* for various tasks are not shipped with the app but are downloaded on demand.
|
||||
*
|
||||
* The primary reason for doing these tasks in the Node.js layer is so that we
|
||||
* can use the binary ONNX runtime which is 10-20x faster than the WASM based
|
||||
* web one.
|
||||
* @file ML related functionality. This code runs in the main process.
|
||||
*/
|
||||
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import {
|
||||
MessageChannelMain,
|
||||
type BrowserWindow,
|
||||
type UtilityProcess,
|
||||
} from "electron";
|
||||
import { app, utilityProcess } from "electron/main";
|
||||
import path from "node:path";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { writeStream } from "../stream";
|
||||
|
||||
/** The active ML worker (utility) process, if any. */
|
||||
let _child: UtilityProcess | undefined;
|
||||
|
||||
/**
|
||||
* Return a function that can be used to trigger a download of the specified
|
||||
* model, and the creating of an ONNX inference session initialized using it.
|
||||
* Create a new ML worker process, terminating the older ones (if any).
|
||||
*
|
||||
* Multiple parallel calls to the returned function are fine, it ensures that
|
||||
* the the model will be downloaded and the session created using it only once.
|
||||
* All pending calls to it meanwhile will just await on the same promise.
|
||||
* [Note: ML IPC]
|
||||
*
|
||||
* And once the promise is resolved, the create ONNX inference session will be
|
||||
* cached, so subsequent calls to the returned function will just reuse the same
|
||||
* session.
|
||||
* The primary reason for doing ML tasks in the Node.js layer is so that we can
|
||||
* use the binary ONNX runtime, which is 10-20x faster than the WASM one that
|
||||
* can be used directly on the web layer.
|
||||
*
|
||||
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
|
||||
* actively trigger a download until the returned function is called.
|
||||
* For this to work, the main and renderer process need to communicate with each
|
||||
* other. Further, in the web layer the ML indexing runs in a web worker (so as
|
||||
* to not get in the way of the main thread). So the communication has 2 hops:
|
||||
*
|
||||
* @param modelName The name of the model to download.
|
||||
* Node.js main <-> Renderer main <-> Renderer web worker
|
||||
*
|
||||
* @param modelByteSize The size in bytes that we expect the model to have. If
|
||||
* the size of the downloaded model does not match the expected size, then we
|
||||
* will redownload it.
|
||||
* This naive way works, but has a problem. The Node.js main process is in the
|
||||
* code path for delivering user events to the renderer process. The ML tasks we
|
||||
* do take in the order of 100-300 ms (possibly more) for each individual
|
||||
* inference. Thus, the Node.js main process is busy for those 100-300 ms, and
|
||||
* does not forward events to the renderer, causing the UI to jitter.
|
||||
*
|
||||
* @returns a function. calling that function returns a promise to an ONNX
|
||||
* session.
|
||||
* The solution for this is to spawn an Electron UtilityProcess, which we can
|
||||
* think of a regular Node.js child process. This frees up the Node.js main
|
||||
* process, and would remove the jitter.
|
||||
* https://www.electronjs.org/docs/latest/tutorial/process-model
|
||||
*
|
||||
* It would seem that this introduces another hop in our IPC
|
||||
*
|
||||
* Node.js utility process <-> Node.js main <-> ...
|
||||
*
|
||||
* but here we can use the special bit about Electron utility processes that
|
||||
* separates them from regular Node.js child processes: their support for
|
||||
* message ports. https://www.electronjs.org/docs/latest/tutorial/message-ports
|
||||
*
|
||||
* As a brief summary, a MessagePort is a web feature that allows two contexts
|
||||
* to communicate. A pair of message ports is called a message channel. The cool
|
||||
* thing about these is that we can pass these ports themselves over IPC.
|
||||
*
|
||||
* > One caveat here is that the message ports can only be passed using the
|
||||
* > `postMessage` APIs, not the usual send/invoke APIs.
|
||||
*
|
||||
* So we
|
||||
*
|
||||
* 1. In the utility process create a message channel.
|
||||
* 2. Spawn a utility process, and send one port of the pair to it.
|
||||
* 3. Send the other port of the pair to the renderer.
|
||||
*
|
||||
* The renderer will forward that port to the web worker that is coordinating
|
||||
* the ML indexing on the web layer. Thereafter, the utility process and web
|
||||
* worker can directly talk to each other!
|
||||
*
|
||||
* Node.js utility process <-> Renderer web worker
|
||||
*
|
||||
* The RPC protocol is handled using comlink on both ends. The port itself needs
|
||||
* to be relayed using `postMessage`.
|
||||
*/
|
||||
export const makeCachedInferenceSession = (
|
||||
modelName: string,
|
||||
modelByteSize: number,
|
||||
) => {
|
||||
let session: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const download = () =>
|
||||
modelPathDownloadingIfNeeded(modelName, modelByteSize);
|
||||
|
||||
const createSession = (modelPath: string) =>
|
||||
createInferenceSession(modelPath);
|
||||
|
||||
const cachedInferenceSession = () => {
|
||||
if (!session) session = download().then(createSession);
|
||||
return session;
|
||||
};
|
||||
|
||||
return cachedInferenceSession;
|
||||
};
|
||||
|
||||
/**
|
||||
* Download the model named {@link modelName} if we don't already have it.
|
||||
*
|
||||
* Also verify that the size of the model we get matches {@expectedByteSize} (if
|
||||
* not, redownload it).
|
||||
*
|
||||
* @returns the path to the model on the local machine.
|
||||
*/
|
||||
const modelPathDownloadingIfNeeded = async (
|
||||
modelName: string,
|
||||
expectedByteSize: number,
|
||||
) => {
|
||||
const modelPath = modelSavePath(modelName);
|
||||
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
await downloadModel(modelPath, modelName);
|
||||
} else {
|
||||
const size = (await fs.stat(modelPath)).size;
|
||||
if (size !== expectedByteSize) {
|
||||
log.error(
|
||||
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
|
||||
);
|
||||
await downloadModel(modelPath, modelName);
|
||||
}
|
||||
export const createMLWorker = (window: BrowserWindow) => {
|
||||
if (_child) {
|
||||
log.debug(() => "Terminating previous ML worker process");
|
||||
_child.kill();
|
||||
_child = undefined;
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
const { port1, port2 } = new MessageChannelMain();
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
const child = utilityProcess.fork(path.join(__dirname, "ml-worker.js"));
|
||||
const userDataPath = app.getPath("userData");
|
||||
child.postMessage({ userDataPath }, [port1]);
|
||||
|
||||
const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download.
|
||||
log.info(`Downloading ML model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
const body = res.body;
|
||||
if (!body) throw new Error(`Received an null response for ${url}`);
|
||||
// Save.
|
||||
await writeStream(saveLocation, body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
window.webContents.postMessage("createMLWorker/port", undefined, [port2]);
|
||||
|
||||
handleMessagesFromUtilityProcess(child);
|
||||
|
||||
_child = child;
|
||||
};
|
||||
|
||||
/**
|
||||
* Crete an ONNX {@link InferenceSession} with some defaults.
|
||||
* Handle messages posted from the utility process.
|
||||
*
|
||||
* [Note: Using Electron APIs in UtilityProcess]
|
||||
*
|
||||
* Only a small subset of the Electron APIs are available to a UtilityProcess.
|
||||
* As of writing (Jul 2024, Electron 30), only the following are available:
|
||||
*
|
||||
* - net
|
||||
* - systemPreferences
|
||||
*
|
||||
* In particular, `app` is not available.
|
||||
*
|
||||
* We structure our code so that it doesn't need anything apart from `net`.
|
||||
*
|
||||
* For the other cases,
|
||||
*
|
||||
* - Additional parameters to the utility process are passed alongwith the
|
||||
* initial message where we provide it the message port.
|
||||
*
|
||||
* - When we need to communicate from the utility process to the main process,
|
||||
* we use the `parentPort` in the utility process.
|
||||
*/
|
||||
const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
// Restrict the number of threads to 1.
|
||||
intraOpNumThreads: 1,
|
||||
// Be more conservative with RAM usage.
|
||||
enableCpuMemArena: false,
|
||||
const handleMessagesFromUtilityProcess = (child: UtilityProcess) => {
|
||||
const logTag = "[ml-worker]";
|
||||
child.on("message", (m: unknown) => {
|
||||
if (m && typeof m == "object" && "method" in m && "p" in m) {
|
||||
const p = m.p;
|
||||
switch (m.method) {
|
||||
case "log.errorString":
|
||||
if (typeof p == "string") {
|
||||
log.error(`${logTag} ${p}`);
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case "log.info":
|
||||
if (Array.isArray(p)) {
|
||||
// Need to cast from any[] to unknown[]
|
||||
log.info(logTag, ...(p as unknown[]));
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case "log.debugString":
|
||||
if (typeof p == "string") {
|
||||
log.debug(() => `${logTag} ${p}`);
|
||||
return;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
log.info("Ignoring unknown message from ML worker", m);
|
||||
});
|
||||
};
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
*/
|
||||
import { net, protocol } from "electron/main";
|
||||
import { randomUUID } from "node:crypto";
|
||||
import { createWriteStream, existsSync } from "node:fs";
|
||||
import fs from "node:fs/promises";
|
||||
import { Readable } from "node:stream";
|
||||
import { ReadableStream } from "node:stream/web";
|
||||
@@ -12,6 +11,7 @@ import log from "./log";
|
||||
import { ffmpegConvertToMP4 } from "./services/ffmpeg";
|
||||
import { markClosableZip, openZip } from "./services/zip";
|
||||
import { ensure } from "./utils/common";
|
||||
import { writeStream } from "./utils/stream";
|
||||
import {
|
||||
deleteTempFile,
|
||||
deleteTempFileIgnoringErrors,
|
||||
@@ -142,6 +142,7 @@ const handleReadZip = async (zipPath: string, entryName: string) => {
|
||||
// https://github.com/antelle/node-stream-zip/blob/master/node_stream_zip.js
|
||||
const modifiedMs = entry.time;
|
||||
|
||||
// @ts-expect-error [Note: Node and web stream type mismatch]
|
||||
return new Response(webReadableStream, {
|
||||
headers: {
|
||||
// We don't know the exact type, but it doesn't really matter, just
|
||||
@@ -159,39 +160,6 @@ const handleWrite = async (path: string, request: Request) => {
|
||||
return new Response("", { status: 200 });
|
||||
};
|
||||
|
||||
/**
|
||||
* Write a (web) ReadableStream to a file at the given {@link filePath}.
|
||||
*
|
||||
* The returned promise resolves when the write completes.
|
||||
*
|
||||
* @param filePath The local file system path where the file should be written.
|
||||
*
|
||||
* @param readableStream A web
|
||||
* [ReadableStream](https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream).
|
||||
*/
|
||||
export const writeStream = (filePath: string, readableStream: ReadableStream) =>
|
||||
writeNodeStream(filePath, Readable.fromWeb(readableStream));
|
||||
|
||||
const writeNodeStream = async (filePath: string, fileStream: Readable) => {
|
||||
const writeable = createWriteStream(filePath);
|
||||
|
||||
fileStream.on("error", (err) => {
|
||||
writeable.destroy(err); // Close the writable stream with an error
|
||||
});
|
||||
|
||||
fileStream.pipe(writeable);
|
||||
|
||||
await new Promise((resolve, reject) => {
|
||||
writeable.on("finish", resolve);
|
||||
writeable.on("error", (err) => {
|
||||
if (existsSync(filePath)) {
|
||||
void fs.unlink(filePath);
|
||||
}
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* A map from token to file paths for convert-to-mp4 requests that we have
|
||||
* received.
|
||||
|
||||
42
desktop/src/main/utils/comlink.ts
Normal file
42
desktop/src/main/utils/comlink.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import type { Endpoint } from "comlink";
|
||||
import type { MessagePortMain } from "electron";
|
||||
|
||||
/**
|
||||
* An adaptation of the `nodeEndpoint` function from comlink suitable for use in
|
||||
* TypeScript with an Electron utility process.
|
||||
*
|
||||
* This is an adaption of the following function from comlink:
|
||||
* https://github.com/GoogleChromeLabs/comlink/blob/main/src/node-adapter.ts
|
||||
*
|
||||
* It has been modified (somewhat hackily) to be useful with an Electron
|
||||
* MessagePortMain instead of a Node.js worker_thread. Only things that we
|
||||
* currently need have been made to work as you can see by the abundant type
|
||||
* casts. Caveat emptor.
|
||||
*/
|
||||
export const messagePortMainEndpoint = (mp: MessagePortMain): Endpoint => {
|
||||
type NL = EventListenerOrEventListenerObject;
|
||||
type EL = (data: Electron.MessageEvent) => void;
|
||||
const listeners = new WeakMap<NL, EL>();
|
||||
return {
|
||||
postMessage: (message, transfer) => {
|
||||
mp.postMessage(message, transfer as unknown as MessagePortMain[]);
|
||||
},
|
||||
addEventListener: (_, eh) => {
|
||||
const l: EL = (data) =>
|
||||
"handleEvent" in eh
|
||||
? eh.handleEvent({ data } as MessageEvent)
|
||||
: eh(data as unknown as MessageEvent);
|
||||
mp.on("message", (data) => {
|
||||
l(data);
|
||||
});
|
||||
listeners.set(eh, l);
|
||||
},
|
||||
removeEventListener: (_, eh) => {
|
||||
const l = listeners.get(eh);
|
||||
if (!l) return;
|
||||
mp.off("message", l);
|
||||
listeners.delete(eh);
|
||||
},
|
||||
start: mp.start.bind(mp),
|
||||
};
|
||||
};
|
||||
39
desktop/src/main/utils/stream.ts
Normal file
39
desktop/src/main/utils/stream.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import { createWriteStream, existsSync } from "node:fs";
|
||||
import fs from "node:fs/promises";
|
||||
import { Readable } from "node:stream";
|
||||
|
||||
/**
|
||||
* Write a (web) ReadableStream to a file at the given {@link filePath}.
|
||||
*
|
||||
* The returned promise resolves when the write completes.
|
||||
*
|
||||
* @param filePath The local file system path where the file should be written.
|
||||
*
|
||||
* @param readableStream A web
|
||||
* [ReadableStream](https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream).
|
||||
*
|
||||
*/
|
||||
export const writeStream = (
|
||||
filePath: string,
|
||||
readableStream: unknown /*ReadableStream*/, // @ts-expect-error [Note: Node and web stream type mismatch]
|
||||
) => writeNodeStream(filePath, Readable.fromWeb(readableStream));
|
||||
|
||||
const writeNodeStream = async (filePath: string, fileStream: Readable) => {
|
||||
const writeable = createWriteStream(filePath);
|
||||
|
||||
fileStream.on("error", (err) => {
|
||||
writeable.destroy(err); // Close the writable stream with an error
|
||||
});
|
||||
|
||||
fileStream.pipe(writeable);
|
||||
|
||||
await new Promise((resolve, reject) => {
|
||||
writeable.on("finish", resolve);
|
||||
writeable.on("error", (err) => {
|
||||
if (existsSync(filePath)) {
|
||||
void fs.unlink(filePath);
|
||||
}
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
};
|
||||
@@ -36,6 +36,28 @@
|
||||
* - [main] desktop/src/main/ipc.ts contains impl
|
||||
*/
|
||||
|
||||
// This code runs in the (isolated) web layer. Contrary to the impression given
|
||||
// by the Electron docs (as of 2024), the window object is actually available to
|
||||
// the preload script, and it is necessary for legitimate uses too.
|
||||
//
|
||||
// > The isolated world is connected to the DOM just the same is the main world,
|
||||
// > it is just the JS contexts that are separated.
|
||||
// >
|
||||
// > https://github.com/electron/electron/issues/27024#issuecomment-745618327
|
||||
//
|
||||
// Adding this reference here tells TypeScript that DOM typings (in particular,
|
||||
// window) should be introduced in the ambient scope.
|
||||
//
|
||||
// [Note: Node and web stream type mismatch]
|
||||
//
|
||||
// Unfortunately, adding this reference causes the ReadableStream typings to
|
||||
// break since lib.dom.d.ts adds its own incompatible definitions of
|
||||
// ReadableStream to the global scope.
|
||||
//
|
||||
// https://github.com/DefinitelyTyped/DefinitelyTyped/discussions/68407
|
||||
|
||||
/// <reference lib="dom" />
|
||||
|
||||
import { contextBridge, ipcRenderer, webUtils } from "electron/renderer";
|
||||
|
||||
// While we can't import other code, we can import types since they're just
|
||||
@@ -48,6 +70,19 @@ import type {
|
||||
ZipItem,
|
||||
} from "./types/ipc";
|
||||
|
||||
// - Infrastructure
|
||||
|
||||
// We need to wait until the renderer is ready before sending ports via
|
||||
// postMessage, and this promise comes handy in such cases. We create the
|
||||
// promise at the top level so that it is guaranteed to be registered before the
|
||||
// load event is fired.
|
||||
//
|
||||
// See: https://www.electronjs.org/docs/latest/tutorial/message-ports
|
||||
|
||||
const windowLoaded = new Promise((resolve) => {
|
||||
window.onload = resolve;
|
||||
});
|
||||
|
||||
// - General
|
||||
|
||||
const appVersion = () => ipcRenderer.invoke("appVersion");
|
||||
@@ -163,17 +198,15 @@ const ffmpegExec = (
|
||||
|
||||
// - ML
|
||||
|
||||
const computeCLIPImageEmbedding = (input: Float32Array) =>
|
||||
ipcRenderer.invoke("computeCLIPImageEmbedding", input);
|
||||
|
||||
const computeCLIPTextEmbeddingIfAvailable = (text: string) =>
|
||||
ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text);
|
||||
|
||||
const detectFaces = (input: Float32Array) =>
|
||||
ipcRenderer.invoke("detectFaces", input);
|
||||
|
||||
const computeFaceEmbeddings = (input: Float32Array) =>
|
||||
ipcRenderer.invoke("computeFaceEmbeddings", input);
|
||||
const createMLWorker = () => {
|
||||
ipcRenderer.send("createMLWorker");
|
||||
ipcRenderer.on("createMLWorker/port", (event) => {
|
||||
void windowLoaded.then(() => {
|
||||
// "*"" is the origin to send to.
|
||||
window.postMessage("createMLWorker/port", "*", event.ports);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// - Watch
|
||||
|
||||
@@ -281,8 +314,11 @@ const clearPendingUploads = () => ipcRenderer.invoke("clearPendingUploads");
|
||||
* operation when it happens across threads.
|
||||
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
|
||||
*
|
||||
* In our case though, we're not dealing with threads but separate processes. So
|
||||
* the ArrayBuffer will be copied:
|
||||
* In our case though, we're not dealing with threads but separate processes.
|
||||
* Electron currently only supports transferring MessagePorts:
|
||||
* https://github.com/electron/electron/issues/34905
|
||||
*
|
||||
* So the ArrayBuffer will be copied:
|
||||
*
|
||||
* > "parameters, errors and return values are **copied** when they're sent over
|
||||
* > the bridge".
|
||||
@@ -339,10 +375,7 @@ contextBridge.exposeInMainWorld("electron", {
|
||||
|
||||
// - ML
|
||||
|
||||
computeCLIPImageEmbedding,
|
||||
computeCLIPTextEmbeddingIfAvailable,
|
||||
detectFaces,
|
||||
computeFaceEmbeddings,
|
||||
createMLWorker,
|
||||
|
||||
// - Watch
|
||||
|
||||
|
||||
@@ -968,6 +968,11 @@ combined-stream@^1.0.8:
|
||||
dependencies:
|
||||
delayed-stream "~1.0.0"
|
||||
|
||||
comlink@^4.4.1:
|
||||
version "4.4.1"
|
||||
resolved "https://registry.yarnpkg.com/comlink/-/comlink-4.4.1.tgz#e568b8e86410b809e8600eb2cf40c189371ef981"
|
||||
integrity sha512-+1dlx0aY5Jo1vHy/tSsIGpSkN4tS9rZSW8FIhG0JH/crs9wwweswIo/POr451r7bZww3hFbPAKnTpimzL/mm4Q==
|
||||
|
||||
commander@^5.0.0:
|
||||
version "5.1.0"
|
||||
resolved "https://registry.yarnpkg.com/commander/-/commander-5.1.0.tgz#46abbd1652f8e059bddaef99bbdcb2ad9cf179ae"
|
||||
|
||||
@@ -82,7 +82,8 @@ export const register = async (): Promise<Registration> => {
|
||||
|
||||
// Register keypair with museum to get a pairing code.
|
||||
let pairingCode: string | undefined;
|
||||
// TODO: eslint has fixed this spurious warning, but we're not on the latest
|
||||
// [TODO: spurious while(true) eslint warning].
|
||||
// eslint has fixed this spurious warning, but we're not on the latest
|
||||
// version yet, so add a disable.
|
||||
// https://github.com/eslint/eslint/pull/18286
|
||||
/* eslint-disable no-constant-condition */
|
||||
|
||||
@@ -82,7 +82,7 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
|
||||
<EnteMenuItem
|
||||
endIcon={<ChevronRight />}
|
||||
onClick={() => setOpenMLSettings(true)}
|
||||
label={pt("ML search")}
|
||||
label={pt("Face and magic search")}
|
||||
/>
|
||||
</MenuItemGroup>
|
||||
</Box>
|
||||
|
||||
@@ -85,14 +85,9 @@ export default function Preferences({ open, onClose, onRootClose }) {
|
||||
<EnteMenuItem
|
||||
endIcon={<ChevronRight />}
|
||||
onClick={() => setOpenMLSettings(true)}
|
||||
label={pt("ML search")}
|
||||
label={pt("Face and magic search")}
|
||||
/>
|
||||
</MenuItemGroup>
|
||||
<MenuSectionTitle
|
||||
title={pt(
|
||||
"Face recognition, magic search and more",
|
||||
)}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</Stack>
|
||||
|
||||
@@ -22,7 +22,7 @@ export const photosLogout = async () => {
|
||||
// See: [Note: Caching IDB instances in separate execution contexts].
|
||||
|
||||
try {
|
||||
terminateMLWorker();
|
||||
await terminateMLWorker();
|
||||
} catch (e) {
|
||||
ignoreError("face", e);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import { isDesktop } from "@/base/app";
|
||||
import { ensureElectron } from "@/base/electron";
|
||||
import log from "@/base/log";
|
||||
import { FileType } from "@/media/file-type";
|
||||
import {
|
||||
clipMatches,
|
||||
isMLEnabled,
|
||||
isMLSupported,
|
||||
mlStatusSnapshot,
|
||||
} from "@/new/photos/services/ml";
|
||||
import { clipMatches } from "@/new/photos/services/ml/clip";
|
||||
import type { Person } from "@/new/photos/services/ml/people";
|
||||
import { EnteFile } from "@/new/photos/types/file";
|
||||
import * as chrono from "chrono-node";
|
||||
@@ -374,7 +373,7 @@ const searchClip = async (
|
||||
searchPhrase: string,
|
||||
): Promise<ClipSearchScores | undefined> => {
|
||||
if (!isMLEnabled()) return undefined;
|
||||
const matches = await clipMatches(searchPhrase, ensureElectron());
|
||||
const matches = await clipMatches(searchPhrase);
|
||||
log.debug(() => ["clip/scores", matches]);
|
||||
return matches;
|
||||
};
|
||||
|
||||
@@ -335,57 +335,21 @@ export interface Electron {
|
||||
// - ML
|
||||
|
||||
/**
|
||||
* Return a CLIP embedding of the given image.
|
||||
* Create a new ML worker, terminating the older ones (if any).
|
||||
*
|
||||
* See: [Note: Natural language search using CLIP]
|
||||
* This creates a new Node.js utility process, and sets things up so that we
|
||||
* can communicate directly with that utility process using a
|
||||
* {@link MessagePort} that gets posted using "createMLWorker/port".
|
||||
*
|
||||
* The input is a opaque float32 array representing the image. The layout
|
||||
* and exact encoding of the input is specific to our implementation and the
|
||||
* ML model (CLIP) we use.
|
||||
* At the other end of that port will be an object that conforms to the
|
||||
* {@link ElectronMLWorker} interface.
|
||||
*
|
||||
* @returns A CLIP embedding (an array of 512 floating point values).
|
||||
* For more details about the IPC flow, see: [Note: ML IPC].
|
||||
*
|
||||
* Note: For simplicity of implementation, we assume that there is at most
|
||||
* one outstanding call to {@link createMLWorker}.
|
||||
*/
|
||||
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Return a CLIP embedding of the given image if we already have the model
|
||||
* downloaded and prepped. If the model is not available return `undefined`.
|
||||
*
|
||||
* This differs from the other sibling ML functions in that it doesn't wait
|
||||
* for the model download to finish. It does trigger a model download, but
|
||||
* then immediately returns `undefined`. At some future point, when the
|
||||
* model downloaded finishes, calls to this function will start returning
|
||||
* the result we seek.
|
||||
*
|
||||
* The reason for doing it in this asymmetric way is because CLIP text
|
||||
* embeddings are used as part of deducing user initiated search results,
|
||||
* and we don't want to block that interaction on a large network request.
|
||||
*
|
||||
* See: [Note: Natural language search using CLIP]
|
||||
*
|
||||
* @param text The string whose embedding we want to compute.
|
||||
*
|
||||
* @returns A CLIP embedding.
|
||||
*/
|
||||
computeCLIPTextEmbeddingIfAvailable: (
|
||||
text: string,
|
||||
) => Promise<Float32Array | undefined>;
|
||||
|
||||
/**
|
||||
* Detect faces in the given image using YOLO.
|
||||
*
|
||||
* Both the input and output are opaque binary data whose internal structure
|
||||
* is specific to our implementation and the model (YOLO) we use.
|
||||
*/
|
||||
detectFaces: (input: Float32Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Return a MobileFaceNet embeddings for the given faces.
|
||||
*
|
||||
* Both the input and output are opaque binary data whose internal structure
|
||||
* is specific to our implementation and the model (MobileFaceNet) we use.
|
||||
*/
|
||||
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
|
||||
createMLWorker: () => void;
|
||||
|
||||
// - Watch
|
||||
|
||||
@@ -574,6 +538,65 @@ export interface Electron {
|
||||
clearPendingUploads: () => Promise<void>;
|
||||
}
|
||||
|
||||
/**
|
||||
* The shape of the object exposed by the Node.js ML worker process on the
|
||||
* message port that the web layer obtains by doing {@link createMLWorker}.
|
||||
*/
|
||||
export interface ElectronMLWorker {
|
||||
/**
|
||||
* Return a CLIP embedding of the given image.
|
||||
*
|
||||
* See: [Note: Natural language search using CLIP]
|
||||
*
|
||||
* The input is a opaque float32 array representing the image. The layout
|
||||
* and exact encoding of the input is specific to our implementation and the
|
||||
* ML model (CLIP) we use.
|
||||
*
|
||||
* @returns A CLIP embedding (an array of 512 floating point values).
|
||||
*/
|
||||
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Return a CLIP embedding of the given image if we already have the model
|
||||
* downloaded and prepped. If the model is not available return `undefined`.
|
||||
*
|
||||
* This differs from the other sibling ML functions in that it doesn't wait
|
||||
* for the model download to finish. It does trigger a model download, but
|
||||
* then immediately returns `undefined`. At some future point, when the
|
||||
* model downloaded finishes, calls to this function will start returning
|
||||
* the result we seek.
|
||||
*
|
||||
* The reason for doing it in this asymmetric way is because CLIP text
|
||||
* embeddings are used as part of deducing user initiated search results,
|
||||
* and we don't want to block that interaction on a large network request.
|
||||
*
|
||||
* See: [Note: Natural language search using CLIP]
|
||||
*
|
||||
* @param text The string whose embedding we want to compute.
|
||||
*
|
||||
* @returns A CLIP embedding.
|
||||
*/
|
||||
computeCLIPTextEmbeddingIfAvailable: (
|
||||
text: string,
|
||||
) => Promise<Float32Array | undefined>;
|
||||
|
||||
/**
|
||||
* Detect faces in the given image using YOLO.
|
||||
*
|
||||
* Both the input and output are opaque binary data whose internal structure
|
||||
* is specific to our implementation and the model (YOLO) we use.
|
||||
*/
|
||||
detectFaces: (input: Float32Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Return a MobileFaceNet embeddings for the given faces.
|
||||
*
|
||||
* Both the input and output are opaque binary data whose internal structure
|
||||
* is specific to our implementation and the model (MobileFaceNet) we use.
|
||||
*/
|
||||
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Errors that have special semantics on the web side.
|
||||
*
|
||||
|
||||
@@ -28,7 +28,7 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
/** The class (T) exposed by the web worker */
|
||||
public remote: Promise<Remote<InstanceType<T>>>;
|
||||
/** The web worker */
|
||||
private worker: Worker;
|
||||
public worker: Worker;
|
||||
/** An arbitrary name associated with this ComlinkWorker for debugging. */
|
||||
private name: string;
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ export const MLSettings: React.FC<MLSettingsProps> = ({
|
||||
<Stack spacing={"4px"} py={"12px"}>
|
||||
<Titlebar
|
||||
onClose={onClose}
|
||||
title={pt("ML search")}
|
||||
title={pt("Face and magic search")}
|
||||
onRootClose={onRootClose}
|
||||
/>
|
||||
{component}
|
||||
@@ -305,7 +305,7 @@ const ManageML: React.FC<ManageMLProps> = ({
|
||||
let status: string;
|
||||
switch (phase) {
|
||||
case "indexing":
|
||||
status = pt("Indexing");
|
||||
status = pt("Running");
|
||||
break;
|
||||
case "scheduled":
|
||||
status = pt("Scheduled");
|
||||
@@ -319,9 +319,9 @@ const ManageML: React.FC<ManageMLProps> = ({
|
||||
|
||||
const confirmDisableML = () => {
|
||||
setDialogBoxAttributesV2({
|
||||
title: pt("Disable ML search"),
|
||||
title: pt("Disable face and magic search"),
|
||||
content: pt(
|
||||
"Do you want to disable ML search on all your devices?",
|
||||
"Do you want to disable face and magic search on all your devices?",
|
||||
),
|
||||
close: { text: t("cancel") },
|
||||
proceed: {
|
||||
@@ -356,7 +356,7 @@ const ManageML: React.FC<ManageMLProps> = ({
|
||||
justifyContent={"space-between"}
|
||||
>
|
||||
<Typography color="text.faint">
|
||||
{pt("Status")}
|
||||
{pt("Indexing")}
|
||||
</Typography>
|
||||
<Typography>{status}</Typography>
|
||||
</Stack>
|
||||
|
||||
@@ -42,7 +42,7 @@ export const MLSettingsBeta: React.FC<MLSettingsBetaProps> = ({
|
||||
<Stack spacing={"4px"} py={"12px"}>
|
||||
<Titlebar
|
||||
onClose={onClose}
|
||||
title={pt("ML search")}
|
||||
title={pt("Face and magic search")}
|
||||
onRootClose={onRootClose}
|
||||
/>
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { basename } from "@/base/file";
|
||||
import type { ElectronMLWorker } from "@/base/types/ipc";
|
||||
import { FileType } from "@/media/file-type";
|
||||
import { decodeLivePhoto } from "@/media/live-photo";
|
||||
import { ensure } from "@/utils/ensure";
|
||||
@@ -7,7 +8,6 @@ import { renderableImageBlob } from "../../utils/file";
|
||||
import { readStream } from "../../utils/native-stream";
|
||||
import DownloadManager from "../download";
|
||||
import type { UploadItem } from "../upload/types";
|
||||
import type { MLWorkerElectron } from "./worker-types";
|
||||
|
||||
/**
|
||||
* A pair of blobs - the original, and a possibly converted "renderable" one -
|
||||
@@ -103,13 +103,14 @@ export const imageBitmapAndData = async (
|
||||
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
|
||||
* use the on-disk file instead of needing to download the original from remote.
|
||||
*
|
||||
* @param electron The {@link MLWorkerElectron} instance that allows us to call
|
||||
* our Node.js layer for various functionality.
|
||||
* @param electron The {@link ElectronMLWorker} instance that stands as a
|
||||
* witness that we're actually running in our desktop app (and thus can safely
|
||||
* call our Node.js layer for various functionality).
|
||||
*/
|
||||
export const indexableBlobs = async (
|
||||
enteFile: EnteFile,
|
||||
uploadItem: UploadItem | undefined,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<IndexableBlobs> =>
|
||||
uploadItem
|
||||
? await indexableUploadItemBlobs(enteFile, uploadItem, electron)
|
||||
@@ -118,7 +119,7 @@ export const indexableBlobs = async (
|
||||
const indexableUploadItemBlobs = async (
|
||||
enteFile: EnteFile,
|
||||
uploadItem: UploadItem,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
) => {
|
||||
const fileType = enteFile.metadata.fileType;
|
||||
let originalImageBlob: Blob | undefined;
|
||||
@@ -149,7 +150,7 @@ const indexableUploadItemBlobs = async (
|
||||
*/
|
||||
const readNonVideoUploadItem = async (
|
||||
uploadItem: UploadItem,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<File> => {
|
||||
if (typeof uploadItem == "string" || Array.isArray(uploadItem)) {
|
||||
const { response, lastModifiedMs } = await readStream(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { Electron } from "@/base/types/ipc";
|
||||
import type { ElectronMLWorker } from "@/base/types/ipc";
|
||||
import type { ImageBitmapAndData } from "./blob";
|
||||
import { clipIndexes } from "./db";
|
||||
import { pixelRGBBicubic } from "./image";
|
||||
import { dotProduct, norm } from "./math";
|
||||
import type { MLWorkerElectron } from "./worker-types";
|
||||
import type { CLIPMatches } from "./worker-types";
|
||||
|
||||
/**
|
||||
* The version of the CLIP indexing pipeline implemented by the current client.
|
||||
@@ -98,19 +98,19 @@ export type LocalCLIPIndex = CLIPIndex & {
|
||||
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
|
||||
* use the on-disk file instead of needing to download the original from remote.
|
||||
*
|
||||
* @param electron The {@link MLWorkerElectron} instance that allows us to call
|
||||
* @param electron The {@link ElectronMLWorker} instance that allows us to call
|
||||
* our Node.js layer to run the ONNX inference.
|
||||
*/
|
||||
export const indexCLIP = async (
|
||||
image: ImageBitmapAndData,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<CLIPIndex> => ({
|
||||
embedding: await computeEmbedding(image.data, electron),
|
||||
});
|
||||
|
||||
const computeEmbedding = async (
|
||||
imageData: ImageData,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<number[]> => {
|
||||
const clipInput = convertToCLIPInput(imageData);
|
||||
return normalized(await electron.computeCLIPImageEmbedding(clipInput));
|
||||
@@ -167,26 +167,15 @@ const normalized = (embedding: Float32Array) => {
|
||||
};
|
||||
|
||||
/**
|
||||
* Use CLIP to perform a natural language search over image embeddings.
|
||||
*
|
||||
* @param searchPhrase The text entered by the user in the search box.
|
||||
*
|
||||
* @param electron The {@link Electron} instance to use to communicate with the
|
||||
* native code running in our desktop app (the embedding happens in the native
|
||||
* layer).
|
||||
*
|
||||
* It returns file (IDs) that should be shown in the search results. They're
|
||||
* returned as a map from fileIDs to the scores they got (higher is better).
|
||||
* This map will only contains entries whose score was above our minimum
|
||||
* threshold.
|
||||
* Find the files whose CLIP embedding "matches" the given {@link searchPhrase}.
|
||||
*
|
||||
* The result can also be `undefined`, which indicates that the download for the
|
||||
* ML model is still in progress (trying again later should succeed).
|
||||
*/
|
||||
export const clipMatches = async (
|
||||
searchPhrase: string,
|
||||
electron: Electron,
|
||||
): Promise<Map<number, number> | undefined> => {
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<CLIPMatches | undefined> => {
|
||||
const t = await electron.computeCLIPTextEmbeddingIfAvailable(searchPhrase);
|
||||
if (!t) return undefined;
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//
|
||||
/* eslint-disable @typescript-eslint/no-non-null-assertion */
|
||||
|
||||
import type { ElectronMLWorker } from "@/base/types/ipc";
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import { Matrix } from "ml-matrix";
|
||||
import { getSimilarityTransformation } from "similarity-transformation";
|
||||
@@ -24,7 +25,6 @@ import {
|
||||
warpAffineFloat32List,
|
||||
} from "./image";
|
||||
import { clamp } from "./math";
|
||||
import type { MLWorkerElectron } from "./worker-types";
|
||||
|
||||
/**
|
||||
* The version of the face indexing pipeline implemented by the current client.
|
||||
@@ -236,13 +236,13 @@ export interface Box {
|
||||
*
|
||||
* @param image The file's contents.
|
||||
*
|
||||
* @param electron The {@link MLWorkerElectron} instance that allows us to call
|
||||
* @param electron The {@link ElectronMLWorker} instance that allows us to call
|
||||
* our Node.js layer to run the ONNX inference.
|
||||
*/
|
||||
export const indexFaces = async (
|
||||
enteFile: EnteFile,
|
||||
{ data: imageData }: ImageBitmapAndData,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<FaceIndex> => ({
|
||||
width: imageData.width,
|
||||
height: imageData.height,
|
||||
@@ -252,7 +252,7 @@ export const indexFaces = async (
|
||||
const indexFaces_ = async (
|
||||
fileID: number,
|
||||
imageData: ImageData,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<Face[]> => {
|
||||
const { width, height } = imageData;
|
||||
const imageDimensions = { width, height };
|
||||
@@ -316,7 +316,7 @@ const indexFaces_ = async (
|
||||
*/
|
||||
const detectFaces = async (
|
||||
imageData: ImageData,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<YOLOFaceDetection[]> => {
|
||||
const rect = ({ width, height }: Dimensions) => ({
|
||||
x: 0,
|
||||
@@ -878,7 +878,7 @@ const mobileFaceNetEmbeddingSize = 192;
|
||||
*/
|
||||
const computeEmbeddings = async (
|
||||
faceData: Float32Array,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
): Promise<Float32Array[]> => {
|
||||
const outputData = await electron.computeFaceEmbeddings(faceData);
|
||||
|
||||
|
||||
@@ -6,17 +6,20 @@ import { isDesktop } from "@/base/app";
|
||||
import { blobCache } from "@/base/blob-cache";
|
||||
import { ensureElectron } from "@/base/electron";
|
||||
import log from "@/base/log";
|
||||
import type { Electron } from "@/base/types/ipc";
|
||||
import { ComlinkWorker } from "@/base/worker/comlink-worker";
|
||||
import { FileType } from "@/media/file-type";
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import { throttled } from "@/utils/promise";
|
||||
import { proxy } from "comlink";
|
||||
import { proxy, transfer } from "comlink";
|
||||
import { isInternalUser } from "../feature-flags";
|
||||
import { getRemoteFlag, updateRemoteFlag } from "../remote-store";
|
||||
import type { UploadItem } from "../upload/types";
|
||||
import { regenerateFaceCrops } from "./crop";
|
||||
import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db";
|
||||
import { MLWorker } from "./worker";
|
||||
import type { CLIPMatches } from "./worker-types";
|
||||
|
||||
/**
|
||||
* In-memory flag that tracks if ML is enabled.
|
||||
@@ -33,7 +36,7 @@ import { MLWorker } from "./worker";
|
||||
let _isMLEnabled = false;
|
||||
|
||||
/** Cached instance of the {@link ComlinkWorker} that wraps our web worker. */
|
||||
let _comlinkWorker: ComlinkWorker<typeof MLWorker> | undefined;
|
||||
let _comlinkWorker: Promise<ComlinkWorker<typeof MLWorker>> | undefined;
|
||||
|
||||
/**
|
||||
* Subscriptions to {@link MLStatus}.
|
||||
@@ -50,29 +53,28 @@ let _mlStatusListeners: (() => void)[] = [];
|
||||
let _mlStatusSnapshot: MLStatus | undefined;
|
||||
|
||||
/** Lazily created, cached, instance of {@link MLWorker}. */
|
||||
const worker = async () => {
|
||||
if (!_comlinkWorker) _comlinkWorker = await createComlinkWorker();
|
||||
return _comlinkWorker.remote;
|
||||
};
|
||||
const worker = () =>
|
||||
(_comlinkWorker ??= createComlinkWorker()).then((cw) => cw.remote);
|
||||
|
||||
const createComlinkWorker = async () => {
|
||||
const electron = ensureElectron();
|
||||
const mlWorkerElectron = {
|
||||
detectFaces: electron.detectFaces,
|
||||
computeFaceEmbeddings: electron.computeFaceEmbeddings,
|
||||
computeCLIPImageEmbedding: electron.computeCLIPImageEmbedding,
|
||||
};
|
||||
const delegate = {
|
||||
workerDidProcessFile,
|
||||
};
|
||||
|
||||
// Obtain a message port from the Electron layer.
|
||||
const messagePort = await createMLWorker(electron);
|
||||
|
||||
const cw = new ComlinkWorker<typeof MLWorker>(
|
||||
"ML",
|
||||
new Worker(new URL("worker.ts", import.meta.url)),
|
||||
);
|
||||
|
||||
await cw.remote.then((w) =>
|
||||
w.init(proxy(mlWorkerElectron), proxy(delegate)),
|
||||
// Forward the port to the web worker.
|
||||
w.init(transfer(messagePort, [messagePort]), proxy(delegate)),
|
||||
);
|
||||
|
||||
return cw;
|
||||
};
|
||||
|
||||
@@ -85,13 +87,40 @@ const createComlinkWorker = async () => {
|
||||
*
|
||||
* It is also called when the user pauses or disables ML.
|
||||
*/
|
||||
export const terminateMLWorker = () => {
|
||||
export const terminateMLWorker = async () => {
|
||||
if (_comlinkWorker) {
|
||||
_comlinkWorker.terminate();
|
||||
await _comlinkWorker.then((cw) => cw.terminate());
|
||||
_comlinkWorker = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Obtain a port from the Node.js layer that can be used to communicate with the
|
||||
* ML worker process.
|
||||
*/
|
||||
const createMLWorker = (electron: Electron): Promise<MessagePort> => {
|
||||
// The main process will do its thing, and send back the port it created to
|
||||
// us by sending an message on the "createMLWorker/port" channel via the
|
||||
// postMessage API. This roundabout way is needed because MessagePorts
|
||||
// cannot be transferred via the usual send/invoke pattern.
|
||||
|
||||
const port = new Promise<MessagePort>((resolve) => {
|
||||
const l = ({ source, data, ports }: MessageEvent) => {
|
||||
// The source check verifies that the message is coming from our own
|
||||
// preload script. The data is the message that was posted.
|
||||
if (source == window && data == "createMLWorker/port") {
|
||||
window.removeEventListener("message", l);
|
||||
resolve(ensure(ports[0]));
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", l);
|
||||
});
|
||||
|
||||
electron.createMLWorker();
|
||||
|
||||
return port;
|
||||
};
|
||||
|
||||
/**
|
||||
* Return true if the current client supports ML.
|
||||
*
|
||||
@@ -163,7 +192,7 @@ export const disableML = async () => {
|
||||
await updateIsMLEnabledRemote(false);
|
||||
setIsMLEnabledLocal(false);
|
||||
_isMLEnabled = false;
|
||||
terminateMLWorker();
|
||||
await terminateMLWorker();
|
||||
triggerStatusUpdate();
|
||||
};
|
||||
|
||||
@@ -369,6 +398,22 @@ const setInterimScheduledStatus = () => {
|
||||
|
||||
const workerDidProcessFile = throttled(updateMLStatusSnapshot, 2000);
|
||||
|
||||
/**
|
||||
* Use CLIP to perform a natural language search over image embeddings.
|
||||
*
|
||||
* @param searchPhrase The text entered by the user in the search box.
|
||||
*
|
||||
* It returns file (IDs) that should be shown in the search results, along with
|
||||
* their scores.
|
||||
*
|
||||
* The result can also be `undefined`, which indicates that the download for the
|
||||
* ML model is still in progress (trying again later should succeed).
|
||||
*/
|
||||
export const clipMatches = (
|
||||
searchPhrase: string,
|
||||
): Promise<CLIPMatches | undefined> =>
|
||||
worker().then((w) => w.clipMatches(searchPhrase));
|
||||
|
||||
/**
|
||||
* Return the IDs of all the faces in the given {@link enteFile} that are not
|
||||
* associated with a person cluster.
|
||||
|
||||
@@ -1,22 +1,7 @@
|
||||
/**
|
||||
* @file Type for the objects shared (as a Comlink proxy) by the main thread and
|
||||
* the ML worker.
|
||||
* @file Types for the objects shared between the main thread and the ML worker.
|
||||
*/
|
||||
|
||||
/**
|
||||
* A subset of {@link Electron} provided to the {@link MLWorker}.
|
||||
*
|
||||
* `globalThis.electron` does not exist in the execution context of web workers.
|
||||
* So instead, we manually provide a proxy object of type
|
||||
* {@link MLWorkerElectron} that exposes a subset of the functions from
|
||||
* {@link Electron} that are needed by the code running in the ML web worker.
|
||||
*/
|
||||
export interface MLWorkerElectron {
|
||||
detectFaces: (input: Float32Array) => Promise<Float32Array>;
|
||||
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
|
||||
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Callbacks invoked by the worker at various points in the indexing pipeline to
|
||||
* notify the main thread of events it might be interested in.
|
||||
@@ -25,7 +10,18 @@ export interface MLWorkerDelegate {
|
||||
/**
|
||||
* Called whenever a file is processed during indexing.
|
||||
*
|
||||
* It is called both when the indexing was successful or failed.
|
||||
* It is called both when the indexing was successful or it failed.
|
||||
*/
|
||||
workerDidProcessFile: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* The result of file ids that should be considered as matches for a particular
|
||||
* search phrase, each with their associated score.
|
||||
*
|
||||
* This is a map of file (IDs) that should be shown in the search results.
|
||||
* They're returned as a map from fileIDs to the scores they got (higher is
|
||||
* better). This map will only contains entries whose score was above our
|
||||
* minimum threshold.
|
||||
*/
|
||||
export type CLIPMatches = Map<number, number>;
|
||||
|
||||
@@ -3,14 +3,15 @@ import { isHTTP4xxError } from "@/base/http";
|
||||
import { getKVN } from "@/base/kv";
|
||||
import { ensureAuthToken } from "@/base/local-user";
|
||||
import log from "@/base/log";
|
||||
import type { ElectronMLWorker } from "@/base/types/ipc";
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import { fileLogID } from "@/new/photos/utils/file";
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import { wait } from "@/utils/promise";
|
||||
import { DOMParser } from "@xmldom/xmldom";
|
||||
import { expose } from "comlink";
|
||||
import { expose, wrap } from "comlink";
|
||||
import downloadManager from "../download";
|
||||
import { cmpNewLib2, extractRawExif } from "../exif";
|
||||
import { cmpNewLib2, extractRawExif, type RawExifTags } from "../exif";
|
||||
import { getAllLocalFiles, getLocalTrashedFiles } from "../files";
|
||||
import type { UploadItem } from "../upload/types";
|
||||
import {
|
||||
@@ -18,7 +19,12 @@ import {
|
||||
indexableBlobs,
|
||||
type ImageBitmapAndData,
|
||||
} from "./blob";
|
||||
import { clipIndexingVersion, indexCLIP, type CLIPIndex } from "./clip";
|
||||
import {
|
||||
clipIndexingVersion,
|
||||
clipMatches,
|
||||
indexCLIP,
|
||||
type CLIPIndex,
|
||||
} from "./clip";
|
||||
import { saveFaceCrops } from "./crop";
|
||||
import {
|
||||
indexableFileIDs,
|
||||
@@ -29,10 +35,11 @@ import {
|
||||
import {
|
||||
fetchDerivedData,
|
||||
putDerivedData,
|
||||
type RawRemoteDerivedData,
|
||||
type RemoteDerivedData,
|
||||
} from "./embedding";
|
||||
import { faceIndexingVersion, indexFaces, type FaceIndex } from "./face";
|
||||
import type { MLWorkerDelegate, MLWorkerElectron } from "./worker-types";
|
||||
import type { CLIPMatches, MLWorkerDelegate } from "./worker-types";
|
||||
|
||||
const idleDurationStart = 5; /* 5 seconds */
|
||||
const idleDurationMax = 16 * 60; /* 16 minutes */
|
||||
@@ -67,9 +74,12 @@ interface IndexableItem {
|
||||
* - "backfillq": fetching remote embeddings of unindexed items, and then
|
||||
* indexing them if needed,
|
||||
* - "idle": in between state transitions.
|
||||
*
|
||||
* In addition, MLWorker can also be invoked for interactive tasks: in
|
||||
* particular, for finding the closest CLIP match when the user does a search.
|
||||
*/
|
||||
export class MLWorker {
|
||||
private electron: MLWorkerElectron | undefined;
|
||||
private electron: ElectronMLWorker | undefined;
|
||||
private delegate: MLWorkerDelegate | undefined;
|
||||
private state: "idle" | "indexing" = "idle";
|
||||
private liveQ: IndexableItem[] = [];
|
||||
@@ -82,15 +92,16 @@ export class MLWorker {
|
||||
* This is conceptually the constructor, however it is easier to have this
|
||||
* as a separate function to avoid complicating the comlink types further.
|
||||
*
|
||||
* @param electron The {@link MLWorkerElectron} that allows the worker to
|
||||
* use the functionality provided by our Node.js layer when running in the
|
||||
* context of our desktop app.
|
||||
* @param port A {@link MessagePort} that allows us to communicate with an
|
||||
* Electron utility process running in the Node.js layer of our desktop app,
|
||||
* exposing an object that conforms to the {@link ElectronMLWorker}
|
||||
* interface.
|
||||
*
|
||||
* @param delegate The {@link MLWorkerDelegate} the worker can use to inform
|
||||
* the main thread of interesting events.
|
||||
*/
|
||||
async init(electron: MLWorkerElectron, delegate?: MLWorkerDelegate) {
|
||||
this.electron = electron;
|
||||
async init(port: MessagePort, delegate: MLWorkerDelegate) {
|
||||
this.electron = wrap<ElectronMLWorker>(port);
|
||||
this.delegate = delegate;
|
||||
// Initialize the downloadManager running in the web worker with the
|
||||
// user's token. It'll be used to download files to index if needed.
|
||||
@@ -176,6 +187,13 @@ export class MLWorker {
|
||||
return this.state == "indexing";
|
||||
}
|
||||
|
||||
/**
|
||||
* Find {@link CLIPMatches} for a given {@link searchPhrase}.
|
||||
*/
|
||||
async clipMatches(searchPhrase: string): Promise<CLIPMatches | undefined> {
|
||||
return clipMatches(searchPhrase, ensure(this.electron));
|
||||
}
|
||||
|
||||
private async tick() {
|
||||
log.debug(() => [
|
||||
"ml/tick",
|
||||
@@ -224,7 +242,7 @@ export class MLWorker {
|
||||
}
|
||||
|
||||
/** Return the next batch of items to backfill (if any). */
|
||||
async backfillQ() {
|
||||
private async backfillQ() {
|
||||
const userID = ensure(await getKVN("userID"));
|
||||
// Find files that our local DB thinks need syncing.
|
||||
const filesByID = await syncWithLocalFilesAndGetFilesToIndex(
|
||||
@@ -256,7 +274,7 @@ expose(MLWorker);
|
||||
*/
|
||||
const indexNextBatch = async (
|
||||
items: IndexableItem[],
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
delegate: MLWorkerDelegate | undefined,
|
||||
) => {
|
||||
// Don't try to index if we wouldn't be able to upload them anyway. The
|
||||
@@ -270,18 +288,37 @@ const indexNextBatch = async (
|
||||
// Nothing to do.
|
||||
if (items.length == 0) return false;
|
||||
|
||||
// Index, keeping track if any of the items failed.
|
||||
// Keep track if any of the items failed.
|
||||
let allSuccess = true;
|
||||
for (const item of items) {
|
||||
try {
|
||||
await index(item, electron);
|
||||
delegate?.workerDidProcessFile();
|
||||
// Possibly unnecessary, but let us drain the microtask queue.
|
||||
await wait(0);
|
||||
} catch (e) {
|
||||
log.warn(`Skipping unindexable file ${item.enteFile.id}`, e);
|
||||
allSuccess = false;
|
||||
|
||||
// Index up to 4 items simultaneously.
|
||||
const tasks = new Array<Promise<void> | undefined>(4).fill(undefined);
|
||||
|
||||
let i = 0;
|
||||
while (i < items.length) {
|
||||
for (let j = 0; j < tasks.length; j++) {
|
||||
if (i < items.length && !tasks[j]) {
|
||||
tasks[j] = index(ensure(items[i++]), electron)
|
||||
.then(() => {
|
||||
tasks[j] = undefined;
|
||||
})
|
||||
.catch(() => {
|
||||
allSuccess = false;
|
||||
tasks[j] = undefined;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for at least one to complete (the other runners continue running
|
||||
// even if one promise reaches the finish line).
|
||||
await Promise.race(tasks);
|
||||
|
||||
// Let the main thread now we're doing something.
|
||||
delegate?.workerDidProcessFile();
|
||||
|
||||
// Let us drain the microtask queue. This also gives a chance for other
|
||||
// interactive tasks like `clipMatches` to run.
|
||||
await wait(0);
|
||||
}
|
||||
|
||||
// Return true if nothing failed.
|
||||
@@ -374,7 +411,7 @@ const syncWithLocalFilesAndGetFilesToIndex = async (
|
||||
*/
|
||||
const index = async (
|
||||
{ enteFile, uploadItem, remoteDerivedData }: IndexableItem,
|
||||
electron: MLWorkerElectron,
|
||||
electron: ElectronMLWorker,
|
||||
) => {
|
||||
const f = fileLogID(enteFile);
|
||||
const fileID = enteFile.id;
|
||||
@@ -463,10 +500,7 @@ const index = async (
|
||||
[faceIndex, clipIndex, exif] = await Promise.all([
|
||||
existingFaceIndex ?? indexFaces(enteFile, image, electron),
|
||||
existingCLIPIndex ?? indexCLIP(image, electron),
|
||||
existingExif ??
|
||||
(originalImageBlob
|
||||
? extractRawExif(originalImageBlob)
|
||||
: undefined),
|
||||
existingExif ?? tryExtractExif(originalImageBlob, f),
|
||||
]);
|
||||
} catch (e) {
|
||||
// See: [Note: Transient and permanent indexing failures]
|
||||
@@ -475,8 +509,12 @@ const index = async (
|
||||
throw e;
|
||||
}
|
||||
|
||||
if (originalImageBlob && exif)
|
||||
await cmpNewLib2(enteFile, originalImageBlob, exif);
|
||||
try {
|
||||
if (originalImageBlob && exif)
|
||||
await cmpNewLib2(enteFile, originalImageBlob, exif);
|
||||
} catch (e) {
|
||||
log.warn(`Skipping exif cmp for ${f}`, e);
|
||||
}
|
||||
|
||||
log.debug(() => {
|
||||
const ms = Date.now() - startTime;
|
||||
@@ -504,22 +542,27 @@ const index = async (
|
||||
// parts. See: [Note: Preserve unknown derived data fields].
|
||||
|
||||
const existingRawDerivedData = remoteDerivedData?.raw ?? {};
|
||||
const rawDerivedData = {
|
||||
const rawDerivedData: RawRemoteDerivedData = {
|
||||
...existingRawDerivedData,
|
||||
face: remoteFaceIndex,
|
||||
clip: remoteCLIPIndex,
|
||||
exif,
|
||||
...(exif ? { exif } : {}),
|
||||
};
|
||||
|
||||
log.debug(() => ["Uploading derived data", rawDerivedData]);
|
||||
if (existingFaceIndex && existingCLIPIndex && !exif) {
|
||||
// If we were indexing just for exif, but exif generation didn't
|
||||
// happen, there is no need to upload.
|
||||
} else {
|
||||
log.debug(() => ["Uploading derived data", rawDerivedData]);
|
||||
|
||||
try {
|
||||
await putDerivedData(enteFile, rawDerivedData);
|
||||
} catch (e) {
|
||||
// See: [Note: Transient and permanent indexing failures]
|
||||
log.error(`Failed to put derived data for ${f}`, e);
|
||||
if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id);
|
||||
throw e;
|
||||
try {
|
||||
await putDerivedData(enteFile, rawDerivedData);
|
||||
} catch (e) {
|
||||
// See: [Note: Transient and permanent indexing failures]
|
||||
log.error(`Failed to put derived data for ${f}`, e);
|
||||
if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -550,3 +593,34 @@ const index = async (
|
||||
image.bitmap.close();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A helper function that tries to extract the raw Exif, but returns `undefined`
|
||||
* if something goes wrong (or it isn't possible) instead of throwing.
|
||||
*
|
||||
* Exif extraction is not a critical item, we don't want the actual indexing to
|
||||
* fail because we were unable to extract Exif. This is not rare: one scenario
|
||||
* is if we were trying to index a file in an exotic format. The ML indexing
|
||||
* will succeed (because we convert it to a renderable blob), but the Exif
|
||||
* extraction will fail (since it needs the original blob, but the original blob
|
||||
* can be an arbitrary format).
|
||||
*
|
||||
* @param originalImageBlob A {@link Blob} containing the original data for the
|
||||
* image (or the image component of a live photo) whose Exif we're trying to
|
||||
* extract. If this is not available, we skip the extraction and return
|
||||
* `undefined`.
|
||||
*
|
||||
* @param f The {@link fileLogID} for the file this blob corresponds to.
|
||||
*/
|
||||
export const tryExtractExif = async (
|
||||
originalImageBlob: Blob | undefined,
|
||||
f: string,
|
||||
): Promise<RawExifTags | undefined> => {
|
||||
if (!originalImageBlob) return undefined;
|
||||
try {
|
||||
return await extractRawExif(originalImageBlob);
|
||||
} catch (e) {
|
||||
log.warn(`Ignoring error during Exif extraction for ${f}`, e);
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -6,8 +6,7 @@
|
||||
* See: [Note: IPC streams].
|
||||
*/
|
||||
|
||||
import type { Electron, ZipItem } from "@/base/types/ipc";
|
||||
import type { MLWorkerElectron } from "../services/ml/worker-types";
|
||||
import type { Electron, ElectronMLWorker, ZipItem } from "@/base/types/ipc";
|
||||
|
||||
/**
|
||||
* Stream the given file or zip entry from the user's local file system.
|
||||
@@ -18,7 +17,7 @@ import type { MLWorkerElectron } from "../services/ml/worker-types";
|
||||
*
|
||||
* To avoid accidentally invoking it in a non-desktop app context, it requires
|
||||
* the {@link Electron} (or a functionally similar) object as a parameter (even
|
||||
* though it doesn't use it).
|
||||
* though it doesn't need or use it).
|
||||
*
|
||||
* @param pathOrZipItem Either the path on the file on the user's local file
|
||||
* system whose contents we want to stream. Or a tuple containing the path to a
|
||||
@@ -36,7 +35,7 @@ import type { MLWorkerElectron } from "../services/ml/worker-types";
|
||||
* reading, expressed as epoch milliseconds.
|
||||
*/
|
||||
export const readStream = async (
|
||||
_: Electron | MLWorkerElectron,
|
||||
_: Electron | ElectronMLWorker,
|
||||
pathOrZipItem: string | ZipItem,
|
||||
): Promise<{ response: Response; size: number; lastModifiedMs: number }> => {
|
||||
let url: URL;
|
||||
|
||||
Reference in New Issue
Block a user