From d0b1ff5520ea122a3f096b94e0fc7123ab395082 Mon Sep 17 00:00:00 2001 From: Manav Rathi Date: Sat, 20 Apr 2024 10:10:33 +0530 Subject: [PATCH] In band signalling --- desktop/src/main/ipc.ts | 6 +- desktop/src/main/services/ml-clip.ts | 93 ++++------------- desktop/src/main/services/ml-face.ts | 10 +- desktop/src/main/services/ml.ts | 8 +- desktop/src/preload.ts | 99 +++++++++++-------- desktop/src/types/ipc.ts | 15 --- web/apps/photos/src/services/clip-service.ts | 4 +- web/apps/photos/src/services/searchService.ts | 40 ++++---- web/packages/next/types/ipc.ts | 17 +++- web/packages/shared/error/index.ts | 2 - 10 files changed, 129 insertions(+), 165 deletions(-) diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index a421e7642a..b8763c520f 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -45,7 +45,7 @@ import { convertToJPEG, generateImageThumbnail, } from "./services/imageProcessor"; -import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip"; +import { clipImageEmbedding, clipTextEmbeddingIfAvailable } from "./services/ml-clip"; import { detectFaces, faceEmbedding } from "./services/ml-face"; import { clearStores, @@ -169,8 +169,8 @@ export const attachIPCHandlers = () => { clipImageEmbedding(jpegImageData), ); - ipcMain.handle("clipTextEmbedding", (_, text: string) => - clipTextEmbedding(text), + ipcMain.handle("clipTextEmbeddingIfAvailable", (_, text: string) => + clipTextEmbeddingIfAvailable(text), ); ipcMain.handle("detectFaces", (_, input: Float32Array) => diff --git a/desktop/src/main/services/ml-clip.ts b/desktop/src/main/services/ml-clip.ts index 954e3859fe..b4fa2c66d1 100644 --- a/desktop/src/main/services/ml-clip.ts +++ b/desktop/src/main/services/ml-clip.ts @@ -5,86 +5,21 @@ * * @see `web/apps/photos/src/services/clip-service.ts` for more details. */ -import { existsSync } from "fs"; import jpeg from "jpeg-js"; import fs from "node:fs/promises"; import * as ort from "onnxruntime-node"; import Tokenizer from "../../thirdparty/clip-bpe-ts/mod"; -import { CustomErrors } from "../../types/ipc"; import log from "../log"; import { writeStream } from "../stream"; import { generateTempFilePath } from "../temp"; import { deleteTempFile } from "./ffmpeg"; -import { - createInferenceSession, - downloadModel, - makeCachedInferenceSession, - modelSavePath, -} from "./ml"; +import { makeCachedInferenceSession } from "./ml"; const cachedCLIPImageSession = makeCachedInferenceSession( "clip-image-vit-32-float32.onnx", 351468764 /* 335.2 MB */, ); -const cachedCLIPTextSession = makeCachedInferenceSession( - "clip-text-vit-32-uint8.onnx", - 64173509 /* 61.2 MB */, -); - -let textModelDownloadInProgress = false; - -/* TODO(MR): use the generic method. Then we can remove the exports for the - internal details functions that we use here */ -const textModelPathDownloadingIfNeeded = async () => { - if (textModelDownloadInProgress) - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - - const modelPath = modelSavePath(textModelName); - if (!existsSync(modelPath)) { - log.info("CLIP text model not found, downloading"); - textModelDownloadInProgress = true; - downloadModel(modelPath, textModelName) - .catch((e) => { - // log but otherwise ignore - log.error("CLIP text model download failed", e); - }) - .finally(() => { - textModelDownloadInProgress = false; - }); - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } else { - const localFileSize = (await fs.stat(modelPath)).size; - if (localFileSize !== textModelByteSize) { - log.error( - `CLIP text model size ${localFileSize} does not match the expected size, downloading again`, - ); - textModelDownloadInProgress = true; - downloadModel(modelPath, textModelName) - .catch((e) => { - // log but otherwise ignore - log.error("CLIP text model download failed", e); - }) - .finally(() => { - textModelDownloadInProgress = false; - }); - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } - } - - return modelPath; -}; - -let _textSession: any = null; - -const onnxTextSession = async () => { - if (!_textSession) { - const modelPath = await textModelPathDownloadingIfNeeded(); - _textSession = await createInferenceSession(modelPath); - } - return _textSession; -}; - export const clipImageEmbedding = async (jpegImageData: Uint8Array) => { const tempFilePath = await generateTempFilePath(""); const imageStream = new Response(jpegImageData.buffer).body; @@ -195,6 +130,11 @@ const normalizeEmbedding = (embedding: Float32Array) => { return embedding; }; +const cachedCLIPTextSession = makeCachedInferenceSession( + "clip-text-vit-32-uint8.onnx", + 64173509 /* 61.2 MB */, +); + let _tokenizer: Tokenizer = null; const getTokenizer = () => { if (!_tokenizer) { @@ -203,14 +143,21 @@ const getTokenizer = () => { return _tokenizer; }; -export const clipTextEmbedding = async (text: string) => { - const session = await Promise.race([ +export const clipTextEmbeddingIfAvailable = async (text: string) => { + const sessionOrStatus = await Promise.race([ cachedCLIPTextSession(), - new Promise<"downloading-model">((resolve) => - setTimeout(() => resolve("downloading-model"), 100), - ), + "downloading-model", ]); - await onnxTextSession(); + + // Don't wait for the download to complete + if (typeof sessionOrStatus == "string") { + console.log( + "Ignoring CLIP text embedding request because model download is pending", + ); + return undefined; + } + + const session = sessionOrStatus; const t1 = Date.now(); const tokenizer = getTokenizer(); const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); @@ -223,6 +170,6 @@ export const clipTextEmbedding = async (text: string) => { () => `onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, ); - const textEmbedding = results["output"].data; + const textEmbedding = results["output"].data as Float32Array; return normalizeEmbedding(textEmbedding); }; diff --git a/desktop/src/main/services/ml-face.ts b/desktop/src/main/services/ml-face.ts index 9b16a0798c..2309d193cd 100644 --- a/desktop/src/main/services/ml-face.ts +++ b/desktop/src/main/services/ml-face.ts @@ -15,11 +15,6 @@ const cachedFaceDetectionSession = makeCachedInferenceSession( 30762872 /* 29.3 MB */, ); -const cachedFaceEmbeddingSession = makeCachedInferenceSession( - "mobilefacenet_opset15.onnx", - 5286998 /* 5 MB */, -); - export const detectFaces = async (input: Float32Array) => { const session = await cachedFaceDetectionSession(); const t = Date.now(); @@ -31,6 +26,11 @@ export const detectFaces = async (input: Float32Array) => { return results["output"].data; }; +const cachedFaceEmbeddingSession = makeCachedInferenceSession( + "mobilefacenet_opset15.onnx", + 5286998 /* 5 MB */, +); + export const faceEmbedding = async (input: Float32Array) => { // Dimension of each face (alias) const mobileFaceNetFaceSize = 112; diff --git a/desktop/src/main/services/ml.ts b/desktop/src/main/services/ml.ts index 5efb9c5604..8292596a22 100644 --- a/desktop/src/main/services/ml.ts +++ b/desktop/src/main/services/ml.ts @@ -1,5 +1,5 @@ /** - * @file AI/ML related functionality. + * @file AI/ML related functionality, generic layer. * * @see also `ml-clip.ts`, `ml-face.ts`. * @@ -92,10 +92,10 @@ const modelPathDownloadingIfNeeded = async ( }; /** Return the path where the given {@link modelName} is meant to be saved */ -export const modelSavePath = (modelName: string) => +const modelSavePath = (modelName: string) => path.join(app.getPath("userData"), "models", modelName); -export const downloadModel = async (saveLocation: string, name: string) => { +const downloadModel = async (saveLocation: string, name: string) => { // `mkdir -p` the directory where we want to save the model. const saveDir = path.dirname(saveLocation); await fs.mkdir(saveDir, { recursive: true }); @@ -112,7 +112,7 @@ export const downloadModel = async (saveLocation: string, name: string) => { /** * Crete an ONNX {@link InferenceSession} with some defaults. */ -export const createInferenceSession = async (modelPath: string) => { +const createInferenceSession = async (modelPath: string) => { return await ort.InferenceSession.create(modelPath, { // Restrict the number of threads to 1 intraOpNumThreads: 1, diff --git a/desktop/src/preload.ts b/desktop/src/preload.ts index 7d0df41d57..a9909f6267 100644 --- a/desktop/src/preload.ts +++ b/desktop/src/preload.ts @@ -163,8 +163,10 @@ const runFFmpegCmd = ( const clipImageEmbedding = (jpegImageData: Uint8Array): Promise => ipcRenderer.invoke("clipImageEmbedding", jpegImageData); -const clipTextEmbedding = (text: string): Promise => - ipcRenderer.invoke("clipTextEmbedding", text); +const clipTextEmbeddingIfAvailable = ( + text: string, +): Promise => + ipcRenderer.invoke("clipTextEmbeddingIfAvailable", text); const detectFaces = (input: Float32Array): Promise => ipcRenderer.invoke("detectFaces", input); @@ -263,42 +265,61 @@ const getElectronFilesFromGoogleZip = ( const getDirFiles = (dirPath: string): Promise => ipcRenderer.invoke("getDirFiles", dirPath); -// -// These objects exposed here will become available to the JS code in our -// renderer (the web/ code) as `window.ElectronAPIs.*` -// -// There are a few related concepts at play here, and it might be worthwhile to -// read their (excellent) documentation to get an understanding; -//` -// - ContextIsolation: -// https://www.electronjs.org/docs/latest/tutorial/context-isolation -// -// - IPC https://www.electronjs.org/docs/latest/tutorial/ipc -// -// [Note: Transferring large amount of data over IPC] -// -// Electron's IPC implementation uses the HTML standard Structured Clone -// Algorithm to serialize objects passed between processes. -// https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization -// -// In particular, ArrayBuffer is eligible for structured cloning. -// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm -// -// Also, ArrayBuffer is "transferable", which means it is a zero-copy operation -// operation when it happens across threads. -// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects -// -// In our case though, we're not dealing with threads but separate processes. So -// the ArrayBuffer will be copied: -// > "parameters, errors and return values are **copied** when they're sent over -// the bridge". -// https://www.electronjs.org/docs/latest/api/context-bridge#methods -// -// The copy itself is relatively fast, but the problem with transfering large -// amounts of data is potentially running out of memory during the copy. -// -// For an alternative, see [Note: IPC streams]. -// +/** + * These objects exposed here will become available to the JS code in our + * renderer (the web/ code) as `window.ElectronAPIs.*` + * + * There are a few related concepts at play here, and it might be worthwhile to + * read their (excellent) documentation to get an understanding; + *` + * - ContextIsolation: + * https://www.electronjs.org/docs/latest/tutorial/context-isolation + * + * - IPC https://www.electronjs.org/docs/latest/tutorial/ipc + * + * --- + * + * [Note: Custom errors across Electron/Renderer boundary] + * + * If we need to identify errors thrown by the main process when invoked from + * the renderer process, we can only use the `message` field because: + * + * > Errors thrown throw `handle` in the main process are not transparent as + * > they are serialized and only the `message` property from the original error + * > is provided to the renderer process. + * > + * > - https://www.electronjs.org/docs/latest/tutorial/ipc + * > + * > Ref: https://github.com/electron/electron/issues/24427 + * + * --- + * + * [Note: Transferring large amount of data over IPC] + * + * Electron's IPC implementation uses the HTML standard Structured Clone + * Algorithm to serialize objects passed between processes. + * https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization + * + * In particular, ArrayBuffer is eligible for structured cloning. + * https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm + * + * Also, ArrayBuffer is "transferable", which means it is a zero-copy operation + * operation when it happens across threads. + * https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects + * + * In our case though, we're not dealing with threads but separate processes. So + * the ArrayBuffer will be copied: + * + * > "parameters, errors and return values are **copied** when they're sent over + * > the bridge". + * > + * > https://www.electronjs.org/docs/latest/api/context-bridge#methods + * + * The copy itself is relatively fast, but the problem with transfering large + * amounts of data is potentially running out of memory during the copy. + * + * For an alternative, see [Note: IPC streams]. + */ contextBridge.exposeInMainWorld("electron", { // - General @@ -340,7 +361,7 @@ contextBridge.exposeInMainWorld("electron", { // - ML clipImageEmbedding, - clipTextEmbedding, + clipTextEmbeddingIfAvailable, detectFaces, faceEmbedding, diff --git a/desktop/src/types/ipc.ts b/desktop/src/types/ipc.ts index d986bc4998..7bb2f1fab8 100644 --- a/desktop/src/types/ipc.ts +++ b/desktop/src/types/ipc.ts @@ -33,25 +33,10 @@ export interface PendingUploads { /** * Errors that have special semantics on the web side. - * - * [Note: Custom errors across Electron/Renderer boundary] - * - * We need to use the `message` field to disambiguate between errors thrown by - * the main process when invoked from the renderer process. This is because: - * - * > Errors thrown throw `handle` in the main process are not transparent as - * > they are serialized and only the `message` property from the original error - * > is provided to the renderer process. - * > - * > - https://www.electronjs.org/docs/latest/tutorial/ipc - * > - * > Ref: https://github.com/electron/electron/issues/24427 */ export const CustomErrors = { WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED: "Windows native image processing is not supported", - MODEL_DOWNLOAD_PENDING: - "Model download pending, skipping clip search request", }; /** diff --git a/web/apps/photos/src/services/clip-service.ts b/web/apps/photos/src/services/clip-service.ts index 9528d50753..eae9590fd5 100644 --- a/web/apps/photos/src/services/clip-service.ts +++ b/web/apps/photos/src/services/clip-service.ts @@ -184,8 +184,8 @@ class CLIPService { } }; - getTextEmbedding = async (text: string) => { - return ensureElectron().clipTextEmbedding(text); + getTextEmbeddingIfAvailable = async (text: string) => { + return ensureElectron().clipTextEmbeddingIfAvailable(text); }; private runClipEmbeddingExtraction = async (canceller: AbortController) => { diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index 408c3daa5f..50e95da5b5 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -1,5 +1,4 @@ import log from "@/next/log"; -import { CustomError } from "@ente/shared/error"; import * as chrono from "chrono-node"; import { FILE_TYPE } from "constants/file"; import { t } from "i18next"; @@ -287,24 +286,20 @@ async function getLocationSuggestions(searchPhrase: string) { return [...locationTagSuggestions, ...citySearchSuggestions]; } -async function getClipSuggestion(searchPhrase: string): Promise { - try { - if (!clipService.isPlatformSupported()) { - return null; - } - - const clipResults = await searchClip(searchPhrase); - return { - type: SuggestionType.CLIP, - value: clipResults, - label: searchPhrase, - }; - } catch (e) { - if (!e.message?.includes(CustomError.MODEL_DOWNLOAD_PENDING)) { - log.error("getClipSuggestion failed", e); - } +async function getClipSuggestion( + searchPhrase: string, +): Promise { + if (!clipService.isPlatformSupported()) { return null; } + + const clipResults = await searchClip(searchPhrase); + if (!clipResults) return clipResults; + return { + type: SuggestionType.CLIP, + value: clipResults, + label: searchPhrase, + }; } function searchCollection( @@ -374,9 +369,14 @@ async function searchLocationTag(searchPhrase: string): Promise { return matchedLocationTags; } -async function searchClip(searchPhrase: string): Promise { +const searchClip = async ( + searchPhrase: string, +): Promise => { + const textEmbedding = + await clipService.getTextEmbeddingIfAvailable(searchPhrase); + if (!textEmbedding) return undefined; + const imageEmbeddings = await getLocalEmbeddings(); - const textEmbedding = await clipService.getTextEmbedding(searchPhrase); const clipSearchResult = new Map( ( await Promise.all( @@ -394,7 +394,7 @@ async function searchClip(searchPhrase: string): Promise { ); return clipSearchResult; -} +}; function convertSuggestionToSearchQuery(option: Suggestion): Search { switch (option.type) { diff --git a/web/packages/next/types/ipc.ts b/web/packages/next/types/ipc.ts index b699053886..7e51f407bb 100644 --- a/web/packages/next/types/ipc.ts +++ b/web/packages/next/types/ipc.ts @@ -240,7 +240,18 @@ export interface Electron { clipImageEmbedding: (jpegImageData: Uint8Array) => Promise; /** - * Return a CLIP embedding of the given image. + * Return a CLIP embedding of the given image if we already have the model + * downloaded and prepped. If the model is not available return `undefined`. + * + * This differs from the other sibling ML functions in that it doesn't wait + * for the model download to finish. It does trigger a model download, but + * then immediately returns `undefined`. At some future point, when the + * model downloaded finishes, calls to this function will start returning + * the result we seek. + * + * The reason for doing it in this asymmetric way is because CLIP text + * embeddings are used as part of deducing user initiated search results, + * and we don't want to block that interaction on a large network request. * * See: [Note: CLIP based magic search] * @@ -248,7 +259,9 @@ export interface Electron { * * @returns A CLIP embedding. */ - clipTextEmbedding: (text: string) => Promise; + clipTextEmbeddingIfAvailable: ( + text: string, + ) => Promise; /** * Detect faces in the given image using YOLO. diff --git a/web/packages/shared/error/index.ts b/web/packages/shared/error/index.ts index 12a87d2dba..c91f21413f 100644 --- a/web/packages/shared/error/index.ts +++ b/web/packages/shared/error/index.ts @@ -84,8 +84,6 @@ export const CustomError = { ServerError: "server error", FILE_NOT_FOUND: "file not found", UNSUPPORTED_PLATFORM: "Unsupported platform", - MODEL_DOWNLOAD_PENDING: - "Model download pending, skipping clip search request", UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch", URL_ALREADY_SET: "url already set", FILE_CONVERSION_FAILED: "file conversion failed",