diff --git a/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx b/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx index 18b7289915..a41983d5c3 100644 --- a/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx +++ b/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx @@ -17,6 +17,7 @@ import AsyncSelect from "react-select/async"; import { InputActionMeta } from "react-select/src/types"; import { City } from "services/locationSearchService"; import { + clearSearchCaches, getAutoCompleteSuggestions, getDefaultOptions, } from "services/searchService"; @@ -97,6 +98,7 @@ export default function SearchInput(props: Iprops) { props.setIsOpen(false); setValue(null); setQuery(""); + clearSearchCaches(); } }; diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index 76b1d5cb2b..35e4015cc1 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -2,6 +2,7 @@ import { isDesktop } from "@/base/app"; import log from "@/base/log"; import { FileType } from "@/media/file-type"; import { + clearCachedCLIPIndexes, clipMatches, isMLEnabled, isMLSupported, @@ -61,6 +62,10 @@ export const getAutoCompleteSuggestions = } }; +export const clearSearchCaches = async () => { + await clearCachedCLIPIndexes(); +}; + async function convertSuggestionsToOptions( suggestions: Suggestion[], ): Promise { diff --git a/web/packages/new/photos/services/ml/clip.ts b/web/packages/new/photos/services/ml/clip.ts index e81036ae56..c4e9ff6692 100644 --- a/web/packages/new/photos/services/ml/clip.ts +++ b/web/packages/new/photos/services/ml/clip.ts @@ -2,7 +2,7 @@ import type { ElectronMLWorker } from "@/base/types/ipc"; import type { ImageBitmapAndData } from "./blob"; import { clipIndexes } from "./db"; import { pixelRGBBilinear } from "./image"; -import { dotProduct, norm } from "./math"; +import { dotProductF32, normF32 } from "./math"; import type { CLIPMatches } from "./worker-types"; /** @@ -106,13 +106,13 @@ export const indexCLIP = async ( image: ImageBitmapAndData, electron: ElectronMLWorker, ): Promise => ({ - embedding: await computeEmbedding(image.data, electron), + embedding: Array.from(await computeEmbedding(image.data, electron)), }); const computeEmbedding = async ( imageData: ImageData, electron: ElectronMLWorker, -): Promise => { +): Promise => { const clipInput = convertToCLIPInput(imageData); return normalized(await electron.computeCLIPImageEmbedding(clipInput)); }; @@ -158,9 +158,8 @@ const convertToCLIPInput = (imageData: ImageData) => { }; const normalized = (embedding: Float32Array) => { - const nums = Array.from(embedding); - const n = norm(nums); - return nums.map((v) => v / n); + const n = normF32(embedding); + return embedding.map((v) => v / n); }; /** @@ -177,17 +176,42 @@ export const clipMatches = async ( if (!t) return undefined; const textEmbedding = normalized(t); - const items = (await clipIndexes()).map( + const items = (await cachedOrReadCLIPIndexes()).map( ({ fileID, embedding }) => // What we want to do is `cosineSimilarity`, but since both the // embeddings involved are already normalized, we can save the norm // calculations and directly do their `dotProduct`. // // This code is on the hot path, so these optimizations help. - [fileID, dotProduct(embedding, textEmbedding)] as const, + [fileID, dotProductF32(embedding, textEmbedding)] as const, ); + // This score threshold was obtain heuristically. 0.2 generally gives solid // results, and around 0.15 we start getting many false positives (all this // is query dependent too). return new Map(items.filter(([, score]) => score >= 0.175)); }; + +let _cachedCLIPIndexes: + | { fileID: number; embedding: Float32Array }[] + | undefined; + +/** + * Cache the CLIP indexes for the duration of a "search session" to avoid + * converting them from number[] to Float32Array during the match. + * + * Converting them to Float32Array gives a big performance boost (See: [Note: + * Dot product performance]). But doing that each time loses out on the + * amortized benefit, so this temporary cache is as attempt to alleviate that. + * + * Once the user is done searching (for now), call + * {@link clearCachedCLIPIndexes}. + */ +const cachedOrReadCLIPIndexes = async () => + _cachedCLIPIndexes ?? + (await clipIndexes()).map(({ fileID, embedding }) => ({ + fileID, + embedding: new Float32Array(embedding), + })); + +export const clearCachedCLIPIndexes = () => (_cachedCLIPIndexes = undefined); diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts index 97ed45c7ea..24305fdba6 100644 --- a/web/packages/new/photos/services/ml/index.ts +++ b/web/packages/new/photos/services/ml/index.ts @@ -602,6 +602,12 @@ export const clipMatches = ( ): Promise => worker().then((w) => w.clipMatches(searchPhrase)); +/** + * Clear any cached intermediate state created during a search session. + */ +export const clearCachedCLIPIndexes = () => + worker().then((w) => w.clearCachedCLIPIndexes()); + /** * Return the IDs of all the faces in the given {@link enteFile} that are not * associated with a person cluster. diff --git a/web/packages/new/photos/services/ml/math.ts b/web/packages/new/photos/services/ml/math.ts index 7da08b3163..c34385cc07 100644 --- a/web/packages/new/photos/services/ml/math.ts +++ b/web/packages/new/photos/services/ml/math.ts @@ -43,7 +43,6 @@ export const dotProductF32 = (v1: Float32Array, v2: Float32Array) => { return d; }; - /** * Return the L2-norm ("magnitude") of the given vector. * @@ -54,6 +53,9 @@ export const dotProductF32 = (v1: Float32Array, v2: Float32Array) => { export const norm = (v: number[]) => Math.sqrt(v.reduce((a, x) => a + x * x, 0)); +export const normF32 = (v: Float32Array) => + Math.sqrt(v.reduce((a, x) => a + x * x, 0)); + /** * Return the cosine similarity of the two given vectors. * diff --git a/web/packages/new/photos/services/ml/worker.ts b/web/packages/new/photos/services/ml/worker.ts index 7b80d3282d..6408188769 100644 --- a/web/packages/new/photos/services/ml/worker.ts +++ b/web/packages/new/photos/services/ml/worker.ts @@ -19,6 +19,7 @@ import { type ImageBitmapAndData, } from "./blob"; import { + clearCachedCLIPIndexes, clipIndexingVersion, clipMatches, indexCLIP, @@ -195,6 +196,13 @@ export class MLWorker { return clipMatches(searchPhrase, ensure(this.electron)); } + /** + * Clear cached intermediate state preserved during a search "session". + */ + clearCachedCLIPIndexes() { + clearCachedCLIPIndexes(); + } + private async tick() { log.debug(() => [ "ml/tick",