From b8b2d5873452ab812e77f453f6a795dba89b16e7 Mon Sep 17 00:00:00 2001 From: Manav Rathi Date: Tue, 9 Jul 2024 15:29:34 +0530 Subject: [PATCH] Clip matches --- web/apps/photos/src/services/searchService.ts | 2 - web/packages/new/photos/services/ml/clip.ts | 83 ++++++++++++++++--- web/packages/new/photos/services/ml/db.ts | 6 +- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index 84a38e144a..c1625488e3 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -24,8 +24,6 @@ import locationSearchService, { City } from "./locationSearchService"; const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]); -// const CLIP_SCORE_THRESHOLD = 0.23; - export const getDefaultOptions = async () => { return [ // TODO-ML(MR): Skip this for now if indexing is disabled (eventually diff --git a/web/packages/new/photos/services/ml/clip.ts b/web/packages/new/photos/services/ml/clip.ts index ef59fed210..adad41b18b 100644 --- a/web/packages/new/photos/services/ml/clip.ts +++ b/web/packages/new/photos/services/ml/clip.ts @@ -1,6 +1,7 @@ import type { EnteFile } from "@/new/photos/types/file"; -import { ensure } from "@/utils/ensure"; +import type { Electron } from "@/next/types/ipc"; import type { ImageBitmapAndData } from "./bitmap"; +import { clipIndexes } from "./db"; import { pixelRGBBicubic } from "./image"; import type { MLWorkerElectron } from "./worker-electron"; @@ -120,8 +121,9 @@ const computeEmbedding = async ( electron: MLWorkerElectron, ): Promise => { const clipInput = convertToCLIPInput(imageData); - const embedding = await electron.computeCLIPImageEmbedding(clipInput); - return Array.from(normalizeEmbedding(embedding)); + return normalizedEmbedding( + await electron.computeCLIPImageEmbedding(clipInput), + ); }; /** @@ -168,13 +170,70 @@ const convertToCLIPInput = (imageData: ImageData) => { return clipInput; }; -const normalizeEmbedding = (embedding: Float32Array) => { - let normalization = 0; - for (const v of embedding) normalization += v * v; - - const sqrtNormalization = Math.sqrt(normalization); - for (let index = 0; index < embedding.length; index++) - embedding[index] = ensure(embedding[index]) / sqrtNormalization; - - return embedding; +const normalizedEmbedding = (embedding: Float32Array) => { + const norm = embedding.reduce((a, v) => a + v * v, 0); + const sqrtNorm = Math.sqrt(norm); + return Array.from(embedding).map((v) => v / sqrtNorm); +}; + +/** + * Use CLIP to perform a natural language search over image embeddings. + * + * @param searchPhrase The text entered by the user in the search box. + * + * @param electron The {@link Electron} instance to use to communicate with the + * native code running in our desktop app (the embedding happens in the native + * layer). + * + * It return a list of files that should be shown in the search results. The + * actual return type is a map from fileIDs to the scores they got (higher is + * better). This map will only contains entries whose score was above our + * minimum threshold. + * + * The result can also be `undefined`, which indicates that the download for the + * ML model is still in progress (trying again later should succeed). + */ +export const clipMatches = async ( + searchPhrase: string, + electron: Electron, +): Promise | undefined> => { + const textEmbedding = await electron + .computeCLIPTextEmbeddingIfAvailable(searchPhrase) + .then((e) => (e ? normalizedEmbedding(e) : e)); + if (!textEmbedding) return undefined; + + const imageEmbeddings = await clipIndexes(); + + return new Map( + ( + await Promise.all( + imageEmbeddings.map( + async ({ + fileID, + embedding, + }): Promise<[number, number]> => [ + fileID, + await clipMatchScore(embedding, textEmbedding), + ], + ), + ) + ).filter(([, score]) => score >= 0.23), + ); +}; + +const clipMatchScore = async ( + imageEmbedding: number[], + textEmbedding: number[], +) => { + if (imageEmbedding.length != textEmbedding.length) + throw Error( + `CLIP image embeddings (${imageEmbedding.length}) and text embeddings (${textEmbedding.length}) length mismatch`, + ); + + let score = 0; + for (let i = 0; i < imageEmbedding.length; i++) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + score += imageEmbedding[i]! * textEmbedding[i]!; + } + return score; }; diff --git a/web/packages/new/photos/services/ml/db.ts b/web/packages/new/photos/services/ml/db.ts index b75bb9e6c6..6a0fee1b6a 100644 --- a/web/packages/new/photos/services/ml/db.ts +++ b/web/packages/new/photos/services/ml/db.ts @@ -272,11 +272,11 @@ export const faceIndex = async (fileID: number) => { }; /** - * Return all CLIP embeddings present locally. + * Return all CLIP indexes present locally. */ -export const clipEmbeddings = async () => { +export const clipIndexes = async () => { const db = await mlDB(); - return (await db.getAll("clip-index")).map((index) => index.embedding); + return await db.getAll("clip-index"); }; /**