From 5a3838be342ac2bb9656b67bb6a27b12c2c3b0ea Mon Sep 17 00:00:00 2001 From: Manav Rathi Date: Wed, 31 Jul 2024 12:30:15 +0530 Subject: [PATCH] Route via workers --- web/apps/photos/src/services/searchService.ts | 5 ++-- web/packages/new/photos/services/ml/clip.ts | 20 ++++------------ web/packages/new/photos/services/ml/index.ts | 17 +++++++++++++ .../new/photos/services/ml/worker-types.ts | 14 +++++++++-- web/packages/new/photos/services/ml/worker.ts | 24 +++++++++++++++---- 5 files changed, 56 insertions(+), 24 deletions(-) diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index 8b90652e64..750a1fb186 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -1,13 +1,12 @@ import { isDesktop } from "@/base/app"; -import { ensureElectron } from "@/base/electron"; import log from "@/base/log"; import { FileType } from "@/media/file-type"; import { + clipMatches, isMLEnabled, isMLSupported, mlStatusSnapshot, } from "@/new/photos/services/ml"; -import { clipMatches } from "@/new/photos/services/ml/clip"; import type { Person } from "@/new/photos/services/ml/people"; import { EnteFile } from "@/new/photos/types/file"; import * as chrono from "chrono-node"; @@ -374,7 +373,7 @@ const searchClip = async ( searchPhrase: string, ): Promise => { if (!isMLEnabled()) return undefined; - const matches = await clipMatches(searchPhrase, ensureElectron()); + const matches = await clipMatches(searchPhrase); log.debug(() => ["clip/scores", matches]); return matches; }; diff --git a/web/packages/new/photos/services/ml/clip.ts b/web/packages/new/photos/services/ml/clip.ts index f6230f2466..78eff1c04d 100644 --- a/web/packages/new/photos/services/ml/clip.ts +++ b/web/packages/new/photos/services/ml/clip.ts @@ -1,8 +1,9 @@ -import type { Electron, ElectronMLWorker } from "@/base/types/ipc"; +import type { ElectronMLWorker } from "@/base/types/ipc"; import type { ImageBitmapAndData } from "./blob"; import { clipIndexes } from "./db"; import { pixelRGBBicubic } from "./image"; import { dotProduct, norm } from "./math"; +import type { CLIPMatches } from "./worker-types"; /** * The version of the CLIP indexing pipeline implemented by the current client. @@ -166,26 +167,15 @@ const normalized = (embedding: Float32Array) => { }; /** - * Use CLIP to perform a natural language search over image embeddings. - * - * @param searchPhrase The text entered by the user in the search box. - * - * @param electron The {@link Electron} instance to use to communicate with the - * native code running in our desktop app (the embedding happens in the native - * layer). - * - * It returns file (IDs) that should be shown in the search results. They're - * returned as a map from fileIDs to the scores they got (higher is better). - * This map will only contains entries whose score was above our minimum - * threshold. + * Find the files whose CLIP embedding "matches" the given {@link searchPhrase}. * * The result can also be `undefined`, which indicates that the download for the * ML model is still in progress (trying again later should succeed). */ export const clipMatches = async ( searchPhrase: string, - electron: Electron, -): Promise | undefined> => { + electron: ElectronMLWorker, +): Promise => { const t = await electron.computeCLIPTextEmbeddingIfAvailable(searchPhrase); if (!t) return undefined; diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts index 5c114c0eb1..ba8083ad74 100644 --- a/web/packages/new/photos/services/ml/index.ts +++ b/web/packages/new/photos/services/ml/index.ts @@ -19,6 +19,7 @@ import type { UploadItem } from "../upload/types"; import { regenerateFaceCrops } from "./crop"; import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db"; import { MLWorker } from "./worker"; +import type { CLIPMatches } from "./worker-types"; /** * In-memory flag that tracks if ML is enabled. @@ -392,6 +393,22 @@ const setInterimScheduledStatus = () => { const workerDidProcessFile = throttled(updateMLStatusSnapshot, 2000); +/** + * Use CLIP to perform a natural language search over image embeddings. + * + * @param searchPhrase The text entered by the user in the search box. + * + * It returns file (IDs) that should be shown in the search results, along with + * their scores. + * + * The result can also be `undefined`, which indicates that the download for the + * ML model is still in progress (trying again later should succeed). + */ +export const clipMatches = ( + searchPhrase: string, +): Promise => + worker().then((w) => w.clipMatches(searchPhrase)); + /** * Return the IDs of all the faces in the given {@link enteFile} that are not * associated with a person cluster. diff --git a/web/packages/new/photos/services/ml/worker-types.ts b/web/packages/new/photos/services/ml/worker-types.ts index a83a215ea4..72d6bce61b 100644 --- a/web/packages/new/photos/services/ml/worker-types.ts +++ b/web/packages/new/photos/services/ml/worker-types.ts @@ -1,6 +1,5 @@ /** - * @file Types for the objects shared (as a Comlink proxy) by the main thread - * and the ML worker. + * @file Types for the objects shared between the main thread and the ML worker. */ /** @@ -15,3 +14,14 @@ export interface MLWorkerDelegate { */ workerDidProcessFile: () => void; } + +/** + * The result of file ids that should be considered as matches for a particular + * search phrase, each with their associated score. + * + * This is a map of file (IDs) that should be shown in the search results. + * They're returned as a map from fileIDs to the scores they got (higher is + * better). This map will only contains entries whose score was above our + * minimum threshold. + */ +export type CLIPMatches = Map; diff --git a/web/packages/new/photos/services/ml/worker.ts b/web/packages/new/photos/services/ml/worker.ts index e137531f6d..cbeba5f844 100644 --- a/web/packages/new/photos/services/ml/worker.ts +++ b/web/packages/new/photos/services/ml/worker.ts @@ -19,7 +19,12 @@ import { indexableBlobs, type ImageBitmapAndData, } from "./blob"; -import { clipIndexingVersion, indexCLIP, type CLIPIndex } from "./clip"; +import { + clipIndexingVersion, + clipMatches, + indexCLIP, + type CLIPIndex, +} from "./clip"; import { saveFaceCrops } from "./crop"; import { indexableFileIDs, @@ -33,7 +38,7 @@ import { type RemoteDerivedData, } from "./embedding"; import { faceIndexingVersion, indexFaces, type FaceIndex } from "./face"; -import type { MLWorkerDelegate } from "./worker-types"; +import type { CLIPMatches, MLWorkerDelegate } from "./worker-types"; const idleDurationStart = 5; /* 5 seconds */ const idleDurationMax = 16 * 60; /* 16 minutes */ @@ -68,6 +73,9 @@ interface IndexableItem { * - "backfillq": fetching remote embeddings of unindexed items, and then * indexing them if needed, * - "idle": in between state transitions. + * + * In addition, MLWorker can also be invoked for interactive tasks: in + * particular, for finding the closest CLIP match when the user does a search. */ export class MLWorker { private electron: ElectronMLWorker | undefined; @@ -178,6 +186,13 @@ export class MLWorker { return this.state == "indexing"; } + /** + * Find {@link CLIPMatches} for a given {@link searchPhrase}. + */ + async clipMatches(searchPhrase: string): Promise { + return clipMatches(searchPhrase, ensure(this.electron)); + } + private async tick() { log.debug(() => [ "ml/tick", @@ -226,7 +241,7 @@ export class MLWorker { } /** Return the next batch of items to backfill (if any). */ - async backfillQ() { + private async backfillQ() { const userID = ensure(await getKVN("userID")); // Find files that our local DB thinks need syncing. const filesByID = await syncWithLocalFilesAndGetFilesToIndex( @@ -278,7 +293,8 @@ const indexNextBatch = async ( try { await index(item, electron); delegate?.workerDidProcessFile(); - // Possibly unnecessary, but let us drain the microtask queue. + // Let us drain the microtask queue. This also gives a chance for other + // interactive tasks like `clipMatches` to run. await wait(0); } catch (e) { log.warn(`Skipping unindexable file ${item.enteFile.id}`, e);