Clip matches
This commit is contained in:
@@ -24,8 +24,6 @@ import locationSearchService, { City } from "./locationSearchService";
|
||||
|
||||
const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]);
|
||||
|
||||
// const CLIP_SCORE_THRESHOLD = 0.23;
|
||||
|
||||
export const getDefaultOptions = async () => {
|
||||
return [
|
||||
// TODO-ML(MR): Skip this for now if indexing is disabled (eventually
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import type { Electron } from "@/next/types/ipc";
|
||||
import type { ImageBitmapAndData } from "./bitmap";
|
||||
import { clipIndexes } from "./db";
|
||||
import { pixelRGBBicubic } from "./image";
|
||||
import type { MLWorkerElectron } from "./worker-electron";
|
||||
|
||||
@@ -120,8 +121,9 @@ const computeEmbedding = async (
|
||||
electron: MLWorkerElectron,
|
||||
): Promise<number[]> => {
|
||||
const clipInput = convertToCLIPInput(imageData);
|
||||
const embedding = await electron.computeCLIPImageEmbedding(clipInput);
|
||||
return Array.from(normalizeEmbedding(embedding));
|
||||
return normalizedEmbedding(
|
||||
await electron.computeCLIPImageEmbedding(clipInput),
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -168,13 +170,70 @@ const convertToCLIPInput = (imageData: ImageData) => {
|
||||
return clipInput;
|
||||
};
|
||||
|
||||
const normalizeEmbedding = (embedding: Float32Array) => {
|
||||
let normalization = 0;
|
||||
for (const v of embedding) normalization += v * v;
|
||||
|
||||
const sqrtNormalization = Math.sqrt(normalization);
|
||||
for (let index = 0; index < embedding.length; index++)
|
||||
embedding[index] = ensure(embedding[index]) / sqrtNormalization;
|
||||
|
||||
return embedding;
|
||||
const normalizedEmbedding = (embedding: Float32Array) => {
|
||||
const norm = embedding.reduce((a, v) => a + v * v, 0);
|
||||
const sqrtNorm = Math.sqrt(norm);
|
||||
return Array.from(embedding).map((v) => v / sqrtNorm);
|
||||
};
|
||||
|
||||
/**
|
||||
* Use CLIP to perform a natural language search over image embeddings.
|
||||
*
|
||||
* @param searchPhrase The text entered by the user in the search box.
|
||||
*
|
||||
* @param electron The {@link Electron} instance to use to communicate with the
|
||||
* native code running in our desktop app (the embedding happens in the native
|
||||
* layer).
|
||||
*
|
||||
* It return a list of files that should be shown in the search results. The
|
||||
* actual return type is a map from fileIDs to the scores they got (higher is
|
||||
* better). This map will only contains entries whose score was above our
|
||||
* minimum threshold.
|
||||
*
|
||||
* The result can also be `undefined`, which indicates that the download for the
|
||||
* ML model is still in progress (trying again later should succeed).
|
||||
*/
|
||||
export const clipMatches = async (
|
||||
searchPhrase: string,
|
||||
electron: Electron,
|
||||
): Promise<Map<number, number> | undefined> => {
|
||||
const textEmbedding = await electron
|
||||
.computeCLIPTextEmbeddingIfAvailable(searchPhrase)
|
||||
.then((e) => (e ? normalizedEmbedding(e) : e));
|
||||
if (!textEmbedding) return undefined;
|
||||
|
||||
const imageEmbeddings = await clipIndexes();
|
||||
|
||||
return new Map<number, number>(
|
||||
(
|
||||
await Promise.all(
|
||||
imageEmbeddings.map(
|
||||
async ({
|
||||
fileID,
|
||||
embedding,
|
||||
}): Promise<[number, number]> => [
|
||||
fileID,
|
||||
await clipMatchScore(embedding, textEmbedding),
|
||||
],
|
||||
),
|
||||
)
|
||||
).filter(([, score]) => score >= 0.23),
|
||||
);
|
||||
};
|
||||
|
||||
const clipMatchScore = async (
|
||||
imageEmbedding: number[],
|
||||
textEmbedding: number[],
|
||||
) => {
|
||||
if (imageEmbedding.length != textEmbedding.length)
|
||||
throw Error(
|
||||
`CLIP image embeddings (${imageEmbedding.length}) and text embeddings (${textEmbedding.length}) length mismatch`,
|
||||
);
|
||||
|
||||
let score = 0;
|
||||
for (let i = 0; i < imageEmbedding.length; i++) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
score += imageEmbedding[i]! * textEmbedding[i]!;
|
||||
}
|
||||
return score;
|
||||
};
|
||||
|
||||
@@ -272,11 +272,11 @@ export const faceIndex = async (fileID: number) => {
|
||||
};
|
||||
|
||||
/**
|
||||
* Return all CLIP embeddings present locally.
|
||||
* Return all CLIP indexes present locally.
|
||||
*/
|
||||
export const clipEmbeddings = async () => {
|
||||
export const clipIndexes = async () => {
|
||||
const db = await mlDB();
|
||||
return (await db.getAll("clip-index")).map((index) => index.embedding);
|
||||
return await db.getAll("clip-index");
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user