Float32 opt for clip

This commit is contained in:
Manav Rathi
2024-09-03 14:11:15 +05:30
parent b30705fa31
commit 5e65001e37
6 changed files with 56 additions and 9 deletions

View File

@@ -17,6 +17,7 @@ import AsyncSelect from "react-select/async";
import { InputActionMeta } from "react-select/src/types";
import { City } from "services/locationSearchService";
import {
clearSearchCaches,
getAutoCompleteSuggestions,
getDefaultOptions,
} from "services/searchService";
@@ -97,6 +98,7 @@ export default function SearchInput(props: Iprops) {
props.setIsOpen(false);
setValue(null);
setQuery("");
clearSearchCaches();
}
};

View File

@@ -2,6 +2,7 @@ import { isDesktop } from "@/base/app";
import log from "@/base/log";
import { FileType } from "@/media/file-type";
import {
clearCachedCLIPIndexes,
clipMatches,
isMLEnabled,
isMLSupported,
@@ -61,6 +62,10 @@ export const getAutoCompleteSuggestions =
}
};
export const clearSearchCaches = async () => {
await clearCachedCLIPIndexes();
};
async function convertSuggestionsToOptions(
suggestions: Suggestion[],
): Promise<SearchOption[]> {

View File

@@ -2,7 +2,7 @@ import type { ElectronMLWorker } from "@/base/types/ipc";
import type { ImageBitmapAndData } from "./blob";
import { clipIndexes } from "./db";
import { pixelRGBBilinear } from "./image";
import { dotProduct, norm } from "./math";
import { dotProductF32, normF32 } from "./math";
import type { CLIPMatches } from "./worker-types";
/**
@@ -106,13 +106,13 @@ export const indexCLIP = async (
image: ImageBitmapAndData,
electron: ElectronMLWorker,
): Promise<CLIPIndex> => ({
embedding: await computeEmbedding(image.data, electron),
embedding: Array.from(await computeEmbedding(image.data, electron)),
});
const computeEmbedding = async (
imageData: ImageData,
electron: ElectronMLWorker,
): Promise<number[]> => {
): Promise<Float32Array> => {
const clipInput = convertToCLIPInput(imageData);
return normalized(await electron.computeCLIPImageEmbedding(clipInput));
};
@@ -158,9 +158,8 @@ const convertToCLIPInput = (imageData: ImageData) => {
};
const normalized = (embedding: Float32Array) => {
const nums = Array.from(embedding);
const n = norm(nums);
return nums.map((v) => v / n);
const n = normF32(embedding);
return embedding.map((v) => v / n);
};
/**
@@ -177,17 +176,42 @@ export const clipMatches = async (
if (!t) return undefined;
const textEmbedding = normalized(t);
const items = (await clipIndexes()).map(
const items = (await cachedOrReadCLIPIndexes()).map(
({ fileID, embedding }) =>
// What we want to do is `cosineSimilarity`, but since both the
// embeddings involved are already normalized, we can save the norm
// calculations and directly do their `dotProduct`.
//
// This code is on the hot path, so these optimizations help.
[fileID, dotProduct(embedding, textEmbedding)] as const,
[fileID, dotProductF32(embedding, textEmbedding)] as const,
);
// This score threshold was obtain heuristically. 0.2 generally gives solid
// results, and around 0.15 we start getting many false positives (all this
// is query dependent too).
return new Map(items.filter(([, score]) => score >= 0.175));
};
let _cachedCLIPIndexes:
| { fileID: number; embedding: Float32Array }[]
| undefined;
/**
* Cache the CLIP indexes for the duration of a "search session" to avoid
* converting them from number[] to Float32Array during the match.
*
* Converting them to Float32Array gives a big performance boost (See: [Note:
* Dot product performance]). But doing that each time loses out on the
* amortized benefit, so this temporary cache is as attempt to alleviate that.
*
* Once the user is done searching (for now), call
* {@link clearCachedCLIPIndexes}.
*/
const cachedOrReadCLIPIndexes = async () =>
_cachedCLIPIndexes ??
(await clipIndexes()).map(({ fileID, embedding }) => ({
fileID,
embedding: new Float32Array(embedding),
}));
export const clearCachedCLIPIndexes = () => (_cachedCLIPIndexes = undefined);

View File

@@ -602,6 +602,12 @@ export const clipMatches = (
): Promise<CLIPMatches | undefined> =>
worker().then((w) => w.clipMatches(searchPhrase));
/**
* Clear any cached intermediate state created during a search session.
*/
export const clearCachedCLIPIndexes = () =>
worker().then((w) => w.clearCachedCLIPIndexes());
/**
* Return the IDs of all the faces in the given {@link enteFile} that are not
* associated with a person cluster.

View File

@@ -43,7 +43,6 @@ export const dotProductF32 = (v1: Float32Array, v2: Float32Array) => {
return d;
};
/**
* Return the L2-norm ("magnitude") of the given vector.
*
@@ -54,6 +53,9 @@ export const dotProductF32 = (v1: Float32Array, v2: Float32Array) => {
export const norm = (v: number[]) =>
Math.sqrt(v.reduce((a, x) => a + x * x, 0));
export const normF32 = (v: Float32Array) =>
Math.sqrt(v.reduce((a, x) => a + x * x, 0));
/**
* Return the cosine similarity of the two given vectors.
*

View File

@@ -19,6 +19,7 @@ import {
type ImageBitmapAndData,
} from "./blob";
import {
clearCachedCLIPIndexes,
clipIndexingVersion,
clipMatches,
indexCLIP,
@@ -195,6 +196,13 @@ export class MLWorker {
return clipMatches(searchPhrase, ensure(this.electron));
}
/**
* Clear cached intermediate state preserved during a search "session".
*/
clearCachedCLIPIndexes() {
clearCachedCLIPIndexes();
}
private async tick() {
log.debug(() => [
"ml/tick",