Clip matches

This commit is contained in:
Manav Rathi
2024-07-09 15:29:34 +05:30
parent 920685e15c
commit b8b2d58734
3 changed files with 74 additions and 17 deletions

View File

@@ -24,8 +24,6 @@ import locationSearchService, { City } from "./locationSearchService";
const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]);
// const CLIP_SCORE_THRESHOLD = 0.23;
export const getDefaultOptions = async () => {
return [
// TODO-ML(MR): Skip this for now if indexing is disabled (eventually

View File

@@ -1,6 +1,7 @@
import type { EnteFile } from "@/new/photos/types/file";
import { ensure } from "@/utils/ensure";
import type { Electron } from "@/next/types/ipc";
import type { ImageBitmapAndData } from "./bitmap";
import { clipIndexes } from "./db";
import { pixelRGBBicubic } from "./image";
import type { MLWorkerElectron } from "./worker-electron";
@@ -120,8 +121,9 @@ const computeEmbedding = async (
electron: MLWorkerElectron,
): Promise<number[]> => {
const clipInput = convertToCLIPInput(imageData);
const embedding = await electron.computeCLIPImageEmbedding(clipInput);
return Array.from(normalizeEmbedding(embedding));
return normalizedEmbedding(
await electron.computeCLIPImageEmbedding(clipInput),
);
};
/**
@@ -168,13 +170,70 @@ const convertToCLIPInput = (imageData: ImageData) => {
return clipInput;
};
const normalizeEmbedding = (embedding: Float32Array) => {
let normalization = 0;
for (const v of embedding) normalization += v * v;
const sqrtNormalization = Math.sqrt(normalization);
for (let index = 0; index < embedding.length; index++)
embedding[index] = ensure(embedding[index]) / sqrtNormalization;
return embedding;
const normalizedEmbedding = (embedding: Float32Array) => {
const norm = embedding.reduce((a, v) => a + v * v, 0);
const sqrtNorm = Math.sqrt(norm);
return Array.from(embedding).map((v) => v / sqrtNorm);
};
/**
* Use CLIP to perform a natural language search over image embeddings.
*
* @param searchPhrase The text entered by the user in the search box.
*
* @param electron The {@link Electron} instance to use to communicate with the
* native code running in our desktop app (the embedding happens in the native
* layer).
*
* It return a list of files that should be shown in the search results. The
* actual return type is a map from fileIDs to the scores they got (higher is
* better). This map will only contains entries whose score was above our
* minimum threshold.
*
* The result can also be `undefined`, which indicates that the download for the
* ML model is still in progress (trying again later should succeed).
*/
export const clipMatches = async (
searchPhrase: string,
electron: Electron,
): Promise<Map<number, number> | undefined> => {
const textEmbedding = await electron
.computeCLIPTextEmbeddingIfAvailable(searchPhrase)
.then((e) => (e ? normalizedEmbedding(e) : e));
if (!textEmbedding) return undefined;
const imageEmbeddings = await clipIndexes();
return new Map<number, number>(
(
await Promise.all(
imageEmbeddings.map(
async ({
fileID,
embedding,
}): Promise<[number, number]> => [
fileID,
await clipMatchScore(embedding, textEmbedding),
],
),
)
).filter(([, score]) => score >= 0.23),
);
};
const clipMatchScore = async (
imageEmbedding: number[],
textEmbedding: number[],
) => {
if (imageEmbedding.length != textEmbedding.length)
throw Error(
`CLIP image embeddings (${imageEmbedding.length}) and text embeddings (${textEmbedding.length}) length mismatch`,
);
let score = 0;
for (let i = 0; i < imageEmbedding.length; i++) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
score += imageEmbedding[i]! * textEmbedding[i]!;
}
return score;
};

View File

@@ -272,11 +272,11 @@ export const faceIndex = async (fileID: number) => {
};
/**
* Return all CLIP embeddings present locally.
* Return all CLIP indexes present locally.
*/
export const clipEmbeddings = async () => {
export const clipIndexes = async () => {
const db = await mlDB();
return (await db.getAll("clip-index")).map((index) => index.embedding);
return await db.getAll("clip-index");
};
/**