[desktop] Consolidate face and CLIP pipelines - Part 2/x (#2408)

* Continuation of https://github.com/ente-io/ente/pull/2391
* CLIP indexing (and search) is disabled for now until the new format is
finalized
This commit is contained in:
Manav Rathi
2024-07-09 20:51:20 +05:30
committed by GitHub
33 changed files with 681 additions and 1309 deletions

View File

@@ -112,21 +112,14 @@ resources (`build`) folder. This is used for thumbnail generation on Linux.
On macOS, we use the `sips` CLI tool for conversion, but that is already
available on the host machine, and is not bundled with our app.
### AI/ML
### ML
[onnxruntime-node](https://github.com/Microsoft/onnxruntime) is used as the
AI/ML runtime. It powers both natural language searches (using CLIP) and face
[onnxruntime-node](https://github.com/Microsoft/onnxruntime) is used as the ML
runtime. It powers both natural language searches (using CLIP) and face
detection (using YOLO).
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding JPEG
data into raw RGB bytes before passing it to ONNX.
html-entities is used by the bundled clip-bpe-ts tokenizer for CLIP.
### Watch Folders
[chokidar](https://github.com/paulmillr/chokidar) is used as a file system
watcher for the watch folders functionality.
[clip-bpe-js](https://github.com/simonwarchol/clip-bpe-js) is used for tokening
the user's search phrase before computing its CLIP (text) embedding.
### ZIP
@@ -135,3 +128,8 @@ reading of large ZIP files (e.g. during imports of Google Takeout ZIPs).
[lru-cache](https://github.com/isaacs/node-lru-cache) is used to cache file ZIP
handles to avoid reopening them for every operation.
### Watch folders
[chokidar](https://github.com/paulmillr/chokidar) is used as a file system
watcher for the watch folders functionality.

View File

@@ -33,8 +33,6 @@
"electron-store": "^8.2",
"electron-updater": "^6.2",
"ffmpeg-static": "^5.2",
"html-entities": "^2.5",
"jpeg-js": "^0.4",
"lru-cache": "^10.2",
"next-electron-server": "^1",
"node-stream-zip": "^1.15",

View File

@@ -186,10 +186,8 @@ export const attachIPCHandlers = () => {
// - ML
ipcMain.handle(
"computeCLIPImageEmbedding",
(_, jpegImageData: Uint8Array) =>
computeCLIPImageEmbedding(jpegImageData),
ipcMain.handle("computeCLIPImageEmbedding", (_, input: Float32Array) =>
computeCLIPImageEmbedding(input),
);
ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) =>

View File

@@ -1,23 +1,12 @@
// TODO: These arise from the array indexing in the pre-processing code. Isolate
// once that code settles down to its final place (currently duplicated across
// web and desktop).
/* eslint-disable @typescript-eslint/no-non-null-assertion */
/**
* @file Compute CLIP embeddings for images and text.
*
* The embeddings are computed using ONNX runtime, with CLIP as the model.
*
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
*/
import Tokenizer from "clip-bpe-js";
import jpeg from "jpeg-js";
import fs from "node:fs/promises";
import * as ort from "onnxruntime-node";
import log from "../log";
import { writeStream } from "../stream";
import { ensure, wait } from "../utils/common";
import { deleteTempFile, makeTempFilePath } from "../utils/temp";
import { makeCachedInferenceSession } from "./ml";
const cachedCLIPImageSession = makeCachedInferenceSession(
@@ -25,241 +14,16 @@ const cachedCLIPImageSession = makeCachedInferenceSession(
351468764 /* 335.2 MB */,
);
export const computeCLIPImageEmbedding = async (jpegImageData: Uint8Array) => {
const tempFilePath = await makeTempFilePath();
const imageStream = new Response(jpegImageData.buffer).body;
await writeStream(tempFilePath, ensure(imageStream));
try {
return await clipImageEmbedding_(tempFilePath);
} finally {
await deleteTempFile(tempFilePath);
}
};
const clipImageEmbedding_ = async (jpegFilePath: string) => {
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
const session = await cachedCLIPImageSession();
const t1 = Date.now();
const rgbData = await getRGBData(jpegFilePath);
const t = Date.now();
const feeds = {
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
};
const t2 = Date.now();
const results = await session.run(feeds);
log.debug(
() =>
`ONNX/CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`);
/* Need these model specific casts to type the result */
const imageEmbedding = ensure(results.output).data as Float32Array;
return normalizeEmbedding(imageEmbedding);
};
const getRGBData = async (jpegFilePath: string): Promise<Float32Array> => {
const jpegData = await fs.readFile(jpegFilePath);
const rawImageData = jpeg.decode(jpegData, {
useTArray: true,
formatAsRGBA: true,
}); // TODO: manav: make sure this works on all images, not just jpeg
const pixelData = rawImageData.data;
const requiredWidth = 224;
const requiredHeight = 224;
const requiredSize = 3 * requiredWidth * requiredHeight;
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
const scale = Math.max(
requiredWidth / rawImageData.width,
requiredHeight / rawImageData.height,
);
const scaledWidth = Math.round(rawImageData.width * scale);
const scaledHeight = Math.round(rawImageData.height * scale);
const widthOffset = Math.max(0, scaledWidth - requiredWidth) / 2;
const heightOffset = Math.max(0, scaledHeight - requiredHeight) / 2;
const processedImage = new Float32Array(requiredSize);
// Populate the Float32Array with normalized pixel values.
let pi = 0;
const cOffsetG = requiredHeight * requiredWidth; // ChannelOffsetGreen
const cOffsetB = 2 * requiredHeight * requiredWidth; // ChannelOffsetBlue
for (let h = 0 + heightOffset; h < scaledHeight - heightOffset; h++) {
for (let w = 0 + widthOffset; w < scaledWidth - widthOffset; w++) {
const { r, g, b } = pixelRGBBicubic(
w / scale,
h / scale,
pixelData,
rawImageData.width,
rawImageData.height,
);
processedImage[pi] = (r / 255.0 - mean[0]!) / std[0]!;
processedImage[pi + cOffsetG] = (g / 255.0 - mean[1]!) / std[1]!;
processedImage[pi + cOffsetB] = (b / 255.0 - mean[2]!) / std[2]!;
pi++;
}
}
return processedImage;
};
// NOTE: exact duplicate of the function in web/apps/photos/src/services/face/image.ts
const pixelRGBBicubic = (
fx: number,
fy: number,
imageData: Uint8Array,
imageWidth: number,
imageHeight: number,
) => {
// Clamp to image boundaries.
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1);
const px = x - 1;
const nx = x + 1;
const ax = x + 2;
const y = Math.trunc(fy) - (fy >= 0.0 ? 0 : 1);
const py = y - 1;
const ny = y + 1;
const ay = y + 2;
const dx = fx - x;
const dy = fy - y;
const cubic = (
dx: number,
ipp: number,
icp: number,
inp: number,
iap: number,
) =>
icp +
0.5 *
(dx * (-ipp + inp) +
dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) +
dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap));
const icc = pixelRGBA(imageData, imageWidth, imageHeight, x, y);
const ipp =
px < 0 || py < 0
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, px, py);
const icp =
px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, x, py);
const inp =
py < 0 || nx >= imageWidth
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, py);
const iap =
ax >= imageWidth || py < 0
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, py);
const ip0 = cubic(dx, ipp.r, icp.r, inp.r, iap.r);
const ip1 = cubic(dx, ipp.g, icp.g, inp.g, iap.g);
const ip2 = cubic(dx, ipp.b, icp.b, inp.b, iap.b);
// const ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a);
const ipc =
px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, px, y);
const inc =
nx >= imageWidth
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, y);
const iac =
ax >= imageWidth
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, y);
const ic0 = cubic(dx, ipc.r, icc.r, inc.r, iac.r);
const ic1 = cubic(dx, ipc.g, icc.g, inc.g, iac.g);
const ic2 = cubic(dx, ipc.b, icc.b, inc.b, iac.b);
// const ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a);
const ipn =
px < 0 || ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, px, ny);
const icn =
ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, x, ny);
const inn =
nx >= imageWidth || ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, ny);
const ian =
ax >= imageWidth || ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, ny);
const in0 = cubic(dx, ipn.r, icn.r, inn.r, ian.r);
const in1 = cubic(dx, ipn.g, icn.g, inn.g, ian.g);
const in2 = cubic(dx, ipn.b, icn.b, inn.b, ian.b);
// const in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a);
const ipa =
px < 0 || ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, px, ay);
const ica =
ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, x, ay);
const ina =
nx >= imageWidth || ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, ay);
const iaa =
ax >= imageWidth || ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, ay);
const ia0 = cubic(dx, ipa.r, ica.r, ina.r, iaa.r);
const ia1 = cubic(dx, ipa.g, ica.g, ina.g, iaa.g);
const ia2 = cubic(dx, ipa.b, ica.b, ina.b, iaa.b);
// const ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a);
const c0 = Math.trunc(clamp(cubic(dy, ip0, ic0, in0, ia0), 0, 255));
const c1 = Math.trunc(clamp(cubic(dy, ip1, ic1, in1, ia1), 0, 255));
const c2 = Math.trunc(clamp(cubic(dy, ip2, ic2, in2, ia2), 0, 255));
// const c3 = cubic(dy, ip3, ic3, in3, ia3);
return { r: c0, g: c1, b: c2 };
};
// NOTE: exact duplicate of the function in web/apps/photos/src/services/face/image.ts
const clamp = (value: number, min: number, max: number) =>
Math.min(max, Math.max(min, value));
// NOTE: exact duplicate of the function in web/apps/photos/src/services/face/image.ts
const pixelRGBA = (
imageData: Uint8Array,
width: number,
height: number,
x: number,
y: number,
) => {
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
const index = (y * width + x) * 4;
return {
r: ensure(imageData[index]),
g: ensure(imageData[index + 1]),
b: ensure(imageData[index + 2]),
a: ensure(imageData[index + 3]),
};
};
const normalizeEmbedding = (embedding: Float32Array) => {
let normalization = 0;
for (const v of embedding) normalization += v * v;
const sqrtNormalization = Math.sqrt(normalization);
for (let index = 0; index < embedding.length; index++)
embedding[index] = ensure(embedding[index]) / sqrtNormalization;
return embedding;
return ensure(results.output).data as Float32Array;
};
const cachedCLIPTextSession = makeCachedInferenceSession(
@@ -302,6 +66,5 @@ export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
() =>
`ONNX/CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const textEmbedding = ensure(results.output).data as Float32Array;
return normalizeEmbedding(textEmbedding);
return ensure(results.output).data as Float32Array;
};

View File

@@ -163,8 +163,8 @@ const ffmpegExec = (
// - ML
const computeCLIPImageEmbedding = (jpegImageData: Uint8Array) =>
ipcRenderer.invoke("computeCLIPImageEmbedding", jpegImageData);
const computeCLIPImageEmbedding = (input: Float32Array) =>
ipcRenderer.invoke("computeCLIPImageEmbedding", input);
const computeCLIPTextEmbeddingIfAvailable = (text: string) =>
ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text);

View File

@@ -1900,11 +1900,6 @@ hosted-git-info@^4.1.0:
dependencies:
lru-cache "^6.0.0"
html-entities@^2.5:
version "2.5.2"
resolved "https://registry.yarnpkg.com/html-entities/-/html-entities-2.5.2.tgz#201a3cf95d3a15be7099521620d19dfb4f65359f"
integrity sha512-K//PSRMQk4FZ78Kyau+mZurHn3FH0Vwr+H36eE0rPbeYkRRi9YxceYPhuN60UwWorxyKHhqoAJl2OFKa4BVtaA==
http-cache-semantics@^4.0.0, http-cache-semantics@^4.1.0:
version "4.1.1"
resolved "https://registry.yarnpkg.com/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz#abe02fcb2985460bf0323be664436ec3476a6d5a"
@@ -2130,11 +2125,6 @@ jake@^10.8.5:
filelist "^1.0.4"
minimatch "^3.1.2"
jpeg-js@^0.4:
version "0.4.4"
resolved "https://registry.yarnpkg.com/jpeg-js/-/jpeg-js-0.4.4.tgz#a9f1c6f1f9f0fa80cdb3484ed9635054d28936aa"
integrity sha512-WZzeDOEtTOBK4Mdsar0IqEU5sMr3vSV2RqkAIzUEV2BHnUfKGyswWFPFwK5EeDo93K3FohSHbLAjj0s1Wzd+dg==
js-yaml@^4.1.0:
version "4.1.0"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-4.1.0.tgz#c1fb65f8f5017901cdd2c951864ba18458a10602"

View File

@@ -28,7 +28,6 @@
"memoize-one": "^6.0.0",
"ml-matrix": "^6.11",
"p-debounce": "^4.0.0",
"p-queue": "^7.1.0",
"photoswipe": "file:./thirdparty/photoswipe",
"piexifjs": "^1.0.6",
"pure-react-carousel": "^1.30.1",

View File

@@ -1,8 +1,7 @@
import { VerticallyCenteredFlex } from "@ente/shared/components/Container";
import { EnteMenuItem } from "@ente/shared/components/Menu/EnteMenuItem";
import ChevronRight from "@mui/icons-material/ChevronRight";
import ScienceIcon from "@mui/icons-material/Science";
import { Box, DialogProps, Stack, Typography } from "@mui/material";
import { Box, DialogProps, Stack } from "@mui/material";
import { EnteDrawer } from "components/EnteDrawer";
import { MenuItemGroup } from "components/Menu/MenuItemGroup";
import MenuSectionTitle from "components/Menu/MenuSectionTitle";
@@ -11,9 +10,7 @@ import { MLSearchSettings } from "components/ml/MLSearchSettings";
import { t } from "i18next";
import isElectron from "is-electron";
import { AppContext } from "pages/_app";
import { useContext, useEffect, useState } from "react";
import { CLIPIndexingStatus, clipService } from "services/clip-service";
import { formatNumber } from "utils/number/format";
import { useContext, useState } from "react";
export default function AdvancedSettings({ open, onClose, onRootClose }) {
const appContext = useContext(AppContext);
@@ -39,16 +36,17 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
appContext.setIsCFProxyDisabled(!appContext.isCFProxyDisabled);
};
const [indexingStatus, setIndexingStatus] = useState<CLIPIndexingStatus>({
indexed: 0,
pending: 0,
});
// TODO-ML:
// const [indexingStatus, setIndexingStatus] = useState<CLIPIndexingStatus>({
// indexed: 0,
// pending: 0,
// });
useEffect(() => {
clipService.setOnUpdateHandler(setIndexingStatus);
clipService.getIndexingStatus().then((st) => setIndexingStatus(st));
return () => clipService.setOnUpdateHandler(undefined);
}, []);
// useEffect(() => {
// clipService.setOnUpdateHandler(setIndexingStatus);
// clipService.getIndexingStatus().then((st) => setIndexingStatus(st));
// return () => clipService.setOnUpdateHandler(undefined);
// }, []);
return (
<EnteDrawer
@@ -97,7 +95,7 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
/>
</Box>
{isElectron() && (
{/* TODO-ML: isElectron() && (
<Box>
<MenuSectionTitle
title={t("MAGIC_SEARCH_STATUS")}
@@ -131,7 +129,7 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
</VerticallyCenteredFlex>
</Stack>
</Box>
)}
)*/}
</Stack>
</Box>
</Stack>

View File

@@ -81,7 +81,6 @@ import {
useState,
} from "react";
import { useDropzone } from "react-dropzone";
import { clipService } from "services/clip-service";
import {
constructEmailList,
constructUserIDToEmailMap,
@@ -390,7 +389,6 @@ export default function Gallery() {
syncWithRemote(false, true);
}, SYNC_INTERVAL_IN_MICROSECONDS);
if (electron) {
// void clipService.setupOnFileUploadListener();
electron.onMainWindowFocus(() => syncWithRemote(false, true));
if (await shouldShowWhatsNew()) setOpenWhatsNew(true);
}
@@ -398,10 +396,7 @@ export default function Gallery() {
main();
return () => {
clearInterval(syncInterval.current);
if (electron) {
electron.onMainWindowFocus(undefined);
clipService.removeOnFileUploadListener();
}
if (electron) electron.onMainWindowFocus(undefined);
};
}, []);

View File

@@ -1,402 +0,0 @@
import { FILE_TYPE } from "@/media/file-type";
import downloadManager from "@/new/photos/services/download";
import { getAllLocalFiles, getLocalFiles } from "@/new/photos/services/files";
import { EnteFile } from "@/new/photos/types/file";
import { ensureElectron } from "@/next/electron";
import log from "@/next/log";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError } from "@ente/shared/error";
import { Events, eventBus } from "@ente/shared/events";
import { LS_KEYS, getData } from "@ente/shared/storage/localStorage";
import PQueue from "p-queue";
import { Embedding } from "types/embedding";
import { getPersonalFiles } from "utils/file";
import { localCLIPEmbeddings, putEmbedding } from "./embeddingService";
/** Status of CLIP indexing on the images in the user's local library. */
export interface CLIPIndexingStatus {
/** Number of items pending indexing. */
pending: number;
/** Number of items that have already been indexed. */
indexed: number;
}
/**
* Use a CLIP based neural network for natural language search.
*
* [Note: CLIP based magic search]
*
* CLIP (Contrastive Language-Image Pretraining) is a neural network trained on
* (image, text) pairs. It can be thought of as two separate (but jointly
* trained) encoders - one for images, and one for text - that both map to the
* same embedding space.
*
* We use this for natural language search within the app (aka "magic search"):
*
* 1. Pre-compute an embedding for each image.
*
* 2. When the user searches, compute an embedding for the search term.
*
* 3. Use cosine similarity to find the find the image (embedding) closest to
* the text (embedding).
*
* More details are in our [blog
* post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the
* initial launch of this feature using the GGML runtime.
*
* Since the initial launch, we've switched over to another runtime,
* [ONNX](https://onnxruntime.ai).
*
* Note that we don't train the neural network - we only use one of the publicly
* available pre-trained neural networks for inference. These neural networks
* are wholly defined by their connectivity and weights. ONNX, our ML runtimes,
* loads these weights and instantiates a running network that we can use to
* compute the embeddings.
*
* Theoretically, the same CLIP model can be loaded by different frameworks /
* runtimes, but in practice each runtime has its own preferred format, and
* there are also quantization tradeoffs. So there is a specific model (a binary
* encoding of weights) tied to our current runtime that we use.
*
* To ensure that the embeddings, for the most part, can be shared, whenever
* possible we try to ensure that all the preprocessing steps, and the model
* itself, is the same across clients - web and mobile.
*/
class CLIPService {
private embeddingExtractionInProgress: AbortController | null = null;
private reRunNeeded = false;
private indexingStatus: CLIPIndexingStatus = {
pending: 0,
indexed: 0,
};
private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined;
private liveEmbeddingExtractionQueue: PQueue;
private onFileUploadedHandler:
| ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void)
| null = null;
constructor() {
this.liveEmbeddingExtractionQueue = new PQueue({
concurrency: 1,
});
}
async logout() {
if (this.embeddingExtractionInProgress) {
this.embeddingExtractionInProgress.abort();
}
if (this.onFileUploadedHandler) {
await this.removeOnFileUploadListener();
}
}
setupOnFileUploadListener = async () => {
try {
if (this.onFileUploadedHandler) {
log.info("file upload listener already setup");
return;
}
log.info("setting up file upload listener");
this.onFileUploadedHandler = (args) => {
this.runLocalFileClipExtraction(args);
};
eventBus.on(Events.FILE_UPLOADED, this.onFileUploadedHandler, this);
log.info("setup file upload listener successfully");
} catch (e) {
log.error("failed to setup clip service", e);
}
};
removeOnFileUploadListener = async () => {
try {
if (!this.onFileUploadedHandler) {
log.info("file upload listener already removed");
return;
}
log.info("removing file upload listener");
eventBus.removeListener(
Events.FILE_UPLOADED,
this.onFileUploadedHandler,
this,
);
this.onFileUploadedHandler = null;
log.info("removed file upload listener successfully");
} catch (e) {
log.error("failed to remove clip service", e);
}
};
getIndexingStatus = async () => {
if (
this.indexingStatus.pending === 0 &&
this.indexingStatus.indexed === 0
) {
this.indexingStatus = await initialIndexingStatus();
}
return this.indexingStatus;
};
/**
* Set the {@link handler} to invoke whenever our indexing status changes.
*/
setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => {
this.onUpdateHandler = handler;
};
scheduleImageEmbeddingExtraction = async () => {
try {
if (this.embeddingExtractionInProgress) {
log.info(
"clip embedding extraction already in progress, scheduling re-run",
);
this.reRunNeeded = true;
return;
} else {
log.info(
"clip embedding extraction not in progress, starting clip embedding extraction",
);
}
const canceller = new AbortController();
this.embeddingExtractionInProgress = canceller;
try {
await this.runClipEmbeddingExtraction(canceller);
} finally {
this.embeddingExtractionInProgress = null;
if (!canceller.signal.aborted && this.reRunNeeded) {
this.reRunNeeded = false;
log.info("re-running clip embedding extraction");
setTimeout(
() => this.scheduleImageEmbeddingExtraction(),
0,
);
}
}
} catch (e) {
if (e.message !== CustomError.REQUEST_CANCELLED) {
log.error("failed to schedule clip embedding extraction", e);
}
}
};
getTextEmbeddingIfAvailable = async (text: string) => {
return ensureElectron().computeCLIPTextEmbeddingIfAvailable(text);
};
private runClipEmbeddingExtraction = async (canceller: AbortController) => {
try {
const user = getData(LS_KEYS.USER);
if (!user) {
return;
}
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
const existingEmbeddings = await localCLIPEmbeddings();
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,
existingEmbeddings,
);
this.updateIndexingStatus({
indexed: existingEmbeddings.length,
pending: pendingFiles.length,
});
if (pendingFiles.length === 0) {
log.info("no clip embedding extraction needed, all done");
return;
}
log.info(
`starting clip embedding extraction for ${pendingFiles.length} files`,
);
for (const file of pendingFiles) {
try {
log.info(
`extracting clip embedding for file: ${file.metadata.title} fileID: ${file.id}`,
);
if (canceller.signal.aborted) {
throw Error(CustomError.REQUEST_CANCELLED);
}
const embeddingData =
await this.extractFileClipImageEmbedding(file);
log.info(
`successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`,
);
await this.encryptAndUploadEmbedding(file, embeddingData);
this.onSuccessStatusUpdater();
log.info(
`successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
);
} catch (e) {
if (e?.message !== CustomError.REQUEST_CANCELLED) {
log.error(
"failed to extract clip embedding for file",
e,
);
}
if (
e?.message === CustomError.REQUEST_CANCELLED ||
e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)
) {
throw e;
}
}
}
} catch (e) {
if (e.message !== CustomError.REQUEST_CANCELLED) {
log.error("failed to extract clip embedding", e);
}
throw e;
}
};
private async runLocalFileClipExtraction(arg: {
enteFile: EnteFile;
localFile: globalThis.File;
}) {
const { enteFile, localFile } = arg;
log.info(
`clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
enteFile.id,
);
if (enteFile.metadata.fileType === FILE_TYPE.VIDEO) {
log.info(
`skipping video file for clip embedding extraction file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
return;
}
const extension = enteFile.metadata.title.split(".").pop();
if (!extension || !["jpg", "jpeg"].includes(extension)) {
log.info(
`skipping non jpg file for clip embedding extraction file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
return;
}
log.info(
`queuing up for local clip embedding extraction for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
try {
await this.liveEmbeddingExtractionQueue.add(async () => {
const embedding =
await this.extractLocalFileClipImageEmbedding(localFile);
await this.encryptAndUploadEmbedding(enteFile, embedding);
});
log.info(
`successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
} catch (e) {
log.error("Failed in ML onFileUploadedHandler", e);
}
}
private extractLocalFileClipImageEmbedding = async (localFile: File) => {
const file = await localFile
.arrayBuffer()
.then((buffer) => new Uint8Array(buffer));
return await ensureElectron().computeCLIPImageEmbedding(file);
};
private encryptAndUploadEmbedding = async (
file: EnteFile,
embeddingData: Float32Array,
) => {
if (embeddingData?.length !== 512) {
throw Error(
`invalid length embedding data length: ${embeddingData?.length}`,
);
}
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptEmbedding(embeddingData, file.key);
log.info(
`putting clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
);
await putEmbedding({
fileID: file.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "onnx-clip",
});
};
private updateIndexingStatus = (status: CLIPIndexingStatus) => {
this.indexingStatus = status;
const handler = this.onUpdateHandler;
if (handler) handler(status);
};
private extractFileClipImageEmbedding = async (file: EnteFile) => {
const thumb = await downloadManager.getThumbnail(file);
const embedding =
await ensureElectron().computeCLIPImageEmbedding(thumb);
return embedding;
};
private onSuccessStatusUpdater = () => {
this.updateIndexingStatus({
pending: this.indexingStatus.pending - 1,
indexed: this.indexingStatus.indexed + 1,
});
};
}
export const clipService = new CLIPService();
const getNonClipEmbeddingExtractedFiles = async (
files: EnteFile[],
existingEmbeddings: Embedding[],
) => {
const existingEmbeddingFileIds = new Set<number>();
existingEmbeddings.forEach((embedding) =>
existingEmbeddingFileIds.add(embedding.fileID),
);
const idSet = new Set<number>();
return files.filter((file) => {
if (idSet.has(file.id)) {
return false;
}
if (existingEmbeddingFileIds.has(file.id)) {
return false;
}
idSet.add(file.id);
return true;
});
};
export const computeClipMatchScore = async (
imageEmbedding: Float32Array,
textEmbedding: Float32Array,
) => {
if (imageEmbedding.length !== textEmbedding.length) {
throw Error("imageEmbedding and textEmbedding length mismatch");
}
let score = 0;
let imageNormalization = 0;
let textNormalization = 0;
for (let index = 0; index < imageEmbedding.length; index++) {
imageNormalization += imageEmbedding[index] * imageEmbedding[index];
textNormalization += textEmbedding[index] * textEmbedding[index];
}
for (let index = 0; index < imageEmbedding.length; index++) {
imageEmbedding[index] =
imageEmbedding[index] / Math.sqrt(imageNormalization);
textEmbedding[index] =
textEmbedding[index] / Math.sqrt(textNormalization);
}
for (let index = 0; index < imageEmbedding.length; index++) {
score += imageEmbedding[index] * textEmbedding[index];
}
return score;
};
const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
const user = getData(LS_KEYS.USER);
if (!user) throw new Error("Orphan CLIP indexing without a login");
const allEmbeddings = await localCLIPEmbeddings();
const localFiles = getPersonalFiles(await getLocalFiles(), user);
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,
allEmbeddings,
);
return {
indexed: allEmbeddings.length,
pending: pendingFiles.length,
};
};

View File

@@ -1,352 +0,0 @@
import {
getAllLocalFiles,
getLocalTrashedFiles,
} from "@/new/photos/services/files";
import type { EmbeddingModel } from "@/new/photos/services/ml/embedding";
import type { FaceIndex } from "@/new/photos/services/ml/face";
import { EnteFile } from "@/new/photos/types/file";
import { inWorker } from "@/next/env";
import log from "@/next/log";
import { apiURL } from "@/next/origins";
import { workerBridge } from "@/next/worker/worker-bridge";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError } from "@ente/shared/error";
import HTTPService from "@ente/shared/network/HTTPService";
import localForage from "@ente/shared/storage/localForage";
import { getToken } from "@ente/shared/storage/localStorage/helpers";
import type {
Embedding,
EncryptedEmbedding,
GetEmbeddingDiffResponse,
PutEmbeddingRequest,
} from "types/embedding";
import { getLocalCollections } from "./collectionService";
type FileML = FaceIndex & {
updatedAt: number;
};
const DIFF_LIMIT = 500;
/** Local storage key suffix for embedding sync times */
const embeddingSyncTimeLSKeySuffix = "embedding_sync_time";
/** Local storage key for CLIP embeddings. */
const clipEmbeddingsLSKey = "embeddings_v2";
const FILE_EMBEDING_TABLE = "file_embeddings";
/** Return all CLIP embeddings that we have available locally. */
export const localCLIPEmbeddings = async () =>
(await storedCLIPEmbeddings()).filter(({ model }) => model === "onnx-clip");
const storedCLIPEmbeddings = async () => {
const embeddings: Array<Embedding> =
await localForage.getItem<Embedding[]>(clipEmbeddingsLSKey);
if (!embeddings) {
// Migrate
await localForage.removeItem("embeddings");
await localForage.removeItem("embedding_sync_time");
await localForage.setItem(clipEmbeddingsLSKey, []);
return [];
}
return embeddings;
};
export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
const embeddings: Array<FileML> =
await localForage.getItem<FileML[]>(FILE_EMBEDING_TABLE);
if (!embeddings) {
return [];
}
return embeddings;
};
const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => {
return (
(await localForage.getItem<number>(
`${model}-${embeddingSyncTimeLSKeySuffix}`,
)) ?? 0
);
};
const setModelEmbeddingSyncTime = async (
model: EmbeddingModel,
time: number,
) => {
await localForage.setItem(`${model}-${embeddingSyncTimeLSKeySuffix}`, time);
};
/**
* Fetch new CLIP embeddings with the server and save them locally. Also prune
* local embeddings for any files no longer exist locally.
*/
export const syncCLIPEmbeddings = async () => {
const model: EmbeddingModel = "onnx-clip";
try {
let allEmbeddings = await storedCLIPEmbeddings();
const localFiles = await getAllLocalFiles();
const hiddenAlbums = await getLocalCollections("hidden");
const localTrashFiles = await getLocalTrashedFiles();
const fileIdToKeyMap = new Map<number, string>();
const allLocalFiles = [...localFiles, ...localTrashFiles];
allLocalFiles.forEach((file) => {
fileIdToKeyMap.set(file.id, file.key);
});
await cleanupDeletedEmbeddings(
allLocalFiles,
allEmbeddings,
clipEmbeddingsLSKey,
);
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
// Note: in rare cases we might get a diff entry for an embedding
// corresponding to a file which has been deleted (but whose
// embedding is enqueued for deletion). Client should expect such a
// scenario (all it has to do is just ignore them).
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const {
encryptedEmbedding,
decryptionHeader,
...rest
} = embedding;
const worker = await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(embedding.fileID);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptEmbedding(
encryptedEmbedding,
decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...rest,
embedding: decryptedData,
} as Embedding;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
}
}),
);
allEmbeddings = getLatestVersionEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
modelLastSinceTime = response.diff.reduce(
(max, { updatedAt }) => Math.max(max, updatedAt),
modelLastSinceTime,
);
await localForage.setItem(clipEmbeddingsLSKey, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length > 0);
} catch (e) {
log.error("Sync embeddings failed", e);
}
};
export const syncFaceEmbeddings = async () => {
const model: EmbeddingModel = "file-ml-clip-face";
try {
let allEmbeddings: FileML[] = await getFileMLEmbeddings();
const localFiles = await getAllLocalFiles();
const hiddenAlbums = await getLocalCollections("hidden");
const localTrashFiles = await getLocalTrashedFiles();
const fileIdToKeyMap = new Map<number, string>();
const allLocalFiles = [...localFiles, ...localTrashFiles];
allLocalFiles.forEach((file) => {
fileIdToKeyMap.set(file.id, file.key);
});
await cleanupDeletedEmbeddings(
allLocalFiles,
allEmbeddings,
FILE_EMBEDING_TABLE,
);
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const worker = await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(embedding.fileID);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptMetadata(
embedding.encryptedEmbedding,
embedding.decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...decryptedData,
updatedAt: embedding.updatedAt,
} as unknown as FileML;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
}
}),
);
allEmbeddings = getLatestVersionFileEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
modelLastSinceTime = response.diff.reduce(
(max, { updatedAt }) => Math.max(max, updatedAt),
modelLastSinceTime,
);
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length > 0);
} catch (e) {
log.error("Sync embeddings failed", e);
}
};
const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
const latestVersionEntities = new Map<number, Embedding>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};
const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
const latestVersionEntities = new Map<number, FileML>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};
export const getEmbeddingsDiff = async (
sinceTime: number,
model: EmbeddingModel,
): Promise<GetEmbeddingDiffResponse> => {
try {
const token = getToken();
if (!token) {
return;
}
const response = await HTTPService.get(
await apiURL("/embeddings/diff"),
{
sinceTime,
limit: DIFF_LIMIT,
model,
},
{
"X-Auth-Token": token,
},
);
return await response.data;
} catch (e) {
log.error("get embeddings diff failed", e);
throw e;
}
};
export const putEmbedding = async (
putEmbeddingReq: PutEmbeddingRequest,
): Promise<EncryptedEmbedding> => {
try {
const token = inWorker()
? await workerBridge.getAuthToken()
: getToken();
if (!token) {
log.info("putEmbedding failed: token not found");
throw Error(CustomError.TOKEN_MISSING);
}
const resp = await HTTPService.put(
await apiURL("/embeddings"),
putEmbeddingReq,
null,
{
"X-Auth-Token": token,
},
);
return resp.data;
} catch (e) {
log.error("put embedding failed", e);
throw e;
}
};
export const cleanupDeletedEmbeddings = async (
allLocalFiles: EnteFile[],
allLocalEmbeddings: Embedding[] | FileML[],
tableName: string,
) => {
const activeFileIds = new Set<number>();
allLocalFiles.forEach((file) => {
activeFileIds.add(file.id);
});
const remainingEmbeddings = allLocalEmbeddings.filter((embedding) =>
activeFileIds.has(embedding.fileID),
);
if (allLocalEmbeddings.length !== remainingEmbeddings.length) {
log.info(
`cleanupDeletedEmbeddings embeddingsCount: ${allLocalEmbeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`,
);
await localForage.setItem(tableName, remainingEmbeddings);
}
};

View File

@@ -3,7 +3,6 @@ import DownloadManager from "@/new/photos/services/download";
import { clearFeatureFlagSessionState } from "@/new/photos/services/feature-flags";
import { logoutML, terminateMLWorker } from "@/new/photos/services/ml";
import log from "@/next/log";
import { clipService } from "services/clip-service";
import exportService from "./export";
/**
@@ -46,12 +45,6 @@ export const photosLogout = async () => {
ignoreError("download", e);
}
try {
await clipService.logout();
} catch (e) {
ignoreError("CLIP", e);
}
// - Desktop
const electron = globalThis.electron;

View File

@@ -19,15 +19,11 @@ import {
import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker";
import { getUniqueFiles } from "utils/file";
import { getFormattedDate } from "utils/search";
import { clipService, computeClipMatchScore } from "./clip-service";
import { localCLIPEmbeddings } from "./embeddingService";
import { getLatestEntities } from "./entityService";
import locationSearchService, { City } from "./locationSearchService";
const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]);
const CLIP_SCORE_THRESHOLD = 0.23;
export const getDefaultOptions = async () => {
return [
// TODO-ML(MR): Skip this for now if indexing is disabled (eventually
@@ -373,30 +369,14 @@ async function searchLocationTag(searchPhrase: string): Promise<LocationTag[]> {
}
const searchClip = async (
searchPhrase: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_searchPhrase: string,
): Promise<ClipSearchScores | undefined> => {
const textEmbedding =
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
if (!textEmbedding) return undefined;
const imageEmbeddings = await localCLIPEmbeddings();
const clipSearchResult = new Map<number, number>(
(
await Promise.all(
imageEmbeddings.map(
async (imageEmbedding): Promise<[number, number]> => [
imageEmbedding.fileID,
await computeClipMatchScore(
imageEmbedding.embedding,
textEmbedding,
),
],
),
)
).filter(([, score]) => score >= CLIP_SCORE_THRESHOLD),
);
return clipSearchResult;
// TODO-ML: clip-test
return undefined;
// const matches = await clipMatches(searchPhrase, ensureElectron());
// log.debug(() => ({ t: "clip-scores", matches }));
// return matches;
};
function convertSuggestionToSearchQuery(option: Suggestion): Search {

View File

@@ -1,8 +1,6 @@
import { fetchAndSaveFeatureFlagsIfNeeded } from "@/new/photos/services/feature-flags";
import { triggerMLSync } from "@/new/photos/services/ml";
import { isDesktop } from "@/next/app";
import { clipService } from "services/clip-service";
import { syncCLIPEmbeddings } from "services/embeddingService";
import { syncEntities } from "services/entityService";
import { syncMapEnabled } from "services/userService";
@@ -20,8 +18,6 @@ export const sync = async () => {
await syncMapEnabled();
fetchAndSaveFeatureFlagsIfNeeded();
if (isDesktop) {
await syncCLIPEmbeddings();
triggerMLSync();
void clipService.scheduleImageEmbeddingExtraction();
}
};

View File

@@ -19,7 +19,6 @@ import { wait } from "@/utils/promise";
import { getDedicatedCryptoWorker } from "@ente/shared/crypto";
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
import { CustomError } from "@ente/shared/error";
import { Events, eventBus } from "@ente/shared/events";
import { Canceler } from "axios";
import type { Remote } from "comlink";
import isElectron from "is-electron";
@@ -617,27 +616,6 @@ class UploadManager {
const uploadItem =
uploadableItem.uploadItem ??
uploadableItem.livePhotoAssets.image;
try {
let file: File | undefined;
if (uploadItem) {
if (uploadItem instanceof File) {
file = uploadItem;
} else if (
typeof uploadItem == "string" ||
Array.isArray(uploadItem)
) {
// path from desktop, no file object
} else {
file = uploadItem.file;
}
}
eventBus.emit(Events.FILE_UPLOADED, {
enteFile: decryptedFile,
localFile: file,
});
} catch (e) {
log.warn("Ignoring error in fileUploaded handlers", e);
}
if (
uploadItem &&
(uploadResult == UPLOAD_RESULT.UPLOADED ||

View File

@@ -1,29 +0,0 @@
import type { EmbeddingModel } from "@/new/photos/services/ml/embedding";
export interface EncryptedEmbedding {
fileID: number;
/** @see {@link EmbeddingModel} */
model: string;
encryptedEmbedding: string;
decryptionHeader: string;
updatedAt: number;
}
export interface Embedding
extends Omit<
EncryptedEmbedding,
"encryptedEmbedding" | "decryptionHeader"
> {
embedding?: Float32Array;
}
export interface GetEmbeddingDiffResponse {
diff: EncryptedEmbedding[];
}
export interface PutEmbeddingRequest {
fileID: number;
model: EmbeddingModel;
encryptedEmbedding: string;
decryptionHeader: string;
}

View File

@@ -17,12 +17,21 @@
* implies "ESNext.*"), "DOM.Iterable" (the ability to iterate over DOM
* elements) is not a subset of "DOM" and needs to be listed out
* explicitly.
*
* Note that we don't need to specify the `target` compilerOption, since
* tsc isn't actually generating (emitting) the JavaScript.
*/
"lib": ["ESnext", "DOM", "DOM.Iterable", "WebWorker"],
/**
* We don't potentially need to specify the `target` compilerOption,
* since tsc isn't actually generating (emitting) the JavaScript.
* However, without a target TypeScript becomes very conservative, and
* complains about things like
*
* > Type 'Float32Array' can only be iterated through when using the
* > '--downlevelIteration' flag or with a '--target' of 'es2015' or
* > higher.
*/
"target": "ES2015",
/*
* The module system to assume the generated JavaScript will use.
*

View File

@@ -0,0 +1,208 @@
import type { EnteFile } from "@/new/photos/types/file";
import type { Electron } from "@/next/types/ipc";
import type { ImageBitmapAndData } from "./bitmap";
import { clipIndexes } from "./db";
import { pixelRGBBicubic } from "./image";
import { cosineSimilarity, norm } from "./math";
import type { MLWorkerElectron } from "./worker-electron";
/**
* The version of the CLIP indexing pipeline implemented by the current client.
*/
export const clipIndexingVersion = 1;
/**
* The CLIP embedding for a file (and some metadata).
*
* See {@link FaceIndex} for a similar structure with more comprehensive
* documentation.
*
* ---
*
* [Note: Natural language search using CLIP]
*
* CLIP (Contrastive Language-Image Pretraining) is a neural network trained on
* (image, text) pairs. It can be thought of as two separate (but jointly
* trained) encoders - one for images, and one for text - that both map to the
* same embedding space.
*
* We use this for natural language search within the app:
*
* 1. Pre-compute an embedding for each image.
*
* 2. When the user searches, compute an embedding for the search term.
*
* 3. Use cosine similarity to find the find the image (embedding) closest to
* the text (embedding).
*
* More details are in our [blog
* post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the
* initial launch of this feature using the GGML runtime.
*
* Since the initial launch, we've switched over to another runtime,
* [ONNX](https://onnxruntime.ai) and have made other implementation changes,
* but the overall gist remains the same.
*
* Note that we don't train the neural network - we only use one of the publicly
* available pre-trained neural networks for inference. These neural networks
* are wholly defined by their connectivity and weights. ONNX, our ML runtimes,
* loads these weights and instantiates a running network that we can use to
* compute the embeddings.
*
* Theoretically, the same CLIP model can be loaded by different frameworks /
* runtimes, but in practice each runtime has its own preferred format, and
* there are also quantization tradeoffs. So there is a specific model (a binary
* encoding of weights) tied to our current runtime that we use.
*
* To ensure that the embeddings, for the most part, can be shared, whenever
* possible we try to ensure that all the preprocessing steps, and the model
* itself, is the same across clients - web and mobile.
*/
export interface CLIPIndex {
/** The ID of the {@link EnteFile} whose index this is. */
fileID: number;
/** An integral version number of the indexing algorithm / pipeline. */
version: number;
/** The UA for the client which generated this embedding. */
client: string;
/**
* The CLIP embedding itself.
*
* This is an array of 512 floating point values that represent the
* embedding of the image in the same space where we'll embed the text so
* that both of them can be compared using a cosine distance.
*/
embedding: number[];
}
/**
* Compute the CLIP embedding of a given file.
*
* This function is the entry point to the CLIP indexing pipeline. The file goes
* through various stages:
*
* 1. Downloading the original if needed.
* 2. Convert (if needed) and pre-process.
* 3. Compute embeddings using ONNX/CLIP.
*
* Once all of it is done, it CLIP embedding (wrapped as a {@link CLIPIndex} so
* that it can be saved locally and also uploaded to the user's remote storage
* for use on their other devices).
*
* @param enteFile The {@link EnteFile} to index.
*
* @param uploadItem If we're called during the upload process, then this will
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
* use the on-disk file instead of needing to download the original from remote.
*
* @param electron The {@link MLWorkerElectron} instance that allows us to call
* our Node.js layer for various functionality.
*
* @param userAgent The UA of the client that is doing the indexing (us).
*/
export const indexCLIP = async (
enteFile: EnteFile,
image: ImageBitmapAndData,
electron: MLWorkerElectron,
userAgent: string,
): Promise<CLIPIndex> => {
const { data: imageData } = image;
const fileID = enteFile.id;
return {
fileID,
version: clipIndexingVersion,
client: userAgent,
embedding: await computeEmbedding(imageData, electron),
};
};
const computeEmbedding = async (
imageData: ImageData,
electron: MLWorkerElectron,
): Promise<number[]> => {
const clipInput = convertToCLIPInput(imageData);
return normalized(await electron.computeCLIPImageEmbedding(clipInput));
};
/**
* Convert {@link imageData} into the format that the CLIP model expects.
*/
const convertToCLIPInput = (imageData: ImageData) => {
const requiredWidth = 224;
const requiredHeight = 224;
const mean = [0.48145466, 0.4578275, 0.40821073] as const;
const std = [0.26862954, 0.26130258, 0.27577711] as const;
const { width, height, data: pixelData } = imageData;
// Maintain aspect ratio.
const scale = Math.max(requiredWidth / width, requiredHeight / height);
const scaledWidth = Math.round(width * scale);
const scaledHeight = Math.round(height * scale);
const widthOffset = Math.max(0, scaledWidth - requiredWidth) / 2;
const heightOffset = Math.max(0, scaledHeight - requiredHeight) / 2;
const clipInput = new Float32Array(3 * requiredWidth * requiredHeight);
// Populate the Float32Array with normalized pixel values.
let pi = 0;
const cOffsetG = requiredHeight * requiredWidth; // ChannelOffsetGreen
const cOffsetB = 2 * requiredHeight * requiredWidth; // ChannelOffsetBlue
for (let h = 0 + heightOffset; h < scaledHeight - heightOffset; h++) {
for (let w = 0 + widthOffset; w < scaledWidth - widthOffset; w++) {
const { r, g, b } = pixelRGBBicubic(
w / scale,
h / scale,
pixelData,
width,
height,
);
clipInput[pi] = (r / 255.0 - mean[0]) / std[0];
clipInput[pi + cOffsetG] = (g / 255.0 - mean[1]) / std[1];
clipInput[pi + cOffsetB] = (b / 255.0 - mean[2]) / std[2];
pi++;
}
}
return clipInput;
};
const normalized = (embedding: Float32Array) => {
const nums = Array.from(embedding);
const n = norm(nums);
return nums.map((v) => v / n);
};
/**
* Use CLIP to perform a natural language search over image embeddings.
*
* @param searchPhrase The text entered by the user in the search box.
*
* @param electron The {@link Electron} instance to use to communicate with the
* native code running in our desktop app (the embedding happens in the native
* layer).
*
* It return a list of files that should be shown in the search results. The
* actual return type is a map from fileIDs to the scores they got (higher is
* better). This map will only contains entries whose score was above our
* minimum threshold.
*
* The result can also be `undefined`, which indicates that the download for the
* ML model is still in progress (trying again later should succeed).
*/
export const clipMatches = async (
searchPhrase: string,
electron: Electron,
): Promise<Map<number, number> | undefined> => {
const t = await electron.computeCLIPTextEmbeddingIfAvailable(searchPhrase);
if (!t) return undefined;
const textEmbedding = normalized(t);
const items = (await clipIndexes()).map(
({ fileID, embedding }) =>
[fileID, cosineSimilarity(embedding, textEmbedding)] as const,
);
return new Map(items.filter(([, score]) => score >= 0.23));
};

View File

@@ -3,7 +3,7 @@ import { ensure } from "@/utils/ensure";
import type { EnteFile } from "../../types/file";
import { renderableImageBitmap } from "./bitmap";
import { type Box, type FaceIndex } from "./face";
import { clamp } from "./image";
import { clamp } from "./math";
/**
* Regenerate and locally save face crops for faces in the given file.

View File

@@ -1,37 +1,48 @@
import log from "@/next/log";
import localForage from "@ente/shared/storage/localForage";
import { deleteDB, openDB, type DBSchema } from "idb";
import type { CLIPIndex } from "./clip";
import type { EmbeddingModel } from "./embedding";
import type { FaceIndex } from "./face";
/**
* Face DB schema.
* ML DB schema.
*
* There "face" database is made of two object stores:
* The "ML" database is made of three object stores:
*
* - "file-status": Contains {@link FileStatus} objects, one for each
* {@link EnteFile} that the ML subsystem knows about. Periodically (and when
* required), this is synced with the list of files that the current client
* knows about locally.
*
* - "face-index": Contains {@link FaceIndex} objects, either indexed locally or
* fetched from remote storage.
*
* - "file-status": Contains {@link FileStatus} objects, one for each
* {@link EnteFile} that the current client knows about.
* - "clip-index": Contains {@link CLIPIndex} objects, either indexed locally or
* fetched from remote storage.
*
* Both the stores are keyed by {@link fileID}, and are expected to contain the
* exact same set of {@link fileID}s. The face-index can be thought of as the
* "original" indexing result, whilst file-status bookkeeps information about
* the indexing process (whether or not a file needs indexing, or if there were
* errors doing so).
* All the stores are keyed by {@link fileID}. The "file-status" contains
* book-keeping about the indexing process (whether or not a file needs
* indexing, or if there were errors doing so), while the other stores contain
* the actual indexing results.
*
* In tandem, these serve as the underlying storage for the functions exposed by
* this file.
*/
interface FaceDBSchema extends DBSchema {
"face-index": {
key: number;
value: FaceIndex;
};
interface MLDBSchema extends DBSchema {
"file-status": {
key: number;
value: FileStatus;
indexes: { status: FileStatus["status"] };
};
"face-index": {
key: number;
value: FaceIndex;
};
"clip-index": {
key: number;
value: CLIPIndex;
};
}
interface FileStatus {
@@ -44,94 +55,128 @@ interface FileStatus {
* to be indexed.
*
* - "indexed" - We have a corresponding entry for this file in the
* "face-index" object (either indexed locally or fetched from remote).
* "face-index" _and_ "clip-index" object stores (either indexed locally
* or fetched from remote).
*
* - "failed" - Indexing was attempted but failed.
*
* We also have a (IndexedDB) "index" on this field to allow us to
* efficiently select or count {@link fileIDs} that fall into various
* buckets.
* There can arise situations in which a file has one, but not all, indexes.
* e.g. it may have a "face-index" but "clip-index" might've not yet
* happened (or failed). In such cases, the status of the file will be
* "indexable": it transitions to "indexed" only after all indexes have been
* computed or fetched.
*
* If you have't heard the word "index" to the point of zoning out, we also
* have a (IndexedDB) "index" on the status field to allow us to efficiently
* select or count {@link fileIDs} that fall into various buckets.
*/
status: "indexable" | "indexed" | "failed";
/**
* A list of embeddings that we still need to compute for the file.
*
* This is guaranteed to be empty if status is "indexed", and will have at
* least one entry otherwise.
*/
pending: EmbeddingModel[];
/**
* The number of times attempts to index this file failed.
*
* It counts failure across all index types.
*
* This is guaranteed to be `0` for files with status "indexed".
*/
failureCount: number;
}
/**
* A lazily-created, cached promise for face DB.
* A lazily-created, cached promise for ML DB.
*
* See: [Note: Caching IDB instances in separate execution contexts].
*/
let _faceDB: ReturnType<typeof openFaceDB> | undefined;
let _mlDB: ReturnType<typeof openMLDB> | undefined;
const openFaceDB = async () => {
const openMLDB = async () => {
deleteLegacyDB();
const db = await openDB<FaceDBSchema>("face", 1, {
// TODO-ML: "face" => "ml", v2 => v1
const db = await openDB<MLDBSchema>("face", 2, {
upgrade(db, oldVersion, newVersion) {
log.info(`Upgrading face DB ${oldVersion} => ${newVersion}`);
log.info(`Upgrading ML DB ${oldVersion} => ${newVersion}`);
if (oldVersion < 1) {
db.createObjectStore("face-index", { keyPath: "fileID" });
db.createObjectStore("file-status", {
keyPath: "fileID",
}).createIndex("status", "status");
db.createObjectStore("face-index", { keyPath: "fileID" });
}
if (oldVersion < 2) {
db.createObjectStore("clip-index", { keyPath: "fileID" });
}
},
blocking() {
log.info(
"Another client is attempting to open a new version of face DB",
"Another client is attempting to open a new version of ML DB",
);
db.close();
_faceDB = undefined;
_mlDB = undefined;
},
blocked() {
log.warn(
"Waiting for an existing client to close their connection so that we can update the face DB version",
"Waiting for an existing client to close their connection so that we can update the ML DB version",
);
},
terminated() {
log.warn("Our connection to face DB was unexpectedly terminated");
_faceDB = undefined;
log.warn("Our connection to ML DB was unexpectedly terminated");
_mlDB = undefined;
},
});
return db;
};
const deleteLegacyDB = () => {
// Delete the legacy face DB.
// Delete the legacy face DB v1.
//
// This code was added June 2024 (v1.7.1-rc) and can be removed once clients
// have migrated over.
void deleteDB("mldata");
// Delete the legacy CLIP (mostly) related keys from LocalForage.
//
// This code was added July 2024 (v1.7.2-rc) and can be removed once
// sufficient clients have migrated over (tag: Migration).
void Promise.all([
localForage.removeItem("embeddings"),
localForage.removeItem("embedding_sync_time"),
localForage.removeItem("embeddings_v2"),
localForage.removeItem("file_embeddings"),
localForage.removeItem("onnx-clip-embedding_sync_time"),
localForage.removeItem("file-ml-clip-face-embedding_sync_time"),
]);
};
/**
* @returns a lazily created, cached connection to the face DB.
* @returns a lazily created, cached connection to the ML DB.
*/
const faceDB = () => (_faceDB ??= openFaceDB());
const mlDB = () => (_mlDB ??= openMLDB());
/**
* Clear any data stored in the face DB.
* Clear any data stored in the ML DB.
*
* This is meant to be called during logout in the main thread.
* This is meant to be called during logout on the main thread.
*/
export const clearFaceDB = async () => {
export const clearMLDB = async () => {
deleteLegacyDB();
try {
if (_faceDB) (await _faceDB).close();
if (_mlDB) (await _mlDB).close();
} catch (e) {
log.warn("Ignoring error when trying to close face DB", e);
log.warn("Ignoring error when trying to close ML DB", e);
}
_faceDB = undefined;
_mlDB = undefined;
return deleteDB("face", {
blocked() {
log.warn(
"Waiting for an existing client to close their connection so that we can delete the face DB",
"Waiting for an existing client to close their connection so that we can delete the ML DB",
);
},
});
@@ -141,23 +186,79 @@ export const clearFaceDB = async () => {
* Save the given {@link faceIndex} locally.
*
* @param faceIndex A {@link FaceIndex} representing the faces that we detected
* (and their corresponding embeddings) in some file.
* (and their corresponding embeddings) in a particular file.
*
* This function adds a new entry, overwriting any existing ones (No merging is
* performed, the existing entry is unconditionally overwritten).
* This function adds a new entry for the face index, overwriting any existing
* ones (No merging is performed, the existing entry is unconditionally
* overwritten). The file status is also updated to remove face from the pending
* embeddings. If there are no other pending embeddings, the status changes to
* "indexed".
*/
export const saveFaceIndex = async (faceIndex: FaceIndex) => {
const db = await faceDB();
const tx = db.transaction(["face-index", "file-status"], "readwrite");
const indexStore = tx.objectStore("face-index");
const { fileID } = faceIndex;
const db = await mlDB();
const tx = db.transaction(["file-status", "face-index"], "readwrite");
const statusStore = tx.objectStore("file-status");
const indexStore = tx.objectStore("face-index");
const fileStatus =
(await statusStore.get(IDBKeyRange.only(fileID))) ??
newFileStatus(fileID);
fileStatus.pending = fileStatus.pending.filter(
(v) => v != "file-ml-clip-face",
);
if (fileStatus.pending.length == 0) fileStatus.status = "indexed";
await Promise.all([
statusStore.put(fileStatus),
indexStore.put(faceIndex),
statusStore.put({
fileID: faceIndex.fileID,
status: "indexed",
failureCount: 0,
}),
tx.done,
]);
};
/**
* Return a new object suitable for use as the initial value of the entry for a
* file in the file status store.
*/
const newFileStatus = (fileID: number): FileStatus => ({
fileID,
status: "indexable",
// TODO-ML: clip-test
// pending: ["file-ml-clip-face", "onnx-clip"],
pending: ["file-ml-clip-face"],
failureCount: 0,
});
/**
* Save the given {@link clipIndex} locally.
*
* @param clipIndex A {@link CLIPIndex} containing the CLIP embedding for a
* particular file.
*
* This function adds a new entry for the CLIP index, overwriting any existing
* ones (No merging is performed, the existing entry is unconditionally
* overwritten). The file status is also updated to remove CLIP from the pending
* embeddings. If there are no other pending embeddings, the status changes to
* "indexed".
*/
export const saveCLIPIndex = async (clipIndex: CLIPIndex) => {
const { fileID } = clipIndex;
const db = await mlDB();
const tx = db.transaction(["file-status", "clip-index"], "readwrite");
const statusStore = tx.objectStore("file-status");
const indexStore = tx.objectStore("clip-index");
const fileStatus =
(await statusStore.get(IDBKeyRange.only(fileID))) ??
newFileStatus(fileID);
fileStatus.pending = fileStatus.pending.filter((v) => v != "onnx-clip");
if (fileStatus.pending.length == 0) fileStatus.status = "indexed";
await Promise.all([
statusStore.put(fileStatus),
indexStore.put(clipIndex),
tx.done,
]);
};
@@ -166,13 +267,21 @@ export const saveFaceIndex = async (faceIndex: FaceIndex) => {
* Return the {@link FaceIndex}, if any, for {@link fileID}.
*/
export const faceIndex = async (fileID: number) => {
const db = await faceDB();
const db = await mlDB();
return db.get("face-index", fileID);
};
/**
* Record the existence of a file so that entities in the face indexing universe
* know about it (e.g. can index it if it is new and it needs indexing).
* Return all CLIP indexes present locally.
*/
export const clipIndexes = async () => {
const db = await mlDB();
return await db.getAll("clip-index");
};
/**
* Record the existence of a file so that entities in the ML universe know about
* it (e.g. can index it if it is new and it needs indexing).
*
* @param fileID The ID of an {@link EnteFile}.
*
@@ -182,48 +291,45 @@ export const faceIndex = async (fileID: number) => {
* unperturbed.
*/
export const addFileEntry = async (fileID: number) => {
const db = await faceDB();
const db = await mlDB();
const tx = db.transaction("file-status", "readwrite");
if ((await tx.store.getKey(fileID)) === undefined) {
await tx.store.put({
fileID,
status: "indexable",
failureCount: 0,
});
}
if ((await tx.store.getKey(fileID)) === undefined)
await tx.store.put(newFileStatus(fileID));
return tx.done;
};
/**
* Update entries in the face DB to align with the state of local files outside
* face DB.
* Update entries in ML DB to align with the state of local files outside ML DB.
*
* @param localFileIDs IDs of all the files that the client is aware of filtered
* to only keep the files that the user owns and the formats that can be indexed
* by our current face indexing pipeline.
* @param localFileIDs IDs of all the files that the client is aware of,
* filtered to only keep the files that the user owns and the formats that can
* be indexed by our current indexing pipelines.
*
* @param localTrashFilesIDs IDs of all the files in trash.
*
* This function then updates the state of file entries in face DB to the be in
* sync with these provided local file IDS.
* This function then updates the state of file entries in ML DB to the be in
* sync with these provided local file IDs.
*
* - Files that are present locally but are not yet in face DB get a fresh entry
* - Files that are present locally but are not yet in ML DB get a fresh entry
* in face DB (and are marked as indexable).
*
* - Files that are not present locally (nor are in trash) but still exist in
* face DB are removed from face DB (including their face index, if any).
* - Files that are not present locally (nor are in trash) but still exist in ML
* DB are removed from ML DB (including any indexes).
*
* - Files that are not present locally but are in the trash are retained in
* face DB if their status is "indexed" (otherwise they too are removed). This
* is prevent churn (re-indexing) if the user moves some files to trash but
* then later restores them before they get permanently deleted.
* - Files that are not present locally but are in the trash are retained in ML
* DB if their status is "indexed"; otherwise they too are removed. This
* special case is to prevent churn (re-indexing) if the user moves some files
* to trash but then later restores them before they get permanently deleted.
*/
export const updateAssumingLocalFiles = async (
localFileIDs: number[],
localTrashFilesIDs: number[],
) => {
const db = await faceDB();
const tx = db.transaction(["face-index", "file-status"], "readwrite");
const db = await mlDB();
const tx = db.transaction(
["file-status", "face-index", "clip-index"],
"readwrite",
);
const fdbFileIDs = await tx.objectStore("file-status").getAllKeys();
const fdbIndexedFileIDs = await tx
.objectStore("file-status")
@@ -250,16 +356,13 @@ export const updateAssumingLocalFiles = async (
await Promise.all(
[
newFileIDs.map((id) =>
tx.objectStore("file-status").put({
fileID: id,
status: "indexable",
failureCount: 0,
}),
tx.objectStore("file-status").put(newFileStatus(id)),
),
removedFileIDs.map((id) =>
tx.objectStore("file-status").delete(id),
),
removedFileIDs.map((id) => tx.objectStore("face-index").delete(id)),
removedFileIDs.map((id) => tx.objectStore("clip-index").delete(id)),
tx.done,
].flat(),
);
@@ -268,11 +371,12 @@ export const updateAssumingLocalFiles = async (
/**
* Return the count of files that can be, and that have been, indexed.
*
* These counts are mutually exclusive. The total number of files that fall
* within the purview of the indexer is thus indexable + indexed.
* These counts are mutually exclusive. Thus the total number of files that are
* fall within the purview of the indexer will be indexable + indexed (if we are
* ignoring the "failed" ones).
*/
export const indexableAndIndexedCounts = async () => {
const db = await faceDB();
const db = await mlDB();
const tx = db.transaction("file-status", "readwrite");
const indexableCount = await tx.store
.index("status")
@@ -287,14 +391,14 @@ export const indexableAndIndexedCounts = async () => {
* Return a list of fileIDs that need to be indexed.
*
* This list is from the universe of the file IDs that the face DB knows about
* (can use {@link addFileEntry} to inform it about new files). From this
* (we can use {@link addFileEntry} to inform it about new files). From this
* universe, we filter out fileIDs the files corresponding to which have already
* been indexed, or which should be ignored.
*
* @param count Limit the result to up to {@link count} items.
*/
export const indexableFileIDs = async (count?: number) => {
const db = await faceDB();
const db = await mlDB();
const tx = db.transaction("file-status", "readonly");
return tx.store
.index("status")
@@ -309,15 +413,15 @@ export const indexableFileIDs = async (count?: number) => {
* If an entry does not exist yet for the given file, then a new one is created
* and its failure count is set to 1. Otherwise the failure count of the
* existing entry is incremented.
*
* This count is across all different types of indexing (face, CLIP) that happen
* on the file.
*/
export const markIndexingFailed = async (fileID: number) => {
const db = await faceDB();
const db = await mlDB();
const tx = db.transaction("file-status", "readwrite");
const failureCount = ((await tx.store.get(fileID))?.failureCount ?? 0) + 1;
await tx.store.put({
fileID,
status: "failed",
failureCount,
});
return tx.done;
const fileStatus = (await tx.store.get(fileID)) ?? newFileStatus(fileID);
fileStatus.status = "failed";
fileStatus.failureCount = fileStatus.failureCount + 1;
await Promise.all([tx.store.put(fileStatus), tx.done]);
};

View File

@@ -12,8 +12,9 @@ import { getKV, setKV } from "@/next/kv";
import log from "@/next/log";
import { apiURL } from "@/next/origins";
import { z } from "zod";
import { saveFaceIndex } from "./db";
import { type FaceIndex, faceIndexingVersion } from "./face";
import { clipIndexingVersion, type CLIPIndex } from "./clip";
import { saveCLIPIndex, saveFaceIndex } from "./db";
import { faceIndexingVersion, type FaceIndex } from "./face";
/**
* The embeddings that we (the current client) knows how to handle.
@@ -48,7 +49,7 @@ import { type FaceIndex, faceIndexingVersion } from "./face";
* "model" (i.e "type") field to create a new universe of embeddings.
*/
export type EmbeddingModel =
| "onnx-clip" /* CLIP (text) embeddings */
| "onnx-clip" /* CLIP embeddings */
| "file-ml-clip-face" /* Face embeddings */;
const RemoteEmbedding = z.object({
@@ -397,3 +398,61 @@ const FaceIndex = z
*/
export const putFaceIndex = async (enteFile: EnteFile, faceIndex: FaceIndex) =>
putEmbedding(enteFile, "file-ml-clip-face", JSON.stringify(faceIndex));
// MARK: - CLIP
/**
* Fetch new or updated CLIP embeddings from remote and save them locally.
*
* See {@link pullFaceEmbeddings} for a sibling function with more comprehensive
* documentation.
*
* @returns true if at least one embedding was pulled, false otherwise.
*/
export const pullCLIPEmbeddings = () =>
pullEmbeddings("onnx-clip", (jsonString: string) =>
// eslint-disable-next-line @typescript-eslint/prefer-ts-expect-error, @typescript-eslint/ban-ts-comment
// @ts-ignore TODO: There is no error here, but this file is imported by
// one of our packages that doesn't have strict mode enabled yet,
// causing a spurious error to be emitted in that context.
saveCLIPIndexIfNewer(CLIPIndex.parse(JSON.parse(jsonString))),
);
/**
* Save the given {@link clipIndex} locally if it is newer than the one we have.
*
* This is a variant of {@link saveCLIPIndex} that performs version checking as
* described in [Note: Handling versioning of embeddings].
*/
const saveCLIPIndexIfNewer = async (index: CLIPIndex) => {
const version = index.version;
if (version < clipIndexingVersion) {
log.info(
`Ignoring remote CLIP index with version ${version} older than what our indexer can produce (${clipIndexingVersion})`,
);
return;
}
return saveCLIPIndex(index);
};
/**
* Zod schemas for the {@link CLIPIndex} types.
*
* See: [Note: Duplicated between Zod schemas and TS type]
*/
const CLIPIndex = z
.object({
fileID: z.number(),
version: z.number(),
client: z.string(),
embedding: z.array(z.number()),
})
// Retain fields we might not (currently) understand.
.passthrough();
/**
* Save the CLIP index for the given {@link enteFile} on remote so that other
* clients can directly pull it instead of needing to reindex.
*/
export const putCLIPIndex = async (enteFile: EnteFile, clipIndex: CLIPIndex) =>
putEmbedding(enteFile, "onnx-clip", JSON.stringify(clipIndex));

View File

@@ -21,11 +21,11 @@ import {
import type { ImageBitmapAndData } from "./bitmap";
import { saveFaceCrops } from "./crop";
import {
clamp,
grayscaleIntMatrixFromNormalized2List,
pixelRGBBilinear,
warpAffineFloat32List,
} from "./image";
import { clamp } from "./math";
import type { MLWorkerElectron } from "./worker-electron";
/**
@@ -196,24 +196,26 @@ export interface Box {
/**
* Index faces in the given file.
*
* This function is the entry point to the indexing pipeline. The file goes
* This function is the entry point to the face indexing pipeline. The file goes
* through various stages:
*
* 1. Downloading the original if needed.
* 2. Detect faces using ONNX/YOLO
* 3. Align the face rectangles, compute blur.
* 4. Compute embeddings for the detected face (crops).
* 4. Compute embeddings using ONNX/MFNT for the detected face (crop).
*
* Once all of it is done, it returns the face rectangles and embeddings so that
* they can be saved locally for offline use, and encrypts and uploads them to
* the user's remote storage so that their other devices can download them
* instead of needing to reindex.
* they can be saved locally (for offline use), and also uploaded to the user's
* remote storage so that their other devices can download them instead of
* needing to reindex.
*
* As an optimization, it also saves the face crops of the detected faces to the
* local cache (they can be regenerated independently too by using
* {@link regenerateFaceCrops}).
*
* @param enteFile The {@link EnteFile} to index.
*
* @param uploadItem If we're called during the upload process, then this will
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
* use the on-disk file instead of needing to download the original from remote.
* @param image The file's contents.
*
* @param electron The {@link MLWorkerElectron} instance that allows us to call
* our Node.js layer for various functionality.

View File

@@ -3,12 +3,7 @@
import { ensure } from "@/utils/ensure";
import { Matrix, inverse } from "ml-matrix";
/**
* Clamp {@link value} to between {@link min} and {@link max}, inclusive.
*/
export const clamp = (value: number, min: number, max: number) =>
Math.min(max, Math.max(min, value));
import { clamp } from "./math";
/**
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
@@ -80,7 +75,7 @@ const pixelRGBA = (
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bicubic interpolation.
*/
const pixelRGBBicubic = (
export const pixelRGBBicubic = (
fx: number,
fy: number,
imageData: Uint8ClampedArray,

View File

@@ -16,7 +16,7 @@ import { ComlinkWorker } from "@/next/worker/comlink-worker";
import { proxy } from "comlink";
import type { UploadItem } from "../upload/types";
import { regenerateFaceCrops } from "./crop";
import { clearFaceDB, faceIndex, indexableAndIndexedCounts } from "./db";
import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db";
import { MLWorker } from "./worker";
/**
@@ -47,6 +47,7 @@ const createComlinkWorker = async () => {
appVersion: electron.appVersion,
detectFaces: electron.detectFaces,
computeFaceEmbeddings: electron.computeFaceEmbeddings,
computeCLIPImageEmbedding: electron.computeCLIPImageEmbedding,
};
const cw = new ComlinkWorker<typeof MLWorker>(
@@ -85,9 +86,9 @@ export const logoutML = async () => {
// `terminateMLWorker` is conceptually also part of this, but for the
// reasons mentioned in [Note: Caching IDB instances in separate execution
// contexts], it gets called first in the logout sequence, and then this
// `logoutML` gets called at a later point in time.
// function (`logoutML`) gets called at a later point in time.
_isMLEnabled = false;
await clearFaceDB();
await clearMLDB();
};
/**

View File

@@ -0,0 +1,43 @@
/**
* Clamp {@link value} to between {@link min} and {@link max}, inclusive.
*/
export const clamp = (value: number, min: number, max: number) =>
Math.min(max, Math.max(min, value));
/**
* Return the dot-product of two vectors.
*
* Dot product is the component-wise product of the corresponding elements of
* the two given vectors.
*
* Precondition: The two vectors must be of the same length.
*/
export const dotProduct = (v1: number[], v2: number[]) => {
if (v1.length != v2.length)
throw new Error(`Length mismatch ${v1.length} ${v2.length}`);
let d = 0;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
for (let i = 0; i < v1.length; i++) d += v1[i]! * v2[i]!;
return d;
};
/**
* Return the L2-norm ("magnitude") of the given vector.
*
* L2-norm is the sqrt of the sum of the squares of the components of the
* vector. It can also be thought of as the sqrt of the dot product of the
* vector with itself.
*/
export const norm = (v: number[]) =>
Math.sqrt(v.reduce((a, x) => a + x * x, 0));
/**
* Return the cosine similarity of the two given vectors.
*
* The result is a value between [-1, 1] (inclusive), indicating how similar
* (1), orthogonal (0) and dissimilar (1) the two vectors are in direction.
*
* Precondition: The two vectors must be of the same length.
*/
export const cosineSimilarity = (v1: number[], v2: number[]) =>
dotProduct(v1, v2) / (norm(v1) * norm(v2));

View File

@@ -10,4 +10,5 @@ export interface MLWorkerElectron {
appVersion: () => Promise<string>;
detectFaces: (input: Float32Array) => Promise<Float32Array>;
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
}

View File

@@ -11,15 +11,17 @@ import { expose } from "comlink";
import downloadManager from "../download";
import { getAllLocalFiles, getLocalTrashedFiles } from "../files";
import type { UploadItem } from "../upload/types";
import { imageBitmapAndData } from "./bitmap";
import { imageBitmapAndData, type ImageBitmapAndData } from "./bitmap";
import { indexCLIP, type CLIPIndex } from "./clip";
import {
indexableFileIDs,
markIndexingFailed,
saveCLIPIndex,
saveFaceIndex,
updateAssumingLocalFiles,
} from "./db";
import { pullFaceEmbeddings, putFaceIndex } from "./embedding";
import { type FaceIndex, indexFaces } from "./face";
import { pullFaceEmbeddings, putCLIPIndex, putFaceIndex } from "./embedding";
import { indexFaces, type FaceIndex } from "./face";
import type { MLWorkerElectron } from "./worker-electron";
const idleDurationStart = 5; /* 5 seconds */
@@ -239,7 +241,24 @@ expose(MLWorker);
* Return true atleast one embedding was pulled.
*/
const pull = async () => {
return pullFaceEmbeddings();
const res = await Promise.allSettled([
pullFaceEmbeddings(),
// TODO-ML: clip-test
// pullCLIPEmbeddings(),
]);
for (const r of res) {
switch (r.status) {
case "fulfilled":
// Return true if any pulled something.
if (r.value) return true;
break;
case "rejected":
// Throw if any failed.
throw r.reason;
}
}
// Return false if neither pulled anything.
return false;
};
/**
@@ -339,6 +358,32 @@ const index = async (
const startTime = Date.now();
const image = await imageBitmapAndData(enteFile, uploadItem, electron);
const res = await Promise.allSettled([
_indexFace(f, enteFile, image, electron, userAgent),
// TODO-ML: clip-test
// _indexCLIP(f, enteFile, image, electron, userAgent),
]);
image.bitmap.close();
const msg: string[] = [];
for (const r of res) {
if (r.status == "rejected") throw r.reason;
else msg.push(r.value);
}
log.debug(() => {
const ms = Date.now() - startTime;
return `Indexed ${msg.join(" and ")} in ${f} (${ms} ms)`;
});
};
const _indexFace = async (
f: string,
enteFile: EnteFile,
image: ImageBitmapAndData,
electron: MLWorkerElectron,
userAgent: string,
) => {
let faceIndex: FaceIndex;
try {
faceIndex = await indexFaces(enteFile, image, electron, userAgent);
@@ -346,8 +391,6 @@ const index = async (
log.error(`Failed to index faces in ${f}`, e);
await markIndexingFailed(enteFile.id);
throw e;
} finally {
image.bitmap.close();
}
// [Note: Transient and permanent indexing failures]
@@ -389,11 +432,37 @@ const index = async (
throw e;
}
log.debug(() => {
const nf = faceIndex.faceEmbedding.faces.length;
const ms = Date.now() - startTime;
return `Indexed ${nf} faces in ${f} (${ms} ms)`;
});
return faceIndex;
// A message for debug printing.
return `${faceIndex.faceEmbedding.faces.length} faces`;
};
// TODO-ML: clip-test export
export const _indexCLIP = async (
f: string,
enteFile: EnteFile,
image: ImageBitmapAndData,
electron: MLWorkerElectron,
userAgent: string,
) => {
let clipIndex: CLIPIndex;
try {
clipIndex = await indexCLIP(enteFile, image, electron, userAgent);
} catch (e) {
log.error(`Failed to index CLIP in ${f}`, e);
await markIndexingFailed(enteFile.id);
throw e;
}
// See: [Note: Transient and permanent indexing failures]
try {
await putCLIPIndex(enteFile, clipIndex);
await saveCLIPIndex(clipIndex);
} catch (e) {
log.error(`Failed to put/save CLIP index for ${f}`, e);
if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id);
throw e;
}
// A message for debug printing.
return "clip";
};

View File

@@ -38,8 +38,8 @@ export const apiURL = async (path: string) => (await apiOrigin()) + path;
export const customAPIOrigin = async () => {
let origin = await getKV("apiOrigin");
if (!origin && !inWorker()) {
// TODO: Migration of apiOrigin from local storage to indexed DB
// Remove me after a bit (27 June 2024).
// TODO: Migration of apiOrigin from local storage to indexed DB. Added
// 27 June 2024, 1.7.2-rc. Remove me after a bit (tag: Migration).
const legacyOrigin = localStorage.getItem("apiOrigin");
if (legacyOrigin !== null) {
origin = legacyOrigin;

View File

@@ -337,15 +337,15 @@ export interface Electron {
/**
* Return a CLIP embedding of the given image.
*
* See: [Note: CLIP based magic search]
* See: [Note: Natural language search using CLIP]
*
* @param jpegImageData The raw bytes of the image encoded as an JPEG.
* The input is a opaque float32 array representing the image. The layout
* and exact encoding of the input is specific to our implementation and the
* ML model (CLIP) we use.
*
* @returns A CLIP embedding.
* @returns A CLIP embedding (an array of 512 floating point values).
*/
computeCLIPImageEmbedding: (
jpegImageData: Uint8Array,
) => Promise<Float32Array>;
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a CLIP embedding of the given image if we already have the model
@@ -361,7 +361,7 @@ export interface Electron {
* embeddings are used as part of deducing user initiated search results,
* and we don't want to block that interaction on a large network request.
*
* See: [Note: CLIP based magic search]
* See: [Note: Natural language search using CLIP]
*
* @param text The string whose embedding we want to compute.
*

View File

@@ -1,7 +1,6 @@
import { ensureElectron } from "@/next/electron";
import log, { logToDisk } from "@/next/log";
import { expose, wrap, type Remote } from "comlink";
import { ensureLocalUser } from "../local-user";
/**
* A minimal wrapper for a web {@link Worker}, proxying a class of type T.
@@ -62,13 +61,12 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
*
* Not all workers need access to all these functions, and this can indeed be
* done in a more fine-grained, per-worker, manner if needed. For now, since it
* is a motley bunch, we just inject them all to all workers.
* is just a couple, we just inject them all to all workers.
*/
const workerBridge = {
// Needed by all workers (likely, not necessarily).
// Needed by all workers (likely, but not necessarily).
logToDisk,
// Needed by MLWorker.
getAuthToken: () => ensureLocalUser().token,
// Needed by ML worker.
convertToJPEG: (imageData: Uint8Array) =>
ensureElectron().convertToJPEG(imageData),
};

View File

@@ -70,8 +70,6 @@ export const CustomError = {
TWO_FACTOR_ENABLED: "two factor enabled",
CLIENT_ERROR: "client error",
ServerError: "server error",
FILE_NOT_FOUND: "file not found",
UNSUPPORTED_PLATFORM: "Unsupported platform",
UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch",
URL_ALREADY_SET: "url already set",
FILE_CONVERSION_FAILED: "file conversion failed",

View File

@@ -1,11 +1,11 @@
import { EventEmitter } from "eventemitter3";
// TODO: Remove me
// When registering event handlers,
// handle errors to avoid unhandled rejection or propagation to emit call
export enum Events {
LOGOUT = "logout",
FILE_UPLOADED = "fileUploaded",
LOCAL_FILES_UPDATED = "localFilesUpdated",
}

View File

@@ -2477,11 +2477,6 @@ eventemitter3@^4.0.7:
resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-4.0.7.tgz#2de9b68f6528d5644ef5c59526a1b4a07306169f"
integrity sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==
eventemitter3@^5.0.1:
version "5.0.1"
resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-5.0.1.tgz#53f5ffd0a492ac800721bb42c66b841de96423c4"
integrity sha512-GWkBvjiSZK87ELrYOSESUYeVIc9mvLLf/nXalMOS5dYrgZq9o5OVkbZAVM06CVxYsCwH9BDZFPlQTlPA1j4ahA==
exifr@^7.1.3:
version "7.1.3"
resolved "https://registry.yarnpkg.com/exifr/-/exifr-7.1.3.tgz#f6218012c36dbb7d843222011b27f065fddbab6f"
@@ -3759,19 +3754,6 @@ p-locate@^5.0.0:
dependencies:
p-limit "^3.0.2"
p-queue@^7.1.0:
version "7.4.1"
resolved "https://registry.yarnpkg.com/p-queue/-/p-queue-7.4.1.tgz#7f86f853048beca8272abdbb7cec1ed2afc0f265"
integrity sha512-vRpMXmIkYF2/1hLBKisKeVYJZ8S2tZ0zEAmIJgdVKP2nq0nh4qCdf8bgw+ZgKrkh71AOCaqzwbJJk1WtdcF3VA==
dependencies:
eventemitter3 "^5.0.1"
p-timeout "^5.0.2"
p-timeout@^5.0.2:
version "5.1.0"
resolved "https://registry.yarnpkg.com/p-timeout/-/p-timeout-5.1.0.tgz#b3c691cf4415138ce2d9cfe071dba11f0fee085b"
integrity sha512-auFDyzzzGZZZdHz3BtET9VEz0SE/uMEAx7uWfGPucfzEwwe/xH0iVeZibQmANYE/hp9T2+UUZT5m+BKyrDp3Ew==
pako@~1.0.2:
version "1.0.11"
resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.11.tgz#6c9599d340d54dfd3946380252a35705a6b992bf"