[desktop] Consolidate face and CLIP pipelines - Part 2/x (#2408)
* Continuation of https://github.com/ente-io/ente/pull/2391 * CLIP indexing (and search) is disabled for now until the new format is finalized
This commit is contained in:
@@ -112,21 +112,14 @@ resources (`build`) folder. This is used for thumbnail generation on Linux.
|
||||
On macOS, we use the `sips` CLI tool for conversion, but that is already
|
||||
available on the host machine, and is not bundled with our app.
|
||||
|
||||
### AI/ML
|
||||
### ML
|
||||
|
||||
[onnxruntime-node](https://github.com/Microsoft/onnxruntime) is used as the
|
||||
AI/ML runtime. It powers both natural language searches (using CLIP) and face
|
||||
[onnxruntime-node](https://github.com/Microsoft/onnxruntime) is used as the ML
|
||||
runtime. It powers both natural language searches (using CLIP) and face
|
||||
detection (using YOLO).
|
||||
|
||||
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding JPEG
|
||||
data into raw RGB bytes before passing it to ONNX.
|
||||
|
||||
html-entities is used by the bundled clip-bpe-ts tokenizer for CLIP.
|
||||
|
||||
### Watch Folders
|
||||
|
||||
[chokidar](https://github.com/paulmillr/chokidar) is used as a file system
|
||||
watcher for the watch folders functionality.
|
||||
[clip-bpe-js](https://github.com/simonwarchol/clip-bpe-js) is used for tokening
|
||||
the user's search phrase before computing its CLIP (text) embedding.
|
||||
|
||||
### ZIP
|
||||
|
||||
@@ -135,3 +128,8 @@ reading of large ZIP files (e.g. during imports of Google Takeout ZIPs).
|
||||
|
||||
[lru-cache](https://github.com/isaacs/node-lru-cache) is used to cache file ZIP
|
||||
handles to avoid reopening them for every operation.
|
||||
|
||||
### Watch folders
|
||||
|
||||
[chokidar](https://github.com/paulmillr/chokidar) is used as a file system
|
||||
watcher for the watch folders functionality.
|
||||
|
||||
@@ -33,8 +33,6 @@
|
||||
"electron-store": "^8.2",
|
||||
"electron-updater": "^6.2",
|
||||
"ffmpeg-static": "^5.2",
|
||||
"html-entities": "^2.5",
|
||||
"jpeg-js": "^0.4",
|
||||
"lru-cache": "^10.2",
|
||||
"next-electron-server": "^1",
|
||||
"node-stream-zip": "^1.15",
|
||||
|
||||
@@ -186,10 +186,8 @@ export const attachIPCHandlers = () => {
|
||||
|
||||
// - ML
|
||||
|
||||
ipcMain.handle(
|
||||
"computeCLIPImageEmbedding",
|
||||
(_, jpegImageData: Uint8Array) =>
|
||||
computeCLIPImageEmbedding(jpegImageData),
|
||||
ipcMain.handle("computeCLIPImageEmbedding", (_, input: Float32Array) =>
|
||||
computeCLIPImageEmbedding(input),
|
||||
);
|
||||
|
||||
ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) =>
|
||||
|
||||
@@ -1,23 +1,12 @@
|
||||
// TODO: These arise from the array indexing in the pre-processing code. Isolate
|
||||
// once that code settles down to its final place (currently duplicated across
|
||||
// web and desktop).
|
||||
/* eslint-disable @typescript-eslint/no-non-null-assertion */
|
||||
|
||||
/**
|
||||
* @file Compute CLIP embeddings for images and text.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*
|
||||
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
|
||||
*/
|
||||
import Tokenizer from "clip-bpe-js";
|
||||
import jpeg from "jpeg-js";
|
||||
import fs from "node:fs/promises";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { writeStream } from "../stream";
|
||||
import { ensure, wait } from "../utils/common";
|
||||
import { deleteTempFile, makeTempFilePath } from "../utils/temp";
|
||||
import { makeCachedInferenceSession } from "./ml";
|
||||
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
@@ -25,241 +14,16 @@ const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
351468764 /* 335.2 MB */,
|
||||
);
|
||||
|
||||
export const computeCLIPImageEmbedding = async (jpegImageData: Uint8Array) => {
|
||||
const tempFilePath = await makeTempFilePath();
|
||||
const imageStream = new Response(jpegImageData.buffer).body;
|
||||
await writeStream(tempFilePath, ensure(imageStream));
|
||||
try {
|
||||
return await clipImageEmbedding_(tempFilePath);
|
||||
} finally {
|
||||
await deleteTempFile(tempFilePath);
|
||||
}
|
||||
};
|
||||
|
||||
const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
||||
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
|
||||
const session = await cachedCLIPImageSession();
|
||||
const t1 = Date.now();
|
||||
const rgbData = await getRGBData(jpegFilePath);
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
|
||||
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
|
||||
};
|
||||
const t2 = Date.now();
|
||||
const results = await session.run(feeds);
|
||||
log.debug(
|
||||
() =>
|
||||
`ONNX/CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to type the result */
|
||||
const imageEmbedding = ensure(results.output).data as Float32Array;
|
||||
return normalizeEmbedding(imageEmbedding);
|
||||
};
|
||||
|
||||
const getRGBData = async (jpegFilePath: string): Promise<Float32Array> => {
|
||||
const jpegData = await fs.readFile(jpegFilePath);
|
||||
const rawImageData = jpeg.decode(jpegData, {
|
||||
useTArray: true,
|
||||
formatAsRGBA: true,
|
||||
}); // TODO: manav: make sure this works on all images, not just jpeg
|
||||
const pixelData = rawImageData.data;
|
||||
|
||||
const requiredWidth = 224;
|
||||
const requiredHeight = 224;
|
||||
const requiredSize = 3 * requiredWidth * requiredHeight;
|
||||
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
|
||||
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
|
||||
|
||||
const scale = Math.max(
|
||||
requiredWidth / rawImageData.width,
|
||||
requiredHeight / rawImageData.height,
|
||||
);
|
||||
const scaledWidth = Math.round(rawImageData.width * scale);
|
||||
const scaledHeight = Math.round(rawImageData.height * scale);
|
||||
const widthOffset = Math.max(0, scaledWidth - requiredWidth) / 2;
|
||||
const heightOffset = Math.max(0, scaledHeight - requiredHeight) / 2;
|
||||
|
||||
const processedImage = new Float32Array(requiredSize);
|
||||
|
||||
// Populate the Float32Array with normalized pixel values.
|
||||
let pi = 0;
|
||||
const cOffsetG = requiredHeight * requiredWidth; // ChannelOffsetGreen
|
||||
const cOffsetB = 2 * requiredHeight * requiredWidth; // ChannelOffsetBlue
|
||||
for (let h = 0 + heightOffset; h < scaledHeight - heightOffset; h++) {
|
||||
for (let w = 0 + widthOffset; w < scaledWidth - widthOffset; w++) {
|
||||
const { r, g, b } = pixelRGBBicubic(
|
||||
w / scale,
|
||||
h / scale,
|
||||
pixelData,
|
||||
rawImageData.width,
|
||||
rawImageData.height,
|
||||
);
|
||||
processedImage[pi] = (r / 255.0 - mean[0]!) / std[0]!;
|
||||
processedImage[pi + cOffsetG] = (g / 255.0 - mean[1]!) / std[1]!;
|
||||
processedImage[pi + cOffsetB] = (b / 255.0 - mean[2]!) / std[2]!;
|
||||
pi++;
|
||||
}
|
||||
}
|
||||
return processedImage;
|
||||
};
|
||||
|
||||
// NOTE: exact duplicate of the function in web/apps/photos/src/services/face/image.ts
|
||||
const pixelRGBBicubic = (
|
||||
fx: number,
|
||||
fy: number,
|
||||
imageData: Uint8Array,
|
||||
imageWidth: number,
|
||||
imageHeight: number,
|
||||
) => {
|
||||
// Clamp to image boundaries.
|
||||
fx = clamp(fx, 0, imageWidth - 1);
|
||||
fy = clamp(fy, 0, imageHeight - 1);
|
||||
|
||||
const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1);
|
||||
const px = x - 1;
|
||||
const nx = x + 1;
|
||||
const ax = x + 2;
|
||||
const y = Math.trunc(fy) - (fy >= 0.0 ? 0 : 1);
|
||||
const py = y - 1;
|
||||
const ny = y + 1;
|
||||
const ay = y + 2;
|
||||
const dx = fx - x;
|
||||
const dy = fy - y;
|
||||
|
||||
const cubic = (
|
||||
dx: number,
|
||||
ipp: number,
|
||||
icp: number,
|
||||
inp: number,
|
||||
iap: number,
|
||||
) =>
|
||||
icp +
|
||||
0.5 *
|
||||
(dx * (-ipp + inp) +
|
||||
dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) +
|
||||
dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap));
|
||||
|
||||
const icc = pixelRGBA(imageData, imageWidth, imageHeight, x, y);
|
||||
|
||||
const ipp =
|
||||
px < 0 || py < 0
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, px, py);
|
||||
const icp =
|
||||
px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, x, py);
|
||||
const inp =
|
||||
py < 0 || nx >= imageWidth
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, nx, py);
|
||||
const iap =
|
||||
ax >= imageWidth || py < 0
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, ax, py);
|
||||
|
||||
const ip0 = cubic(dx, ipp.r, icp.r, inp.r, iap.r);
|
||||
const ip1 = cubic(dx, ipp.g, icp.g, inp.g, iap.g);
|
||||
const ip2 = cubic(dx, ipp.b, icp.b, inp.b, iap.b);
|
||||
// const ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a);
|
||||
|
||||
const ipc =
|
||||
px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, px, y);
|
||||
const inc =
|
||||
nx >= imageWidth
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, nx, y);
|
||||
const iac =
|
||||
ax >= imageWidth
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, ax, y);
|
||||
|
||||
const ic0 = cubic(dx, ipc.r, icc.r, inc.r, iac.r);
|
||||
const ic1 = cubic(dx, ipc.g, icc.g, inc.g, iac.g);
|
||||
const ic2 = cubic(dx, ipc.b, icc.b, inc.b, iac.b);
|
||||
// const ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a);
|
||||
|
||||
const ipn =
|
||||
px < 0 || ny >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, px, ny);
|
||||
const icn =
|
||||
ny >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, x, ny);
|
||||
const inn =
|
||||
nx >= imageWidth || ny >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, nx, ny);
|
||||
const ian =
|
||||
ax >= imageWidth || ny >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, ax, ny);
|
||||
|
||||
const in0 = cubic(dx, ipn.r, icn.r, inn.r, ian.r);
|
||||
const in1 = cubic(dx, ipn.g, icn.g, inn.g, ian.g);
|
||||
const in2 = cubic(dx, ipn.b, icn.b, inn.b, ian.b);
|
||||
// const in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a);
|
||||
|
||||
const ipa =
|
||||
px < 0 || ay >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, px, ay);
|
||||
const ica =
|
||||
ay >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, x, ay);
|
||||
const ina =
|
||||
nx >= imageWidth || ay >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, nx, ay);
|
||||
const iaa =
|
||||
ax >= imageWidth || ay >= imageHeight
|
||||
? icc
|
||||
: pixelRGBA(imageData, imageWidth, imageHeight, ax, ay);
|
||||
|
||||
const ia0 = cubic(dx, ipa.r, ica.r, ina.r, iaa.r);
|
||||
const ia1 = cubic(dx, ipa.g, ica.g, ina.g, iaa.g);
|
||||
const ia2 = cubic(dx, ipa.b, ica.b, ina.b, iaa.b);
|
||||
// const ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a);
|
||||
|
||||
const c0 = Math.trunc(clamp(cubic(dy, ip0, ic0, in0, ia0), 0, 255));
|
||||
const c1 = Math.trunc(clamp(cubic(dy, ip1, ic1, in1, ia1), 0, 255));
|
||||
const c2 = Math.trunc(clamp(cubic(dy, ip2, ic2, in2, ia2), 0, 255));
|
||||
// const c3 = cubic(dy, ip3, ic3, in3, ia3);
|
||||
|
||||
return { r: c0, g: c1, b: c2 };
|
||||
};
|
||||
|
||||
// NOTE: exact duplicate of the function in web/apps/photos/src/services/face/image.ts
|
||||
const clamp = (value: number, min: number, max: number) =>
|
||||
Math.min(max, Math.max(min, value));
|
||||
|
||||
// NOTE: exact duplicate of the function in web/apps/photos/src/services/face/image.ts
|
||||
const pixelRGBA = (
|
||||
imageData: Uint8Array,
|
||||
width: number,
|
||||
height: number,
|
||||
x: number,
|
||||
y: number,
|
||||
) => {
|
||||
if (x < 0 || x >= width || y < 0 || y >= height) {
|
||||
return { r: 0, g: 0, b: 0, a: 0 };
|
||||
}
|
||||
const index = (y * width + x) * 4;
|
||||
return {
|
||||
r: ensure(imageData[index]),
|
||||
g: ensure(imageData[index + 1]),
|
||||
b: ensure(imageData[index + 2]),
|
||||
a: ensure(imageData[index + 3]),
|
||||
};
|
||||
};
|
||||
|
||||
const normalizeEmbedding = (embedding: Float32Array) => {
|
||||
let normalization = 0;
|
||||
for (const v of embedding) normalization += v * v;
|
||||
|
||||
const sqrtNormalization = Math.sqrt(normalization);
|
||||
for (let index = 0; index < embedding.length; index++)
|
||||
embedding[index] = ensure(embedding[index]) / sqrtNormalization;
|
||||
|
||||
return embedding;
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
@@ -302,6 +66,5 @@ export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||
() =>
|
||||
`ONNX/CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const textEmbedding = ensure(results.output).data as Float32Array;
|
||||
return normalizeEmbedding(textEmbedding);
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
@@ -163,8 +163,8 @@ const ffmpegExec = (
|
||||
|
||||
// - ML
|
||||
|
||||
const computeCLIPImageEmbedding = (jpegImageData: Uint8Array) =>
|
||||
ipcRenderer.invoke("computeCLIPImageEmbedding", jpegImageData);
|
||||
const computeCLIPImageEmbedding = (input: Float32Array) =>
|
||||
ipcRenderer.invoke("computeCLIPImageEmbedding", input);
|
||||
|
||||
const computeCLIPTextEmbeddingIfAvailable = (text: string) =>
|
||||
ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text);
|
||||
|
||||
@@ -1900,11 +1900,6 @@ hosted-git-info@^4.1.0:
|
||||
dependencies:
|
||||
lru-cache "^6.0.0"
|
||||
|
||||
html-entities@^2.5:
|
||||
version "2.5.2"
|
||||
resolved "https://registry.yarnpkg.com/html-entities/-/html-entities-2.5.2.tgz#201a3cf95d3a15be7099521620d19dfb4f65359f"
|
||||
integrity sha512-K//PSRMQk4FZ78Kyau+mZurHn3FH0Vwr+H36eE0rPbeYkRRi9YxceYPhuN60UwWorxyKHhqoAJl2OFKa4BVtaA==
|
||||
|
||||
http-cache-semantics@^4.0.0, http-cache-semantics@^4.1.0:
|
||||
version "4.1.1"
|
||||
resolved "https://registry.yarnpkg.com/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz#abe02fcb2985460bf0323be664436ec3476a6d5a"
|
||||
@@ -2130,11 +2125,6 @@ jake@^10.8.5:
|
||||
filelist "^1.0.4"
|
||||
minimatch "^3.1.2"
|
||||
|
||||
jpeg-js@^0.4:
|
||||
version "0.4.4"
|
||||
resolved "https://registry.yarnpkg.com/jpeg-js/-/jpeg-js-0.4.4.tgz#a9f1c6f1f9f0fa80cdb3484ed9635054d28936aa"
|
||||
integrity sha512-WZzeDOEtTOBK4Mdsar0IqEU5sMr3vSV2RqkAIzUEV2BHnUfKGyswWFPFwK5EeDo93K3FohSHbLAjj0s1Wzd+dg==
|
||||
|
||||
js-yaml@^4.1.0:
|
||||
version "4.1.0"
|
||||
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-4.1.0.tgz#c1fb65f8f5017901cdd2c951864ba18458a10602"
|
||||
|
||||
@@ -28,7 +28,6 @@
|
||||
"memoize-one": "^6.0.0",
|
||||
"ml-matrix": "^6.11",
|
||||
"p-debounce": "^4.0.0",
|
||||
"p-queue": "^7.1.0",
|
||||
"photoswipe": "file:./thirdparty/photoswipe",
|
||||
"piexifjs": "^1.0.6",
|
||||
"pure-react-carousel": "^1.30.1",
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { VerticallyCenteredFlex } from "@ente/shared/components/Container";
|
||||
import { EnteMenuItem } from "@ente/shared/components/Menu/EnteMenuItem";
|
||||
import ChevronRight from "@mui/icons-material/ChevronRight";
|
||||
import ScienceIcon from "@mui/icons-material/Science";
|
||||
import { Box, DialogProps, Stack, Typography } from "@mui/material";
|
||||
import { Box, DialogProps, Stack } from "@mui/material";
|
||||
import { EnteDrawer } from "components/EnteDrawer";
|
||||
import { MenuItemGroup } from "components/Menu/MenuItemGroup";
|
||||
import MenuSectionTitle from "components/Menu/MenuSectionTitle";
|
||||
@@ -11,9 +10,7 @@ import { MLSearchSettings } from "components/ml/MLSearchSettings";
|
||||
import { t } from "i18next";
|
||||
import isElectron from "is-electron";
|
||||
import { AppContext } from "pages/_app";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { CLIPIndexingStatus, clipService } from "services/clip-service";
|
||||
import { formatNumber } from "utils/number/format";
|
||||
import { useContext, useState } from "react";
|
||||
|
||||
export default function AdvancedSettings({ open, onClose, onRootClose }) {
|
||||
const appContext = useContext(AppContext);
|
||||
@@ -39,16 +36,17 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
|
||||
appContext.setIsCFProxyDisabled(!appContext.isCFProxyDisabled);
|
||||
};
|
||||
|
||||
const [indexingStatus, setIndexingStatus] = useState<CLIPIndexingStatus>({
|
||||
indexed: 0,
|
||||
pending: 0,
|
||||
});
|
||||
// TODO-ML:
|
||||
// const [indexingStatus, setIndexingStatus] = useState<CLIPIndexingStatus>({
|
||||
// indexed: 0,
|
||||
// pending: 0,
|
||||
// });
|
||||
|
||||
useEffect(() => {
|
||||
clipService.setOnUpdateHandler(setIndexingStatus);
|
||||
clipService.getIndexingStatus().then((st) => setIndexingStatus(st));
|
||||
return () => clipService.setOnUpdateHandler(undefined);
|
||||
}, []);
|
||||
// useEffect(() => {
|
||||
// clipService.setOnUpdateHandler(setIndexingStatus);
|
||||
// clipService.getIndexingStatus().then((st) => setIndexingStatus(st));
|
||||
// return () => clipService.setOnUpdateHandler(undefined);
|
||||
// }, []);
|
||||
|
||||
return (
|
||||
<EnteDrawer
|
||||
@@ -97,7 +95,7 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
|
||||
/>
|
||||
</Box>
|
||||
|
||||
{isElectron() && (
|
||||
{/* TODO-ML: isElectron() && (
|
||||
<Box>
|
||||
<MenuSectionTitle
|
||||
title={t("MAGIC_SEARCH_STATUS")}
|
||||
@@ -131,7 +129,7 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
|
||||
</VerticallyCenteredFlex>
|
||||
</Stack>
|
||||
</Box>
|
||||
)}
|
||||
)*/}
|
||||
</Stack>
|
||||
</Box>
|
||||
</Stack>
|
||||
|
||||
@@ -81,7 +81,6 @@ import {
|
||||
useState,
|
||||
} from "react";
|
||||
import { useDropzone } from "react-dropzone";
|
||||
import { clipService } from "services/clip-service";
|
||||
import {
|
||||
constructEmailList,
|
||||
constructUserIDToEmailMap,
|
||||
@@ -390,7 +389,6 @@ export default function Gallery() {
|
||||
syncWithRemote(false, true);
|
||||
}, SYNC_INTERVAL_IN_MICROSECONDS);
|
||||
if (electron) {
|
||||
// void clipService.setupOnFileUploadListener();
|
||||
electron.onMainWindowFocus(() => syncWithRemote(false, true));
|
||||
if (await shouldShowWhatsNew()) setOpenWhatsNew(true);
|
||||
}
|
||||
@@ -398,10 +396,7 @@ export default function Gallery() {
|
||||
main();
|
||||
return () => {
|
||||
clearInterval(syncInterval.current);
|
||||
if (electron) {
|
||||
electron.onMainWindowFocus(undefined);
|
||||
clipService.removeOnFileUploadListener();
|
||||
}
|
||||
if (electron) electron.onMainWindowFocus(undefined);
|
||||
};
|
||||
}, []);
|
||||
|
||||
|
||||
@@ -1,402 +0,0 @@
|
||||
import { FILE_TYPE } from "@/media/file-type";
|
||||
import downloadManager from "@/new/photos/services/download";
|
||||
import { getAllLocalFiles, getLocalFiles } from "@/new/photos/services/files";
|
||||
import { EnteFile } from "@/new/photos/types/file";
|
||||
import { ensureElectron } from "@/next/electron";
|
||||
import log from "@/next/log";
|
||||
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
||||
import { CustomError } from "@ente/shared/error";
|
||||
import { Events, eventBus } from "@ente/shared/events";
|
||||
import { LS_KEYS, getData } from "@ente/shared/storage/localStorage";
|
||||
import PQueue from "p-queue";
|
||||
import { Embedding } from "types/embedding";
|
||||
import { getPersonalFiles } from "utils/file";
|
||||
import { localCLIPEmbeddings, putEmbedding } from "./embeddingService";
|
||||
|
||||
/** Status of CLIP indexing on the images in the user's local library. */
|
||||
export interface CLIPIndexingStatus {
|
||||
/** Number of items pending indexing. */
|
||||
pending: number;
|
||||
/** Number of items that have already been indexed. */
|
||||
indexed: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a CLIP based neural network for natural language search.
|
||||
*
|
||||
* [Note: CLIP based magic search]
|
||||
*
|
||||
* CLIP (Contrastive Language-Image Pretraining) is a neural network trained on
|
||||
* (image, text) pairs. It can be thought of as two separate (but jointly
|
||||
* trained) encoders - one for images, and one for text - that both map to the
|
||||
* same embedding space.
|
||||
*
|
||||
* We use this for natural language search within the app (aka "magic search"):
|
||||
*
|
||||
* 1. Pre-compute an embedding for each image.
|
||||
*
|
||||
* 2. When the user searches, compute an embedding for the search term.
|
||||
*
|
||||
* 3. Use cosine similarity to find the find the image (embedding) closest to
|
||||
* the text (embedding).
|
||||
*
|
||||
* More details are in our [blog
|
||||
* post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the
|
||||
* initial launch of this feature using the GGML runtime.
|
||||
*
|
||||
* Since the initial launch, we've switched over to another runtime,
|
||||
* [ONNX](https://onnxruntime.ai).
|
||||
*
|
||||
* Note that we don't train the neural network - we only use one of the publicly
|
||||
* available pre-trained neural networks for inference. These neural networks
|
||||
* are wholly defined by their connectivity and weights. ONNX, our ML runtimes,
|
||||
* loads these weights and instantiates a running network that we can use to
|
||||
* compute the embeddings.
|
||||
*
|
||||
* Theoretically, the same CLIP model can be loaded by different frameworks /
|
||||
* runtimes, but in practice each runtime has its own preferred format, and
|
||||
* there are also quantization tradeoffs. So there is a specific model (a binary
|
||||
* encoding of weights) tied to our current runtime that we use.
|
||||
*
|
||||
* To ensure that the embeddings, for the most part, can be shared, whenever
|
||||
* possible we try to ensure that all the preprocessing steps, and the model
|
||||
* itself, is the same across clients - web and mobile.
|
||||
*/
|
||||
class CLIPService {
|
||||
private embeddingExtractionInProgress: AbortController | null = null;
|
||||
private reRunNeeded = false;
|
||||
private indexingStatus: CLIPIndexingStatus = {
|
||||
pending: 0,
|
||||
indexed: 0,
|
||||
};
|
||||
private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined;
|
||||
private liveEmbeddingExtractionQueue: PQueue;
|
||||
private onFileUploadedHandler:
|
||||
| ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void)
|
||||
| null = null;
|
||||
|
||||
constructor() {
|
||||
this.liveEmbeddingExtractionQueue = new PQueue({
|
||||
concurrency: 1,
|
||||
});
|
||||
}
|
||||
|
||||
async logout() {
|
||||
if (this.embeddingExtractionInProgress) {
|
||||
this.embeddingExtractionInProgress.abort();
|
||||
}
|
||||
if (this.onFileUploadedHandler) {
|
||||
await this.removeOnFileUploadListener();
|
||||
}
|
||||
}
|
||||
|
||||
setupOnFileUploadListener = async () => {
|
||||
try {
|
||||
if (this.onFileUploadedHandler) {
|
||||
log.info("file upload listener already setup");
|
||||
return;
|
||||
}
|
||||
log.info("setting up file upload listener");
|
||||
this.onFileUploadedHandler = (args) => {
|
||||
this.runLocalFileClipExtraction(args);
|
||||
};
|
||||
eventBus.on(Events.FILE_UPLOADED, this.onFileUploadedHandler, this);
|
||||
log.info("setup file upload listener successfully");
|
||||
} catch (e) {
|
||||
log.error("failed to setup clip service", e);
|
||||
}
|
||||
};
|
||||
|
||||
removeOnFileUploadListener = async () => {
|
||||
try {
|
||||
if (!this.onFileUploadedHandler) {
|
||||
log.info("file upload listener already removed");
|
||||
return;
|
||||
}
|
||||
log.info("removing file upload listener");
|
||||
eventBus.removeListener(
|
||||
Events.FILE_UPLOADED,
|
||||
this.onFileUploadedHandler,
|
||||
this,
|
||||
);
|
||||
this.onFileUploadedHandler = null;
|
||||
log.info("removed file upload listener successfully");
|
||||
} catch (e) {
|
||||
log.error("failed to remove clip service", e);
|
||||
}
|
||||
};
|
||||
|
||||
getIndexingStatus = async () => {
|
||||
if (
|
||||
this.indexingStatus.pending === 0 &&
|
||||
this.indexingStatus.indexed === 0
|
||||
) {
|
||||
this.indexingStatus = await initialIndexingStatus();
|
||||
}
|
||||
return this.indexingStatus;
|
||||
};
|
||||
|
||||
/**
|
||||
* Set the {@link handler} to invoke whenever our indexing status changes.
|
||||
*/
|
||||
setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => {
|
||||
this.onUpdateHandler = handler;
|
||||
};
|
||||
|
||||
scheduleImageEmbeddingExtraction = async () => {
|
||||
try {
|
||||
if (this.embeddingExtractionInProgress) {
|
||||
log.info(
|
||||
"clip embedding extraction already in progress, scheduling re-run",
|
||||
);
|
||||
this.reRunNeeded = true;
|
||||
return;
|
||||
} else {
|
||||
log.info(
|
||||
"clip embedding extraction not in progress, starting clip embedding extraction",
|
||||
);
|
||||
}
|
||||
const canceller = new AbortController();
|
||||
this.embeddingExtractionInProgress = canceller;
|
||||
try {
|
||||
await this.runClipEmbeddingExtraction(canceller);
|
||||
} finally {
|
||||
this.embeddingExtractionInProgress = null;
|
||||
if (!canceller.signal.aborted && this.reRunNeeded) {
|
||||
this.reRunNeeded = false;
|
||||
log.info("re-running clip embedding extraction");
|
||||
setTimeout(
|
||||
() => this.scheduleImageEmbeddingExtraction(),
|
||||
0,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (e.message !== CustomError.REQUEST_CANCELLED) {
|
||||
log.error("failed to schedule clip embedding extraction", e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
getTextEmbeddingIfAvailable = async (text: string) => {
|
||||
return ensureElectron().computeCLIPTextEmbeddingIfAvailable(text);
|
||||
};
|
||||
|
||||
private runClipEmbeddingExtraction = async (canceller: AbortController) => {
|
||||
try {
|
||||
const user = getData(LS_KEYS.USER);
|
||||
if (!user) {
|
||||
return;
|
||||
}
|
||||
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
|
||||
const existingEmbeddings = await localCLIPEmbeddings();
|
||||
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
||||
localFiles,
|
||||
existingEmbeddings,
|
||||
);
|
||||
this.updateIndexingStatus({
|
||||
indexed: existingEmbeddings.length,
|
||||
pending: pendingFiles.length,
|
||||
});
|
||||
if (pendingFiles.length === 0) {
|
||||
log.info("no clip embedding extraction needed, all done");
|
||||
return;
|
||||
}
|
||||
log.info(
|
||||
`starting clip embedding extraction for ${pendingFiles.length} files`,
|
||||
);
|
||||
for (const file of pendingFiles) {
|
||||
try {
|
||||
log.info(
|
||||
`extracting clip embedding for file: ${file.metadata.title} fileID: ${file.id}`,
|
||||
);
|
||||
if (canceller.signal.aborted) {
|
||||
throw Error(CustomError.REQUEST_CANCELLED);
|
||||
}
|
||||
const embeddingData =
|
||||
await this.extractFileClipImageEmbedding(file);
|
||||
log.info(
|
||||
`successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`,
|
||||
);
|
||||
await this.encryptAndUploadEmbedding(file, embeddingData);
|
||||
this.onSuccessStatusUpdater();
|
||||
log.info(
|
||||
`successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
|
||||
);
|
||||
} catch (e) {
|
||||
if (e?.message !== CustomError.REQUEST_CANCELLED) {
|
||||
log.error(
|
||||
"failed to extract clip embedding for file",
|
||||
e,
|
||||
);
|
||||
}
|
||||
if (
|
||||
e?.message === CustomError.REQUEST_CANCELLED ||
|
||||
e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)
|
||||
) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (e.message !== CustomError.REQUEST_CANCELLED) {
|
||||
log.error("failed to extract clip embedding", e);
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
private async runLocalFileClipExtraction(arg: {
|
||||
enteFile: EnteFile;
|
||||
localFile: globalThis.File;
|
||||
}) {
|
||||
const { enteFile, localFile } = arg;
|
||||
log.info(
|
||||
`clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
||||
enteFile.id,
|
||||
);
|
||||
if (enteFile.metadata.fileType === FILE_TYPE.VIDEO) {
|
||||
log.info(
|
||||
`skipping video file for clip embedding extraction file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
const extension = enteFile.metadata.title.split(".").pop();
|
||||
if (!extension || !["jpg", "jpeg"].includes(extension)) {
|
||||
log.info(
|
||||
`skipping non jpg file for clip embedding extraction file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
log.info(
|
||||
`queuing up for local clip embedding extraction for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
||||
);
|
||||
try {
|
||||
await this.liveEmbeddingExtractionQueue.add(async () => {
|
||||
const embedding =
|
||||
await this.extractLocalFileClipImageEmbedding(localFile);
|
||||
await this.encryptAndUploadEmbedding(enteFile, embedding);
|
||||
});
|
||||
log.info(
|
||||
`successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
||||
);
|
||||
} catch (e) {
|
||||
log.error("Failed in ML onFileUploadedHandler", e);
|
||||
}
|
||||
}
|
||||
|
||||
private extractLocalFileClipImageEmbedding = async (localFile: File) => {
|
||||
const file = await localFile
|
||||
.arrayBuffer()
|
||||
.then((buffer) => new Uint8Array(buffer));
|
||||
return await ensureElectron().computeCLIPImageEmbedding(file);
|
||||
};
|
||||
|
||||
private encryptAndUploadEmbedding = async (
|
||||
file: EnteFile,
|
||||
embeddingData: Float32Array,
|
||||
) => {
|
||||
if (embeddingData?.length !== 512) {
|
||||
throw Error(
|
||||
`invalid length embedding data length: ${embeddingData?.length}`,
|
||||
);
|
||||
}
|
||||
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
|
||||
const { file: encryptedEmbeddingData } =
|
||||
await comlinkCryptoWorker.encryptEmbedding(embeddingData, file.key);
|
||||
log.info(
|
||||
`putting clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
|
||||
);
|
||||
await putEmbedding({
|
||||
fileID: file.id,
|
||||
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
|
||||
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
|
||||
model: "onnx-clip",
|
||||
});
|
||||
};
|
||||
|
||||
private updateIndexingStatus = (status: CLIPIndexingStatus) => {
|
||||
this.indexingStatus = status;
|
||||
const handler = this.onUpdateHandler;
|
||||
if (handler) handler(status);
|
||||
};
|
||||
|
||||
private extractFileClipImageEmbedding = async (file: EnteFile) => {
|
||||
const thumb = await downloadManager.getThumbnail(file);
|
||||
const embedding =
|
||||
await ensureElectron().computeCLIPImageEmbedding(thumb);
|
||||
return embedding;
|
||||
};
|
||||
|
||||
private onSuccessStatusUpdater = () => {
|
||||
this.updateIndexingStatus({
|
||||
pending: this.indexingStatus.pending - 1,
|
||||
indexed: this.indexingStatus.indexed + 1,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export const clipService = new CLIPService();
|
||||
|
||||
const getNonClipEmbeddingExtractedFiles = async (
|
||||
files: EnteFile[],
|
||||
existingEmbeddings: Embedding[],
|
||||
) => {
|
||||
const existingEmbeddingFileIds = new Set<number>();
|
||||
existingEmbeddings.forEach((embedding) =>
|
||||
existingEmbeddingFileIds.add(embedding.fileID),
|
||||
);
|
||||
const idSet = new Set<number>();
|
||||
return files.filter((file) => {
|
||||
if (idSet.has(file.id)) {
|
||||
return false;
|
||||
}
|
||||
if (existingEmbeddingFileIds.has(file.id)) {
|
||||
return false;
|
||||
}
|
||||
idSet.add(file.id);
|
||||
return true;
|
||||
});
|
||||
};
|
||||
|
||||
export const computeClipMatchScore = async (
|
||||
imageEmbedding: Float32Array,
|
||||
textEmbedding: Float32Array,
|
||||
) => {
|
||||
if (imageEmbedding.length !== textEmbedding.length) {
|
||||
throw Error("imageEmbedding and textEmbedding length mismatch");
|
||||
}
|
||||
let score = 0;
|
||||
let imageNormalization = 0;
|
||||
let textNormalization = 0;
|
||||
|
||||
for (let index = 0; index < imageEmbedding.length; index++) {
|
||||
imageNormalization += imageEmbedding[index] * imageEmbedding[index];
|
||||
textNormalization += textEmbedding[index] * textEmbedding[index];
|
||||
}
|
||||
for (let index = 0; index < imageEmbedding.length; index++) {
|
||||
imageEmbedding[index] =
|
||||
imageEmbedding[index] / Math.sqrt(imageNormalization);
|
||||
textEmbedding[index] =
|
||||
textEmbedding[index] / Math.sqrt(textNormalization);
|
||||
}
|
||||
for (let index = 0; index < imageEmbedding.length; index++) {
|
||||
score += imageEmbedding[index] * textEmbedding[index];
|
||||
}
|
||||
return score;
|
||||
};
|
||||
|
||||
const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
|
||||
const user = getData(LS_KEYS.USER);
|
||||
if (!user) throw new Error("Orphan CLIP indexing without a login");
|
||||
const allEmbeddings = await localCLIPEmbeddings();
|
||||
const localFiles = getPersonalFiles(await getLocalFiles(), user);
|
||||
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
||||
localFiles,
|
||||
allEmbeddings,
|
||||
);
|
||||
return {
|
||||
indexed: allEmbeddings.length,
|
||||
pending: pendingFiles.length,
|
||||
};
|
||||
};
|
||||
@@ -1,352 +0,0 @@
|
||||
import {
|
||||
getAllLocalFiles,
|
||||
getLocalTrashedFiles,
|
||||
} from "@/new/photos/services/files";
|
||||
import type { EmbeddingModel } from "@/new/photos/services/ml/embedding";
|
||||
import type { FaceIndex } from "@/new/photos/services/ml/face";
|
||||
import { EnteFile } from "@/new/photos/types/file";
|
||||
import { inWorker } from "@/next/env";
|
||||
import log from "@/next/log";
|
||||
import { apiURL } from "@/next/origins";
|
||||
import { workerBridge } from "@/next/worker/worker-bridge";
|
||||
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
||||
import { CustomError } from "@ente/shared/error";
|
||||
import HTTPService from "@ente/shared/network/HTTPService";
|
||||
import localForage from "@ente/shared/storage/localForage";
|
||||
import { getToken } from "@ente/shared/storage/localStorage/helpers";
|
||||
import type {
|
||||
Embedding,
|
||||
EncryptedEmbedding,
|
||||
GetEmbeddingDiffResponse,
|
||||
PutEmbeddingRequest,
|
||||
} from "types/embedding";
|
||||
import { getLocalCollections } from "./collectionService";
|
||||
|
||||
type FileML = FaceIndex & {
|
||||
updatedAt: number;
|
||||
};
|
||||
const DIFF_LIMIT = 500;
|
||||
|
||||
/** Local storage key suffix for embedding sync times */
|
||||
const embeddingSyncTimeLSKeySuffix = "embedding_sync_time";
|
||||
/** Local storage key for CLIP embeddings. */
|
||||
const clipEmbeddingsLSKey = "embeddings_v2";
|
||||
const FILE_EMBEDING_TABLE = "file_embeddings";
|
||||
|
||||
/** Return all CLIP embeddings that we have available locally. */
|
||||
export const localCLIPEmbeddings = async () =>
|
||||
(await storedCLIPEmbeddings()).filter(({ model }) => model === "onnx-clip");
|
||||
|
||||
const storedCLIPEmbeddings = async () => {
|
||||
const embeddings: Array<Embedding> =
|
||||
await localForage.getItem<Embedding[]>(clipEmbeddingsLSKey);
|
||||
if (!embeddings) {
|
||||
// Migrate
|
||||
await localForage.removeItem("embeddings");
|
||||
await localForage.removeItem("embedding_sync_time");
|
||||
await localForage.setItem(clipEmbeddingsLSKey, []);
|
||||
return [];
|
||||
}
|
||||
return embeddings;
|
||||
};
|
||||
|
||||
export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
|
||||
const embeddings: Array<FileML> =
|
||||
await localForage.getItem<FileML[]>(FILE_EMBEDING_TABLE);
|
||||
if (!embeddings) {
|
||||
return [];
|
||||
}
|
||||
return embeddings;
|
||||
};
|
||||
|
||||
const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => {
|
||||
return (
|
||||
(await localForage.getItem<number>(
|
||||
`${model}-${embeddingSyncTimeLSKeySuffix}`,
|
||||
)) ?? 0
|
||||
);
|
||||
};
|
||||
|
||||
const setModelEmbeddingSyncTime = async (
|
||||
model: EmbeddingModel,
|
||||
time: number,
|
||||
) => {
|
||||
await localForage.setItem(`${model}-${embeddingSyncTimeLSKeySuffix}`, time);
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetch new CLIP embeddings with the server and save them locally. Also prune
|
||||
* local embeddings for any files no longer exist locally.
|
||||
*/
|
||||
export const syncCLIPEmbeddings = async () => {
|
||||
const model: EmbeddingModel = "onnx-clip";
|
||||
try {
|
||||
let allEmbeddings = await storedCLIPEmbeddings();
|
||||
const localFiles = await getAllLocalFiles();
|
||||
const hiddenAlbums = await getLocalCollections("hidden");
|
||||
const localTrashFiles = await getLocalTrashedFiles();
|
||||
const fileIdToKeyMap = new Map<number, string>();
|
||||
const allLocalFiles = [...localFiles, ...localTrashFiles];
|
||||
allLocalFiles.forEach((file) => {
|
||||
fileIdToKeyMap.set(file.id, file.key);
|
||||
});
|
||||
await cleanupDeletedEmbeddings(
|
||||
allLocalFiles,
|
||||
allEmbeddings,
|
||||
clipEmbeddingsLSKey,
|
||||
);
|
||||
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
|
||||
|
||||
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
||||
log.info(
|
||||
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
||||
);
|
||||
let response: GetEmbeddingDiffResponse;
|
||||
do {
|
||||
response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
||||
if (!response.diff?.length) {
|
||||
return;
|
||||
}
|
||||
// Note: in rare cases we might get a diff entry for an embedding
|
||||
// corresponding to a file which has been deleted (but whose
|
||||
// embedding is enqueued for deletion). Client should expect such a
|
||||
// scenario (all it has to do is just ignore them).
|
||||
const newEmbeddings = await Promise.all(
|
||||
response.diff.map(async (embedding) => {
|
||||
try {
|
||||
const {
|
||||
encryptedEmbedding,
|
||||
decryptionHeader,
|
||||
...rest
|
||||
} = embedding;
|
||||
const worker = await ComlinkCryptoWorker.getInstance();
|
||||
const fileKey = fileIdToKeyMap.get(embedding.fileID);
|
||||
if (!fileKey) {
|
||||
throw Error(CustomError.FILE_NOT_FOUND);
|
||||
}
|
||||
const decryptedData = await worker.decryptEmbedding(
|
||||
encryptedEmbedding,
|
||||
decryptionHeader,
|
||||
fileIdToKeyMap.get(embedding.fileID),
|
||||
);
|
||||
|
||||
return {
|
||||
...rest,
|
||||
embedding: decryptedData,
|
||||
} as Embedding;
|
||||
} catch (e) {
|
||||
let hasHiddenAlbums = false;
|
||||
if (e.message === CustomError.FILE_NOT_FOUND) {
|
||||
hasHiddenAlbums = hiddenAlbums?.length > 0;
|
||||
}
|
||||
log.error(
|
||||
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
||||
e,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
allEmbeddings = getLatestVersionEmbeddings([
|
||||
...allEmbeddings,
|
||||
...newEmbeddings,
|
||||
]);
|
||||
modelLastSinceTime = response.diff.reduce(
|
||||
(max, { updatedAt }) => Math.max(max, updatedAt),
|
||||
modelLastSinceTime,
|
||||
);
|
||||
await localForage.setItem(clipEmbeddingsLSKey, allEmbeddings);
|
||||
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
||||
log.info(
|
||||
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
||||
);
|
||||
} while (response.diff.length > 0);
|
||||
} catch (e) {
|
||||
log.error("Sync embeddings failed", e);
|
||||
}
|
||||
};
|
||||
|
||||
export const syncFaceEmbeddings = async () => {
|
||||
const model: EmbeddingModel = "file-ml-clip-face";
|
||||
try {
|
||||
let allEmbeddings: FileML[] = await getFileMLEmbeddings();
|
||||
const localFiles = await getAllLocalFiles();
|
||||
const hiddenAlbums = await getLocalCollections("hidden");
|
||||
const localTrashFiles = await getLocalTrashedFiles();
|
||||
const fileIdToKeyMap = new Map<number, string>();
|
||||
const allLocalFiles = [...localFiles, ...localTrashFiles];
|
||||
allLocalFiles.forEach((file) => {
|
||||
fileIdToKeyMap.set(file.id, file.key);
|
||||
});
|
||||
await cleanupDeletedEmbeddings(
|
||||
allLocalFiles,
|
||||
allEmbeddings,
|
||||
FILE_EMBEDING_TABLE,
|
||||
);
|
||||
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
|
||||
|
||||
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
||||
log.info(
|
||||
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
||||
);
|
||||
let response: GetEmbeddingDiffResponse;
|
||||
do {
|
||||
response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
||||
if (!response.diff?.length) {
|
||||
return;
|
||||
}
|
||||
const newEmbeddings = await Promise.all(
|
||||
response.diff.map(async (embedding) => {
|
||||
try {
|
||||
const worker = await ComlinkCryptoWorker.getInstance();
|
||||
const fileKey = fileIdToKeyMap.get(embedding.fileID);
|
||||
if (!fileKey) {
|
||||
throw Error(CustomError.FILE_NOT_FOUND);
|
||||
}
|
||||
const decryptedData = await worker.decryptMetadata(
|
||||
embedding.encryptedEmbedding,
|
||||
embedding.decryptionHeader,
|
||||
fileIdToKeyMap.get(embedding.fileID),
|
||||
);
|
||||
|
||||
return {
|
||||
...decryptedData,
|
||||
updatedAt: embedding.updatedAt,
|
||||
} as unknown as FileML;
|
||||
} catch (e) {
|
||||
let hasHiddenAlbums = false;
|
||||
if (e.message === CustomError.FILE_NOT_FOUND) {
|
||||
hasHiddenAlbums = hiddenAlbums?.length > 0;
|
||||
}
|
||||
log.error(
|
||||
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
||||
e,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
allEmbeddings = getLatestVersionFileEmbeddings([
|
||||
...allEmbeddings,
|
||||
...newEmbeddings,
|
||||
]);
|
||||
modelLastSinceTime = response.diff.reduce(
|
||||
(max, { updatedAt }) => Math.max(max, updatedAt),
|
||||
modelLastSinceTime,
|
||||
);
|
||||
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
|
||||
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
||||
log.info(
|
||||
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
||||
);
|
||||
} while (response.diff.length > 0);
|
||||
} catch (e) {
|
||||
log.error("Sync embeddings failed", e);
|
||||
}
|
||||
};
|
||||
|
||||
const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
|
||||
const latestVersionEntities = new Map<number, Embedding>();
|
||||
embeddings.forEach((embedding) => {
|
||||
if (!embedding?.fileID) {
|
||||
return;
|
||||
}
|
||||
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
|
||||
if (
|
||||
!existingEmbeddings ||
|
||||
existingEmbeddings.updatedAt < embedding.updatedAt
|
||||
) {
|
||||
latestVersionEntities.set(embedding.fileID, embedding);
|
||||
}
|
||||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
||||
|
||||
const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
|
||||
const latestVersionEntities = new Map<number, FileML>();
|
||||
embeddings.forEach((embedding) => {
|
||||
if (!embedding?.fileID) {
|
||||
return;
|
||||
}
|
||||
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
|
||||
if (
|
||||
!existingEmbeddings ||
|
||||
existingEmbeddings.updatedAt < embedding.updatedAt
|
||||
) {
|
||||
latestVersionEntities.set(embedding.fileID, embedding);
|
||||
}
|
||||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
||||
|
||||
export const getEmbeddingsDiff = async (
|
||||
sinceTime: number,
|
||||
model: EmbeddingModel,
|
||||
): Promise<GetEmbeddingDiffResponse> => {
|
||||
try {
|
||||
const token = getToken();
|
||||
if (!token) {
|
||||
return;
|
||||
}
|
||||
const response = await HTTPService.get(
|
||||
await apiURL("/embeddings/diff"),
|
||||
{
|
||||
sinceTime,
|
||||
limit: DIFF_LIMIT,
|
||||
model,
|
||||
},
|
||||
{
|
||||
"X-Auth-Token": token,
|
||||
},
|
||||
);
|
||||
return await response.data;
|
||||
} catch (e) {
|
||||
log.error("get embeddings diff failed", e);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
export const putEmbedding = async (
|
||||
putEmbeddingReq: PutEmbeddingRequest,
|
||||
): Promise<EncryptedEmbedding> => {
|
||||
try {
|
||||
const token = inWorker()
|
||||
? await workerBridge.getAuthToken()
|
||||
: getToken();
|
||||
if (!token) {
|
||||
log.info("putEmbedding failed: token not found");
|
||||
throw Error(CustomError.TOKEN_MISSING);
|
||||
}
|
||||
const resp = await HTTPService.put(
|
||||
await apiURL("/embeddings"),
|
||||
putEmbeddingReq,
|
||||
null,
|
||||
{
|
||||
"X-Auth-Token": token,
|
||||
},
|
||||
);
|
||||
return resp.data;
|
||||
} catch (e) {
|
||||
log.error("put embedding failed", e);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
export const cleanupDeletedEmbeddings = async (
|
||||
allLocalFiles: EnteFile[],
|
||||
allLocalEmbeddings: Embedding[] | FileML[],
|
||||
tableName: string,
|
||||
) => {
|
||||
const activeFileIds = new Set<number>();
|
||||
allLocalFiles.forEach((file) => {
|
||||
activeFileIds.add(file.id);
|
||||
});
|
||||
|
||||
const remainingEmbeddings = allLocalEmbeddings.filter((embedding) =>
|
||||
activeFileIds.has(embedding.fileID),
|
||||
);
|
||||
if (allLocalEmbeddings.length !== remainingEmbeddings.length) {
|
||||
log.info(
|
||||
`cleanupDeletedEmbeddings embeddingsCount: ${allLocalEmbeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`,
|
||||
);
|
||||
await localForage.setItem(tableName, remainingEmbeddings);
|
||||
}
|
||||
};
|
||||
@@ -3,7 +3,6 @@ import DownloadManager from "@/new/photos/services/download";
|
||||
import { clearFeatureFlagSessionState } from "@/new/photos/services/feature-flags";
|
||||
import { logoutML, terminateMLWorker } from "@/new/photos/services/ml";
|
||||
import log from "@/next/log";
|
||||
import { clipService } from "services/clip-service";
|
||||
import exportService from "./export";
|
||||
|
||||
/**
|
||||
@@ -46,12 +45,6 @@ export const photosLogout = async () => {
|
||||
ignoreError("download", e);
|
||||
}
|
||||
|
||||
try {
|
||||
await clipService.logout();
|
||||
} catch (e) {
|
||||
ignoreError("CLIP", e);
|
||||
}
|
||||
|
||||
// - Desktop
|
||||
|
||||
const electron = globalThis.electron;
|
||||
|
||||
@@ -19,15 +19,11 @@ import {
|
||||
import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker";
|
||||
import { getUniqueFiles } from "utils/file";
|
||||
import { getFormattedDate } from "utils/search";
|
||||
import { clipService, computeClipMatchScore } from "./clip-service";
|
||||
import { localCLIPEmbeddings } from "./embeddingService";
|
||||
import { getLatestEntities } from "./entityService";
|
||||
import locationSearchService, { City } from "./locationSearchService";
|
||||
|
||||
const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]);
|
||||
|
||||
const CLIP_SCORE_THRESHOLD = 0.23;
|
||||
|
||||
export const getDefaultOptions = async () => {
|
||||
return [
|
||||
// TODO-ML(MR): Skip this for now if indexing is disabled (eventually
|
||||
@@ -373,30 +369,14 @@ async function searchLocationTag(searchPhrase: string): Promise<LocationTag[]> {
|
||||
}
|
||||
|
||||
const searchClip = async (
|
||||
searchPhrase: string,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
_searchPhrase: string,
|
||||
): Promise<ClipSearchScores | undefined> => {
|
||||
const textEmbedding =
|
||||
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
|
||||
if (!textEmbedding) return undefined;
|
||||
|
||||
const imageEmbeddings = await localCLIPEmbeddings();
|
||||
const clipSearchResult = new Map<number, number>(
|
||||
(
|
||||
await Promise.all(
|
||||
imageEmbeddings.map(
|
||||
async (imageEmbedding): Promise<[number, number]> => [
|
||||
imageEmbedding.fileID,
|
||||
await computeClipMatchScore(
|
||||
imageEmbedding.embedding,
|
||||
textEmbedding,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
).filter(([, score]) => score >= CLIP_SCORE_THRESHOLD),
|
||||
);
|
||||
|
||||
return clipSearchResult;
|
||||
// TODO-ML: clip-test
|
||||
return undefined;
|
||||
// const matches = await clipMatches(searchPhrase, ensureElectron());
|
||||
// log.debug(() => ({ t: "clip-scores", matches }));
|
||||
// return matches;
|
||||
};
|
||||
|
||||
function convertSuggestionToSearchQuery(option: Suggestion): Search {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { fetchAndSaveFeatureFlagsIfNeeded } from "@/new/photos/services/feature-flags";
|
||||
import { triggerMLSync } from "@/new/photos/services/ml";
|
||||
import { isDesktop } from "@/next/app";
|
||||
import { clipService } from "services/clip-service";
|
||||
import { syncCLIPEmbeddings } from "services/embeddingService";
|
||||
import { syncEntities } from "services/entityService";
|
||||
import { syncMapEnabled } from "services/userService";
|
||||
|
||||
@@ -20,8 +18,6 @@ export const sync = async () => {
|
||||
await syncMapEnabled();
|
||||
fetchAndSaveFeatureFlagsIfNeeded();
|
||||
if (isDesktop) {
|
||||
await syncCLIPEmbeddings();
|
||||
triggerMLSync();
|
||||
void clipService.scheduleImageEmbeddingExtraction();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -19,7 +19,6 @@ import { wait } from "@/utils/promise";
|
||||
import { getDedicatedCryptoWorker } from "@ente/shared/crypto";
|
||||
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
|
||||
import { CustomError } from "@ente/shared/error";
|
||||
import { Events, eventBus } from "@ente/shared/events";
|
||||
import { Canceler } from "axios";
|
||||
import type { Remote } from "comlink";
|
||||
import isElectron from "is-electron";
|
||||
@@ -617,27 +616,6 @@ class UploadManager {
|
||||
const uploadItem =
|
||||
uploadableItem.uploadItem ??
|
||||
uploadableItem.livePhotoAssets.image;
|
||||
try {
|
||||
let file: File | undefined;
|
||||
if (uploadItem) {
|
||||
if (uploadItem instanceof File) {
|
||||
file = uploadItem;
|
||||
} else if (
|
||||
typeof uploadItem == "string" ||
|
||||
Array.isArray(uploadItem)
|
||||
) {
|
||||
// path from desktop, no file object
|
||||
} else {
|
||||
file = uploadItem.file;
|
||||
}
|
||||
}
|
||||
eventBus.emit(Events.FILE_UPLOADED, {
|
||||
enteFile: decryptedFile,
|
||||
localFile: file,
|
||||
});
|
||||
} catch (e) {
|
||||
log.warn("Ignoring error in fileUploaded handlers", e);
|
||||
}
|
||||
if (
|
||||
uploadItem &&
|
||||
(uploadResult == UPLOAD_RESULT.UPLOADED ||
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import type { EmbeddingModel } from "@/new/photos/services/ml/embedding";
|
||||
|
||||
export interface EncryptedEmbedding {
|
||||
fileID: number;
|
||||
/** @see {@link EmbeddingModel} */
|
||||
model: string;
|
||||
encryptedEmbedding: string;
|
||||
decryptionHeader: string;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
export interface Embedding
|
||||
extends Omit<
|
||||
EncryptedEmbedding,
|
||||
"encryptedEmbedding" | "decryptionHeader"
|
||||
> {
|
||||
embedding?: Float32Array;
|
||||
}
|
||||
|
||||
export interface GetEmbeddingDiffResponse {
|
||||
diff: EncryptedEmbedding[];
|
||||
}
|
||||
|
||||
export interface PutEmbeddingRequest {
|
||||
fileID: number;
|
||||
model: EmbeddingModel;
|
||||
encryptedEmbedding: string;
|
||||
decryptionHeader: string;
|
||||
}
|
||||
@@ -17,12 +17,21 @@
|
||||
* implies "ESNext.*"), "DOM.Iterable" (the ability to iterate over DOM
|
||||
* elements) is not a subset of "DOM" and needs to be listed out
|
||||
* explicitly.
|
||||
*
|
||||
* Note that we don't need to specify the `target` compilerOption, since
|
||||
* tsc isn't actually generating (emitting) the JavaScript.
|
||||
*/
|
||||
"lib": ["ESnext", "DOM", "DOM.Iterable", "WebWorker"],
|
||||
|
||||
/**
|
||||
* We don't potentially need to specify the `target` compilerOption,
|
||||
* since tsc isn't actually generating (emitting) the JavaScript.
|
||||
* However, without a target TypeScript becomes very conservative, and
|
||||
* complains about things like
|
||||
*
|
||||
* > Type 'Float32Array' can only be iterated through when using the
|
||||
* > '--downlevelIteration' flag or with a '--target' of 'es2015' or
|
||||
* > higher.
|
||||
*/
|
||||
"target": "ES2015",
|
||||
|
||||
/*
|
||||
* The module system to assume the generated JavaScript will use.
|
||||
*
|
||||
|
||||
208
web/packages/new/photos/services/ml/clip.ts
Normal file
208
web/packages/new/photos/services/ml/clip.ts
Normal file
@@ -0,0 +1,208 @@
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import type { Electron } from "@/next/types/ipc";
|
||||
import type { ImageBitmapAndData } from "./bitmap";
|
||||
import { clipIndexes } from "./db";
|
||||
import { pixelRGBBicubic } from "./image";
|
||||
import { cosineSimilarity, norm } from "./math";
|
||||
import type { MLWorkerElectron } from "./worker-electron";
|
||||
|
||||
/**
|
||||
* The version of the CLIP indexing pipeline implemented by the current client.
|
||||
*/
|
||||
export const clipIndexingVersion = 1;
|
||||
|
||||
/**
|
||||
* The CLIP embedding for a file (and some metadata).
|
||||
*
|
||||
* See {@link FaceIndex} for a similar structure with more comprehensive
|
||||
* documentation.
|
||||
*
|
||||
* ---
|
||||
*
|
||||
* [Note: Natural language search using CLIP]
|
||||
*
|
||||
* CLIP (Contrastive Language-Image Pretraining) is a neural network trained on
|
||||
* (image, text) pairs. It can be thought of as two separate (but jointly
|
||||
* trained) encoders - one for images, and one for text - that both map to the
|
||||
* same embedding space.
|
||||
*
|
||||
* We use this for natural language search within the app:
|
||||
*
|
||||
* 1. Pre-compute an embedding for each image.
|
||||
*
|
||||
* 2. When the user searches, compute an embedding for the search term.
|
||||
*
|
||||
* 3. Use cosine similarity to find the find the image (embedding) closest to
|
||||
* the text (embedding).
|
||||
*
|
||||
* More details are in our [blog
|
||||
* post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the
|
||||
* initial launch of this feature using the GGML runtime.
|
||||
*
|
||||
* Since the initial launch, we've switched over to another runtime,
|
||||
* [ONNX](https://onnxruntime.ai) and have made other implementation changes,
|
||||
* but the overall gist remains the same.
|
||||
*
|
||||
* Note that we don't train the neural network - we only use one of the publicly
|
||||
* available pre-trained neural networks for inference. These neural networks
|
||||
* are wholly defined by their connectivity and weights. ONNX, our ML runtimes,
|
||||
* loads these weights and instantiates a running network that we can use to
|
||||
* compute the embeddings.
|
||||
*
|
||||
* Theoretically, the same CLIP model can be loaded by different frameworks /
|
||||
* runtimes, but in practice each runtime has its own preferred format, and
|
||||
* there are also quantization tradeoffs. So there is a specific model (a binary
|
||||
* encoding of weights) tied to our current runtime that we use.
|
||||
*
|
||||
* To ensure that the embeddings, for the most part, can be shared, whenever
|
||||
* possible we try to ensure that all the preprocessing steps, and the model
|
||||
* itself, is the same across clients - web and mobile.
|
||||
*/
|
||||
export interface CLIPIndex {
|
||||
/** The ID of the {@link EnteFile} whose index this is. */
|
||||
fileID: number;
|
||||
/** An integral version number of the indexing algorithm / pipeline. */
|
||||
version: number;
|
||||
/** The UA for the client which generated this embedding. */
|
||||
client: string;
|
||||
/**
|
||||
* The CLIP embedding itself.
|
||||
*
|
||||
* This is an array of 512 floating point values that represent the
|
||||
* embedding of the image in the same space where we'll embed the text so
|
||||
* that both of them can be compared using a cosine distance.
|
||||
*/
|
||||
embedding: number[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the CLIP embedding of a given file.
|
||||
*
|
||||
* This function is the entry point to the CLIP indexing pipeline. The file goes
|
||||
* through various stages:
|
||||
*
|
||||
* 1. Downloading the original if needed.
|
||||
* 2. Convert (if needed) and pre-process.
|
||||
* 3. Compute embeddings using ONNX/CLIP.
|
||||
*
|
||||
* Once all of it is done, it CLIP embedding (wrapped as a {@link CLIPIndex} so
|
||||
* that it can be saved locally and also uploaded to the user's remote storage
|
||||
* for use on their other devices).
|
||||
*
|
||||
* @param enteFile The {@link EnteFile} to index.
|
||||
*
|
||||
* @param uploadItem If we're called during the upload process, then this will
|
||||
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
|
||||
* use the on-disk file instead of needing to download the original from remote.
|
||||
*
|
||||
* @param electron The {@link MLWorkerElectron} instance that allows us to call
|
||||
* our Node.js layer for various functionality.
|
||||
*
|
||||
* @param userAgent The UA of the client that is doing the indexing (us).
|
||||
*/
|
||||
export const indexCLIP = async (
|
||||
enteFile: EnteFile,
|
||||
image: ImageBitmapAndData,
|
||||
electron: MLWorkerElectron,
|
||||
userAgent: string,
|
||||
): Promise<CLIPIndex> => {
|
||||
const { data: imageData } = image;
|
||||
const fileID = enteFile.id;
|
||||
|
||||
return {
|
||||
fileID,
|
||||
version: clipIndexingVersion,
|
||||
client: userAgent,
|
||||
embedding: await computeEmbedding(imageData, electron),
|
||||
};
|
||||
};
|
||||
|
||||
const computeEmbedding = async (
|
||||
imageData: ImageData,
|
||||
electron: MLWorkerElectron,
|
||||
): Promise<number[]> => {
|
||||
const clipInput = convertToCLIPInput(imageData);
|
||||
return normalized(await electron.computeCLIPImageEmbedding(clipInput));
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert {@link imageData} into the format that the CLIP model expects.
|
||||
*/
|
||||
const convertToCLIPInput = (imageData: ImageData) => {
|
||||
const requiredWidth = 224;
|
||||
const requiredHeight = 224;
|
||||
|
||||
const mean = [0.48145466, 0.4578275, 0.40821073] as const;
|
||||
const std = [0.26862954, 0.26130258, 0.27577711] as const;
|
||||
|
||||
const { width, height, data: pixelData } = imageData;
|
||||
|
||||
// Maintain aspect ratio.
|
||||
const scale = Math.max(requiredWidth / width, requiredHeight / height);
|
||||
|
||||
const scaledWidth = Math.round(width * scale);
|
||||
const scaledHeight = Math.round(height * scale);
|
||||
const widthOffset = Math.max(0, scaledWidth - requiredWidth) / 2;
|
||||
const heightOffset = Math.max(0, scaledHeight - requiredHeight) / 2;
|
||||
|
||||
const clipInput = new Float32Array(3 * requiredWidth * requiredHeight);
|
||||
|
||||
// Populate the Float32Array with normalized pixel values.
|
||||
let pi = 0;
|
||||
const cOffsetG = requiredHeight * requiredWidth; // ChannelOffsetGreen
|
||||
const cOffsetB = 2 * requiredHeight * requiredWidth; // ChannelOffsetBlue
|
||||
for (let h = 0 + heightOffset; h < scaledHeight - heightOffset; h++) {
|
||||
for (let w = 0 + widthOffset; w < scaledWidth - widthOffset; w++) {
|
||||
const { r, g, b } = pixelRGBBicubic(
|
||||
w / scale,
|
||||
h / scale,
|
||||
pixelData,
|
||||
width,
|
||||
height,
|
||||
);
|
||||
clipInput[pi] = (r / 255.0 - mean[0]) / std[0];
|
||||
clipInput[pi + cOffsetG] = (g / 255.0 - mean[1]) / std[1];
|
||||
clipInput[pi + cOffsetB] = (b / 255.0 - mean[2]) / std[2];
|
||||
pi++;
|
||||
}
|
||||
}
|
||||
return clipInput;
|
||||
};
|
||||
|
||||
const normalized = (embedding: Float32Array) => {
|
||||
const nums = Array.from(embedding);
|
||||
const n = norm(nums);
|
||||
return nums.map((v) => v / n);
|
||||
};
|
||||
|
||||
/**
|
||||
* Use CLIP to perform a natural language search over image embeddings.
|
||||
*
|
||||
* @param searchPhrase The text entered by the user in the search box.
|
||||
*
|
||||
* @param electron The {@link Electron} instance to use to communicate with the
|
||||
* native code running in our desktop app (the embedding happens in the native
|
||||
* layer).
|
||||
*
|
||||
* It return a list of files that should be shown in the search results. The
|
||||
* actual return type is a map from fileIDs to the scores they got (higher is
|
||||
* better). This map will only contains entries whose score was above our
|
||||
* minimum threshold.
|
||||
*
|
||||
* The result can also be `undefined`, which indicates that the download for the
|
||||
* ML model is still in progress (trying again later should succeed).
|
||||
*/
|
||||
export const clipMatches = async (
|
||||
searchPhrase: string,
|
||||
electron: Electron,
|
||||
): Promise<Map<number, number> | undefined> => {
|
||||
const t = await electron.computeCLIPTextEmbeddingIfAvailable(searchPhrase);
|
||||
if (!t) return undefined;
|
||||
|
||||
const textEmbedding = normalized(t);
|
||||
const items = (await clipIndexes()).map(
|
||||
({ fileID, embedding }) =>
|
||||
[fileID, cosineSimilarity(embedding, textEmbedding)] as const,
|
||||
);
|
||||
return new Map(items.filter(([, score]) => score >= 0.23));
|
||||
};
|
||||
@@ -3,7 +3,7 @@ import { ensure } from "@/utils/ensure";
|
||||
import type { EnteFile } from "../../types/file";
|
||||
import { renderableImageBitmap } from "./bitmap";
|
||||
import { type Box, type FaceIndex } from "./face";
|
||||
import { clamp } from "./image";
|
||||
import { clamp } from "./math";
|
||||
|
||||
/**
|
||||
* Regenerate and locally save face crops for faces in the given file.
|
||||
|
||||
@@ -1,37 +1,48 @@
|
||||
import log from "@/next/log";
|
||||
import localForage from "@ente/shared/storage/localForage";
|
||||
import { deleteDB, openDB, type DBSchema } from "idb";
|
||||
import type { CLIPIndex } from "./clip";
|
||||
import type { EmbeddingModel } from "./embedding";
|
||||
import type { FaceIndex } from "./face";
|
||||
|
||||
/**
|
||||
* Face DB schema.
|
||||
* ML DB schema.
|
||||
*
|
||||
* There "face" database is made of two object stores:
|
||||
* The "ML" database is made of three object stores:
|
||||
*
|
||||
* - "file-status": Contains {@link FileStatus} objects, one for each
|
||||
* {@link EnteFile} that the ML subsystem knows about. Periodically (and when
|
||||
* required), this is synced with the list of files that the current client
|
||||
* knows about locally.
|
||||
*
|
||||
* - "face-index": Contains {@link FaceIndex} objects, either indexed locally or
|
||||
* fetched from remote storage.
|
||||
*
|
||||
* - "file-status": Contains {@link FileStatus} objects, one for each
|
||||
* {@link EnteFile} that the current client knows about.
|
||||
* - "clip-index": Contains {@link CLIPIndex} objects, either indexed locally or
|
||||
* fetched from remote storage.
|
||||
*
|
||||
* Both the stores are keyed by {@link fileID}, and are expected to contain the
|
||||
* exact same set of {@link fileID}s. The face-index can be thought of as the
|
||||
* "original" indexing result, whilst file-status bookkeeps information about
|
||||
* the indexing process (whether or not a file needs indexing, or if there were
|
||||
* errors doing so).
|
||||
* All the stores are keyed by {@link fileID}. The "file-status" contains
|
||||
* book-keeping about the indexing process (whether or not a file needs
|
||||
* indexing, or if there were errors doing so), while the other stores contain
|
||||
* the actual indexing results.
|
||||
*
|
||||
* In tandem, these serve as the underlying storage for the functions exposed by
|
||||
* this file.
|
||||
*/
|
||||
interface FaceDBSchema extends DBSchema {
|
||||
"face-index": {
|
||||
key: number;
|
||||
value: FaceIndex;
|
||||
};
|
||||
interface MLDBSchema extends DBSchema {
|
||||
"file-status": {
|
||||
key: number;
|
||||
value: FileStatus;
|
||||
indexes: { status: FileStatus["status"] };
|
||||
};
|
||||
"face-index": {
|
||||
key: number;
|
||||
value: FaceIndex;
|
||||
};
|
||||
"clip-index": {
|
||||
key: number;
|
||||
value: CLIPIndex;
|
||||
};
|
||||
}
|
||||
|
||||
interface FileStatus {
|
||||
@@ -44,94 +55,128 @@ interface FileStatus {
|
||||
* to be indexed.
|
||||
*
|
||||
* - "indexed" - We have a corresponding entry for this file in the
|
||||
* "face-index" object (either indexed locally or fetched from remote).
|
||||
* "face-index" _and_ "clip-index" object stores (either indexed locally
|
||||
* or fetched from remote).
|
||||
*
|
||||
* - "failed" - Indexing was attempted but failed.
|
||||
*
|
||||
* We also have a (IndexedDB) "index" on this field to allow us to
|
||||
* efficiently select or count {@link fileIDs} that fall into various
|
||||
* buckets.
|
||||
* There can arise situations in which a file has one, but not all, indexes.
|
||||
* e.g. it may have a "face-index" but "clip-index" might've not yet
|
||||
* happened (or failed). In such cases, the status of the file will be
|
||||
* "indexable": it transitions to "indexed" only after all indexes have been
|
||||
* computed or fetched.
|
||||
*
|
||||
* If you have't heard the word "index" to the point of zoning out, we also
|
||||
* have a (IndexedDB) "index" on the status field to allow us to efficiently
|
||||
* select or count {@link fileIDs} that fall into various buckets.
|
||||
*/
|
||||
status: "indexable" | "indexed" | "failed";
|
||||
/**
|
||||
* A list of embeddings that we still need to compute for the file.
|
||||
*
|
||||
* This is guaranteed to be empty if status is "indexed", and will have at
|
||||
* least one entry otherwise.
|
||||
*/
|
||||
pending: EmbeddingModel[];
|
||||
/**
|
||||
* The number of times attempts to index this file failed.
|
||||
*
|
||||
* It counts failure across all index types.
|
||||
*
|
||||
* This is guaranteed to be `0` for files with status "indexed".
|
||||
*/
|
||||
failureCount: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* A lazily-created, cached promise for face DB.
|
||||
* A lazily-created, cached promise for ML DB.
|
||||
*
|
||||
* See: [Note: Caching IDB instances in separate execution contexts].
|
||||
*/
|
||||
let _faceDB: ReturnType<typeof openFaceDB> | undefined;
|
||||
let _mlDB: ReturnType<typeof openMLDB> | undefined;
|
||||
|
||||
const openFaceDB = async () => {
|
||||
const openMLDB = async () => {
|
||||
deleteLegacyDB();
|
||||
|
||||
const db = await openDB<FaceDBSchema>("face", 1, {
|
||||
// TODO-ML: "face" => "ml", v2 => v1
|
||||
const db = await openDB<MLDBSchema>("face", 2, {
|
||||
upgrade(db, oldVersion, newVersion) {
|
||||
log.info(`Upgrading face DB ${oldVersion} => ${newVersion}`);
|
||||
log.info(`Upgrading ML DB ${oldVersion} => ${newVersion}`);
|
||||
if (oldVersion < 1) {
|
||||
db.createObjectStore("face-index", { keyPath: "fileID" });
|
||||
db.createObjectStore("file-status", {
|
||||
keyPath: "fileID",
|
||||
}).createIndex("status", "status");
|
||||
db.createObjectStore("face-index", { keyPath: "fileID" });
|
||||
}
|
||||
if (oldVersion < 2) {
|
||||
db.createObjectStore("clip-index", { keyPath: "fileID" });
|
||||
}
|
||||
},
|
||||
blocking() {
|
||||
log.info(
|
||||
"Another client is attempting to open a new version of face DB",
|
||||
"Another client is attempting to open a new version of ML DB",
|
||||
);
|
||||
db.close();
|
||||
_faceDB = undefined;
|
||||
_mlDB = undefined;
|
||||
},
|
||||
blocked() {
|
||||
log.warn(
|
||||
"Waiting for an existing client to close their connection so that we can update the face DB version",
|
||||
"Waiting for an existing client to close their connection so that we can update the ML DB version",
|
||||
);
|
||||
},
|
||||
terminated() {
|
||||
log.warn("Our connection to face DB was unexpectedly terminated");
|
||||
_faceDB = undefined;
|
||||
log.warn("Our connection to ML DB was unexpectedly terminated");
|
||||
_mlDB = undefined;
|
||||
},
|
||||
});
|
||||
return db;
|
||||
};
|
||||
|
||||
const deleteLegacyDB = () => {
|
||||
// Delete the legacy face DB.
|
||||
// Delete the legacy face DB v1.
|
||||
//
|
||||
// This code was added June 2024 (v1.7.1-rc) and can be removed once clients
|
||||
// have migrated over.
|
||||
void deleteDB("mldata");
|
||||
|
||||
// Delete the legacy CLIP (mostly) related keys from LocalForage.
|
||||
//
|
||||
// This code was added July 2024 (v1.7.2-rc) and can be removed once
|
||||
// sufficient clients have migrated over (tag: Migration).
|
||||
void Promise.all([
|
||||
localForage.removeItem("embeddings"),
|
||||
localForage.removeItem("embedding_sync_time"),
|
||||
localForage.removeItem("embeddings_v2"),
|
||||
localForage.removeItem("file_embeddings"),
|
||||
localForage.removeItem("onnx-clip-embedding_sync_time"),
|
||||
localForage.removeItem("file-ml-clip-face-embedding_sync_time"),
|
||||
]);
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns a lazily created, cached connection to the face DB.
|
||||
* @returns a lazily created, cached connection to the ML DB.
|
||||
*/
|
||||
const faceDB = () => (_faceDB ??= openFaceDB());
|
||||
const mlDB = () => (_mlDB ??= openMLDB());
|
||||
|
||||
/**
|
||||
* Clear any data stored in the face DB.
|
||||
* Clear any data stored in the ML DB.
|
||||
*
|
||||
* This is meant to be called during logout in the main thread.
|
||||
* This is meant to be called during logout on the main thread.
|
||||
*/
|
||||
export const clearFaceDB = async () => {
|
||||
export const clearMLDB = async () => {
|
||||
deleteLegacyDB();
|
||||
|
||||
try {
|
||||
if (_faceDB) (await _faceDB).close();
|
||||
if (_mlDB) (await _mlDB).close();
|
||||
} catch (e) {
|
||||
log.warn("Ignoring error when trying to close face DB", e);
|
||||
log.warn("Ignoring error when trying to close ML DB", e);
|
||||
}
|
||||
_faceDB = undefined;
|
||||
_mlDB = undefined;
|
||||
|
||||
return deleteDB("face", {
|
||||
blocked() {
|
||||
log.warn(
|
||||
"Waiting for an existing client to close their connection so that we can delete the face DB",
|
||||
"Waiting for an existing client to close their connection so that we can delete the ML DB",
|
||||
);
|
||||
},
|
||||
});
|
||||
@@ -141,23 +186,79 @@ export const clearFaceDB = async () => {
|
||||
* Save the given {@link faceIndex} locally.
|
||||
*
|
||||
* @param faceIndex A {@link FaceIndex} representing the faces that we detected
|
||||
* (and their corresponding embeddings) in some file.
|
||||
* (and their corresponding embeddings) in a particular file.
|
||||
*
|
||||
* This function adds a new entry, overwriting any existing ones (No merging is
|
||||
* performed, the existing entry is unconditionally overwritten).
|
||||
* This function adds a new entry for the face index, overwriting any existing
|
||||
* ones (No merging is performed, the existing entry is unconditionally
|
||||
* overwritten). The file status is also updated to remove face from the pending
|
||||
* embeddings. If there are no other pending embeddings, the status changes to
|
||||
* "indexed".
|
||||
*/
|
||||
export const saveFaceIndex = async (faceIndex: FaceIndex) => {
|
||||
const db = await faceDB();
|
||||
const tx = db.transaction(["face-index", "file-status"], "readwrite");
|
||||
const indexStore = tx.objectStore("face-index");
|
||||
const { fileID } = faceIndex;
|
||||
|
||||
const db = await mlDB();
|
||||
const tx = db.transaction(["file-status", "face-index"], "readwrite");
|
||||
const statusStore = tx.objectStore("file-status");
|
||||
const indexStore = tx.objectStore("face-index");
|
||||
|
||||
const fileStatus =
|
||||
(await statusStore.get(IDBKeyRange.only(fileID))) ??
|
||||
newFileStatus(fileID);
|
||||
fileStatus.pending = fileStatus.pending.filter(
|
||||
(v) => v != "file-ml-clip-face",
|
||||
);
|
||||
if (fileStatus.pending.length == 0) fileStatus.status = "indexed";
|
||||
|
||||
await Promise.all([
|
||||
statusStore.put(fileStatus),
|
||||
indexStore.put(faceIndex),
|
||||
statusStore.put({
|
||||
fileID: faceIndex.fileID,
|
||||
status: "indexed",
|
||||
failureCount: 0,
|
||||
}),
|
||||
tx.done,
|
||||
]);
|
||||
};
|
||||
|
||||
/**
|
||||
* Return a new object suitable for use as the initial value of the entry for a
|
||||
* file in the file status store.
|
||||
*/
|
||||
const newFileStatus = (fileID: number): FileStatus => ({
|
||||
fileID,
|
||||
status: "indexable",
|
||||
// TODO-ML: clip-test
|
||||
// pending: ["file-ml-clip-face", "onnx-clip"],
|
||||
pending: ["file-ml-clip-face"],
|
||||
failureCount: 0,
|
||||
});
|
||||
|
||||
/**
|
||||
* Save the given {@link clipIndex} locally.
|
||||
*
|
||||
* @param clipIndex A {@link CLIPIndex} containing the CLIP embedding for a
|
||||
* particular file.
|
||||
*
|
||||
* This function adds a new entry for the CLIP index, overwriting any existing
|
||||
* ones (No merging is performed, the existing entry is unconditionally
|
||||
* overwritten). The file status is also updated to remove CLIP from the pending
|
||||
* embeddings. If there are no other pending embeddings, the status changes to
|
||||
* "indexed".
|
||||
*/
|
||||
export const saveCLIPIndex = async (clipIndex: CLIPIndex) => {
|
||||
const { fileID } = clipIndex;
|
||||
|
||||
const db = await mlDB();
|
||||
const tx = db.transaction(["file-status", "clip-index"], "readwrite");
|
||||
const statusStore = tx.objectStore("file-status");
|
||||
const indexStore = tx.objectStore("clip-index");
|
||||
|
||||
const fileStatus =
|
||||
(await statusStore.get(IDBKeyRange.only(fileID))) ??
|
||||
newFileStatus(fileID);
|
||||
fileStatus.pending = fileStatus.pending.filter((v) => v != "onnx-clip");
|
||||
if (fileStatus.pending.length == 0) fileStatus.status = "indexed";
|
||||
|
||||
await Promise.all([
|
||||
statusStore.put(fileStatus),
|
||||
indexStore.put(clipIndex),
|
||||
tx.done,
|
||||
]);
|
||||
};
|
||||
@@ -166,13 +267,21 @@ export const saveFaceIndex = async (faceIndex: FaceIndex) => {
|
||||
* Return the {@link FaceIndex}, if any, for {@link fileID}.
|
||||
*/
|
||||
export const faceIndex = async (fileID: number) => {
|
||||
const db = await faceDB();
|
||||
const db = await mlDB();
|
||||
return db.get("face-index", fileID);
|
||||
};
|
||||
|
||||
/**
|
||||
* Record the existence of a file so that entities in the face indexing universe
|
||||
* know about it (e.g. can index it if it is new and it needs indexing).
|
||||
* Return all CLIP indexes present locally.
|
||||
*/
|
||||
export const clipIndexes = async () => {
|
||||
const db = await mlDB();
|
||||
return await db.getAll("clip-index");
|
||||
};
|
||||
|
||||
/**
|
||||
* Record the existence of a file so that entities in the ML universe know about
|
||||
* it (e.g. can index it if it is new and it needs indexing).
|
||||
*
|
||||
* @param fileID The ID of an {@link EnteFile}.
|
||||
*
|
||||
@@ -182,48 +291,45 @@ export const faceIndex = async (fileID: number) => {
|
||||
* unperturbed.
|
||||
*/
|
||||
export const addFileEntry = async (fileID: number) => {
|
||||
const db = await faceDB();
|
||||
const db = await mlDB();
|
||||
const tx = db.transaction("file-status", "readwrite");
|
||||
if ((await tx.store.getKey(fileID)) === undefined) {
|
||||
await tx.store.put({
|
||||
fileID,
|
||||
status: "indexable",
|
||||
failureCount: 0,
|
||||
});
|
||||
}
|
||||
if ((await tx.store.getKey(fileID)) === undefined)
|
||||
await tx.store.put(newFileStatus(fileID));
|
||||
return tx.done;
|
||||
};
|
||||
|
||||
/**
|
||||
* Update entries in the face DB to align with the state of local files outside
|
||||
* face DB.
|
||||
* Update entries in ML DB to align with the state of local files outside ML DB.
|
||||
*
|
||||
* @param localFileIDs IDs of all the files that the client is aware of filtered
|
||||
* to only keep the files that the user owns and the formats that can be indexed
|
||||
* by our current face indexing pipeline.
|
||||
* @param localFileIDs IDs of all the files that the client is aware of,
|
||||
* filtered to only keep the files that the user owns and the formats that can
|
||||
* be indexed by our current indexing pipelines.
|
||||
*
|
||||
* @param localTrashFilesIDs IDs of all the files in trash.
|
||||
*
|
||||
* This function then updates the state of file entries in face DB to the be in
|
||||
* sync with these provided local file IDS.
|
||||
* This function then updates the state of file entries in ML DB to the be in
|
||||
* sync with these provided local file IDs.
|
||||
*
|
||||
* - Files that are present locally but are not yet in face DB get a fresh entry
|
||||
* - Files that are present locally but are not yet in ML DB get a fresh entry
|
||||
* in face DB (and are marked as indexable).
|
||||
*
|
||||
* - Files that are not present locally (nor are in trash) but still exist in
|
||||
* face DB are removed from face DB (including their face index, if any).
|
||||
* - Files that are not present locally (nor are in trash) but still exist in ML
|
||||
* DB are removed from ML DB (including any indexes).
|
||||
*
|
||||
* - Files that are not present locally but are in the trash are retained in
|
||||
* face DB if their status is "indexed" (otherwise they too are removed). This
|
||||
* is prevent churn (re-indexing) if the user moves some files to trash but
|
||||
* then later restores them before they get permanently deleted.
|
||||
* - Files that are not present locally but are in the trash are retained in ML
|
||||
* DB if their status is "indexed"; otherwise they too are removed. This
|
||||
* special case is to prevent churn (re-indexing) if the user moves some files
|
||||
* to trash but then later restores them before they get permanently deleted.
|
||||
*/
|
||||
export const updateAssumingLocalFiles = async (
|
||||
localFileIDs: number[],
|
||||
localTrashFilesIDs: number[],
|
||||
) => {
|
||||
const db = await faceDB();
|
||||
const tx = db.transaction(["face-index", "file-status"], "readwrite");
|
||||
const db = await mlDB();
|
||||
const tx = db.transaction(
|
||||
["file-status", "face-index", "clip-index"],
|
||||
"readwrite",
|
||||
);
|
||||
const fdbFileIDs = await tx.objectStore("file-status").getAllKeys();
|
||||
const fdbIndexedFileIDs = await tx
|
||||
.objectStore("file-status")
|
||||
@@ -250,16 +356,13 @@ export const updateAssumingLocalFiles = async (
|
||||
await Promise.all(
|
||||
[
|
||||
newFileIDs.map((id) =>
|
||||
tx.objectStore("file-status").put({
|
||||
fileID: id,
|
||||
status: "indexable",
|
||||
failureCount: 0,
|
||||
}),
|
||||
tx.objectStore("file-status").put(newFileStatus(id)),
|
||||
),
|
||||
removedFileIDs.map((id) =>
|
||||
tx.objectStore("file-status").delete(id),
|
||||
),
|
||||
removedFileIDs.map((id) => tx.objectStore("face-index").delete(id)),
|
||||
removedFileIDs.map((id) => tx.objectStore("clip-index").delete(id)),
|
||||
tx.done,
|
||||
].flat(),
|
||||
);
|
||||
@@ -268,11 +371,12 @@ export const updateAssumingLocalFiles = async (
|
||||
/**
|
||||
* Return the count of files that can be, and that have been, indexed.
|
||||
*
|
||||
* These counts are mutually exclusive. The total number of files that fall
|
||||
* within the purview of the indexer is thus indexable + indexed.
|
||||
* These counts are mutually exclusive. Thus the total number of files that are
|
||||
* fall within the purview of the indexer will be indexable + indexed (if we are
|
||||
* ignoring the "failed" ones).
|
||||
*/
|
||||
export const indexableAndIndexedCounts = async () => {
|
||||
const db = await faceDB();
|
||||
const db = await mlDB();
|
||||
const tx = db.transaction("file-status", "readwrite");
|
||||
const indexableCount = await tx.store
|
||||
.index("status")
|
||||
@@ -287,14 +391,14 @@ export const indexableAndIndexedCounts = async () => {
|
||||
* Return a list of fileIDs that need to be indexed.
|
||||
*
|
||||
* This list is from the universe of the file IDs that the face DB knows about
|
||||
* (can use {@link addFileEntry} to inform it about new files). From this
|
||||
* (we can use {@link addFileEntry} to inform it about new files). From this
|
||||
* universe, we filter out fileIDs the files corresponding to which have already
|
||||
* been indexed, or which should be ignored.
|
||||
*
|
||||
* @param count Limit the result to up to {@link count} items.
|
||||
*/
|
||||
export const indexableFileIDs = async (count?: number) => {
|
||||
const db = await faceDB();
|
||||
const db = await mlDB();
|
||||
const tx = db.transaction("file-status", "readonly");
|
||||
return tx.store
|
||||
.index("status")
|
||||
@@ -309,15 +413,15 @@ export const indexableFileIDs = async (count?: number) => {
|
||||
* If an entry does not exist yet for the given file, then a new one is created
|
||||
* and its failure count is set to 1. Otherwise the failure count of the
|
||||
* existing entry is incremented.
|
||||
*
|
||||
* This count is across all different types of indexing (face, CLIP) that happen
|
||||
* on the file.
|
||||
*/
|
||||
export const markIndexingFailed = async (fileID: number) => {
|
||||
const db = await faceDB();
|
||||
const db = await mlDB();
|
||||
const tx = db.transaction("file-status", "readwrite");
|
||||
const failureCount = ((await tx.store.get(fileID))?.failureCount ?? 0) + 1;
|
||||
await tx.store.put({
|
||||
fileID,
|
||||
status: "failed",
|
||||
failureCount,
|
||||
});
|
||||
return tx.done;
|
||||
const fileStatus = (await tx.store.get(fileID)) ?? newFileStatus(fileID);
|
||||
fileStatus.status = "failed";
|
||||
fileStatus.failureCount = fileStatus.failureCount + 1;
|
||||
await Promise.all([tx.store.put(fileStatus), tx.done]);
|
||||
};
|
||||
|
||||
@@ -12,8 +12,9 @@ import { getKV, setKV } from "@/next/kv";
|
||||
import log from "@/next/log";
|
||||
import { apiURL } from "@/next/origins";
|
||||
import { z } from "zod";
|
||||
import { saveFaceIndex } from "./db";
|
||||
import { type FaceIndex, faceIndexingVersion } from "./face";
|
||||
import { clipIndexingVersion, type CLIPIndex } from "./clip";
|
||||
import { saveCLIPIndex, saveFaceIndex } from "./db";
|
||||
import { faceIndexingVersion, type FaceIndex } from "./face";
|
||||
|
||||
/**
|
||||
* The embeddings that we (the current client) knows how to handle.
|
||||
@@ -48,7 +49,7 @@ import { type FaceIndex, faceIndexingVersion } from "./face";
|
||||
* "model" (i.e "type") field to create a new universe of embeddings.
|
||||
*/
|
||||
export type EmbeddingModel =
|
||||
| "onnx-clip" /* CLIP (text) embeddings */
|
||||
| "onnx-clip" /* CLIP embeddings */
|
||||
| "file-ml-clip-face" /* Face embeddings */;
|
||||
|
||||
const RemoteEmbedding = z.object({
|
||||
@@ -397,3 +398,61 @@ const FaceIndex = z
|
||||
*/
|
||||
export const putFaceIndex = async (enteFile: EnteFile, faceIndex: FaceIndex) =>
|
||||
putEmbedding(enteFile, "file-ml-clip-face", JSON.stringify(faceIndex));
|
||||
|
||||
// MARK: - CLIP
|
||||
|
||||
/**
|
||||
* Fetch new or updated CLIP embeddings from remote and save them locally.
|
||||
*
|
||||
* See {@link pullFaceEmbeddings} for a sibling function with more comprehensive
|
||||
* documentation.
|
||||
*
|
||||
* @returns true if at least one embedding was pulled, false otherwise.
|
||||
*/
|
||||
export const pullCLIPEmbeddings = () =>
|
||||
pullEmbeddings("onnx-clip", (jsonString: string) =>
|
||||
// eslint-disable-next-line @typescript-eslint/prefer-ts-expect-error, @typescript-eslint/ban-ts-comment
|
||||
// @ts-ignore TODO: There is no error here, but this file is imported by
|
||||
// one of our packages that doesn't have strict mode enabled yet,
|
||||
// causing a spurious error to be emitted in that context.
|
||||
saveCLIPIndexIfNewer(CLIPIndex.parse(JSON.parse(jsonString))),
|
||||
);
|
||||
|
||||
/**
|
||||
* Save the given {@link clipIndex} locally if it is newer than the one we have.
|
||||
*
|
||||
* This is a variant of {@link saveCLIPIndex} that performs version checking as
|
||||
* described in [Note: Handling versioning of embeddings].
|
||||
*/
|
||||
const saveCLIPIndexIfNewer = async (index: CLIPIndex) => {
|
||||
const version = index.version;
|
||||
if (version < clipIndexingVersion) {
|
||||
log.info(
|
||||
`Ignoring remote CLIP index with version ${version} older than what our indexer can produce (${clipIndexingVersion})`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
return saveCLIPIndex(index);
|
||||
};
|
||||
|
||||
/**
|
||||
* Zod schemas for the {@link CLIPIndex} types.
|
||||
*
|
||||
* See: [Note: Duplicated between Zod schemas and TS type]
|
||||
*/
|
||||
const CLIPIndex = z
|
||||
.object({
|
||||
fileID: z.number(),
|
||||
version: z.number(),
|
||||
client: z.string(),
|
||||
embedding: z.array(z.number()),
|
||||
})
|
||||
// Retain fields we might not (currently) understand.
|
||||
.passthrough();
|
||||
|
||||
/**
|
||||
* Save the CLIP index for the given {@link enteFile} on remote so that other
|
||||
* clients can directly pull it instead of needing to reindex.
|
||||
*/
|
||||
export const putCLIPIndex = async (enteFile: EnteFile, clipIndex: CLIPIndex) =>
|
||||
putEmbedding(enteFile, "onnx-clip", JSON.stringify(clipIndex));
|
||||
|
||||
@@ -21,11 +21,11 @@ import {
|
||||
import type { ImageBitmapAndData } from "./bitmap";
|
||||
import { saveFaceCrops } from "./crop";
|
||||
import {
|
||||
clamp,
|
||||
grayscaleIntMatrixFromNormalized2List,
|
||||
pixelRGBBilinear,
|
||||
warpAffineFloat32List,
|
||||
} from "./image";
|
||||
import { clamp } from "./math";
|
||||
import type { MLWorkerElectron } from "./worker-electron";
|
||||
|
||||
/**
|
||||
@@ -196,24 +196,26 @@ export interface Box {
|
||||
/**
|
||||
* Index faces in the given file.
|
||||
*
|
||||
* This function is the entry point to the indexing pipeline. The file goes
|
||||
* This function is the entry point to the face indexing pipeline. The file goes
|
||||
* through various stages:
|
||||
*
|
||||
* 1. Downloading the original if needed.
|
||||
* 2. Detect faces using ONNX/YOLO
|
||||
* 3. Align the face rectangles, compute blur.
|
||||
* 4. Compute embeddings for the detected face (crops).
|
||||
* 4. Compute embeddings using ONNX/MFNT for the detected face (crop).
|
||||
*
|
||||
* Once all of it is done, it returns the face rectangles and embeddings so that
|
||||
* they can be saved locally for offline use, and encrypts and uploads them to
|
||||
* the user's remote storage so that their other devices can download them
|
||||
* instead of needing to reindex.
|
||||
* they can be saved locally (for offline use), and also uploaded to the user's
|
||||
* remote storage so that their other devices can download them instead of
|
||||
* needing to reindex.
|
||||
*
|
||||
* As an optimization, it also saves the face crops of the detected faces to the
|
||||
* local cache (they can be regenerated independently too by using
|
||||
* {@link regenerateFaceCrops}).
|
||||
*
|
||||
* @param enteFile The {@link EnteFile} to index.
|
||||
*
|
||||
* @param uploadItem If we're called during the upload process, then this will
|
||||
* be set to the {@link UploadItem} that was uploaded. This way, we can directly
|
||||
* use the on-disk file instead of needing to download the original from remote.
|
||||
* @param image The file's contents.
|
||||
*
|
||||
* @param electron The {@link MLWorkerElectron} instance that allows us to call
|
||||
* our Node.js layer for various functionality.
|
||||
|
||||
@@ -3,12 +3,7 @@
|
||||
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import { Matrix, inverse } from "ml-matrix";
|
||||
|
||||
/**
|
||||
* Clamp {@link value} to between {@link min} and {@link max}, inclusive.
|
||||
*/
|
||||
export const clamp = (value: number, min: number, max: number) =>
|
||||
Math.min(max, Math.max(min, value));
|
||||
import { clamp } from "./math";
|
||||
|
||||
/**
|
||||
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
|
||||
@@ -80,7 +75,7 @@ const pixelRGBA = (
|
||||
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
|
||||
* {@link fy}) using bicubic interpolation.
|
||||
*/
|
||||
const pixelRGBBicubic = (
|
||||
export const pixelRGBBicubic = (
|
||||
fx: number,
|
||||
fy: number,
|
||||
imageData: Uint8ClampedArray,
|
||||
|
||||
@@ -16,7 +16,7 @@ import { ComlinkWorker } from "@/next/worker/comlink-worker";
|
||||
import { proxy } from "comlink";
|
||||
import type { UploadItem } from "../upload/types";
|
||||
import { regenerateFaceCrops } from "./crop";
|
||||
import { clearFaceDB, faceIndex, indexableAndIndexedCounts } from "./db";
|
||||
import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db";
|
||||
import { MLWorker } from "./worker";
|
||||
|
||||
/**
|
||||
@@ -47,6 +47,7 @@ const createComlinkWorker = async () => {
|
||||
appVersion: electron.appVersion,
|
||||
detectFaces: electron.detectFaces,
|
||||
computeFaceEmbeddings: electron.computeFaceEmbeddings,
|
||||
computeCLIPImageEmbedding: electron.computeCLIPImageEmbedding,
|
||||
};
|
||||
|
||||
const cw = new ComlinkWorker<typeof MLWorker>(
|
||||
@@ -85,9 +86,9 @@ export const logoutML = async () => {
|
||||
// `terminateMLWorker` is conceptually also part of this, but for the
|
||||
// reasons mentioned in [Note: Caching IDB instances in separate execution
|
||||
// contexts], it gets called first in the logout sequence, and then this
|
||||
// `logoutML` gets called at a later point in time.
|
||||
// function (`logoutML`) gets called at a later point in time.
|
||||
_isMLEnabled = false;
|
||||
await clearFaceDB();
|
||||
await clearMLDB();
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
43
web/packages/new/photos/services/ml/math.ts
Normal file
43
web/packages/new/photos/services/ml/math.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Clamp {@link value} to between {@link min} and {@link max}, inclusive.
|
||||
*/
|
||||
export const clamp = (value: number, min: number, max: number) =>
|
||||
Math.min(max, Math.max(min, value));
|
||||
|
||||
/**
|
||||
* Return the dot-product of two vectors.
|
||||
*
|
||||
* Dot product is the component-wise product of the corresponding elements of
|
||||
* the two given vectors.
|
||||
*
|
||||
* Precondition: The two vectors must be of the same length.
|
||||
*/
|
||||
export const dotProduct = (v1: number[], v2: number[]) => {
|
||||
if (v1.length != v2.length)
|
||||
throw new Error(`Length mismatch ${v1.length} ${v2.length}`);
|
||||
let d = 0;
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
for (let i = 0; i < v1.length; i++) d += v1[i]! * v2[i]!;
|
||||
return d;
|
||||
};
|
||||
|
||||
/**
|
||||
* Return the L2-norm ("magnitude") of the given vector.
|
||||
*
|
||||
* L2-norm is the sqrt of the sum of the squares of the components of the
|
||||
* vector. It can also be thought of as the sqrt of the dot product of the
|
||||
* vector with itself.
|
||||
*/
|
||||
export const norm = (v: number[]) =>
|
||||
Math.sqrt(v.reduce((a, x) => a + x * x, 0));
|
||||
|
||||
/**
|
||||
* Return the cosine similarity of the two given vectors.
|
||||
*
|
||||
* The result is a value between [-1, 1] (inclusive), indicating how similar
|
||||
* (1), orthogonal (0) and dissimilar (1) the two vectors are in direction.
|
||||
*
|
||||
* Precondition: The two vectors must be of the same length.
|
||||
*/
|
||||
export const cosineSimilarity = (v1: number[], v2: number[]) =>
|
||||
dotProduct(v1, v2) / (norm(v1) * norm(v2));
|
||||
@@ -10,4 +10,5 @@ export interface MLWorkerElectron {
|
||||
appVersion: () => Promise<string>;
|
||||
detectFaces: (input: Float32Array) => Promise<Float32Array>;
|
||||
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
|
||||
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
|
||||
}
|
||||
|
||||
@@ -11,15 +11,17 @@ import { expose } from "comlink";
|
||||
import downloadManager from "../download";
|
||||
import { getAllLocalFiles, getLocalTrashedFiles } from "../files";
|
||||
import type { UploadItem } from "../upload/types";
|
||||
import { imageBitmapAndData } from "./bitmap";
|
||||
import { imageBitmapAndData, type ImageBitmapAndData } from "./bitmap";
|
||||
import { indexCLIP, type CLIPIndex } from "./clip";
|
||||
import {
|
||||
indexableFileIDs,
|
||||
markIndexingFailed,
|
||||
saveCLIPIndex,
|
||||
saveFaceIndex,
|
||||
updateAssumingLocalFiles,
|
||||
} from "./db";
|
||||
import { pullFaceEmbeddings, putFaceIndex } from "./embedding";
|
||||
import { type FaceIndex, indexFaces } from "./face";
|
||||
import { pullFaceEmbeddings, putCLIPIndex, putFaceIndex } from "./embedding";
|
||||
import { indexFaces, type FaceIndex } from "./face";
|
||||
import type { MLWorkerElectron } from "./worker-electron";
|
||||
|
||||
const idleDurationStart = 5; /* 5 seconds */
|
||||
@@ -239,7 +241,24 @@ expose(MLWorker);
|
||||
* Return true atleast one embedding was pulled.
|
||||
*/
|
||||
const pull = async () => {
|
||||
return pullFaceEmbeddings();
|
||||
const res = await Promise.allSettled([
|
||||
pullFaceEmbeddings(),
|
||||
// TODO-ML: clip-test
|
||||
// pullCLIPEmbeddings(),
|
||||
]);
|
||||
for (const r of res) {
|
||||
switch (r.status) {
|
||||
case "fulfilled":
|
||||
// Return true if any pulled something.
|
||||
if (r.value) return true;
|
||||
break;
|
||||
case "rejected":
|
||||
// Throw if any failed.
|
||||
throw r.reason;
|
||||
}
|
||||
}
|
||||
// Return false if neither pulled anything.
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -339,6 +358,32 @@ const index = async (
|
||||
const startTime = Date.now();
|
||||
|
||||
const image = await imageBitmapAndData(enteFile, uploadItem, electron);
|
||||
const res = await Promise.allSettled([
|
||||
_indexFace(f, enteFile, image, electron, userAgent),
|
||||
// TODO-ML: clip-test
|
||||
// _indexCLIP(f, enteFile, image, electron, userAgent),
|
||||
]);
|
||||
image.bitmap.close();
|
||||
|
||||
const msg: string[] = [];
|
||||
for (const r of res) {
|
||||
if (r.status == "rejected") throw r.reason;
|
||||
else msg.push(r.value);
|
||||
}
|
||||
|
||||
log.debug(() => {
|
||||
const ms = Date.now() - startTime;
|
||||
return `Indexed ${msg.join(" and ")} in ${f} (${ms} ms)`;
|
||||
});
|
||||
};
|
||||
|
||||
const _indexFace = async (
|
||||
f: string,
|
||||
enteFile: EnteFile,
|
||||
image: ImageBitmapAndData,
|
||||
electron: MLWorkerElectron,
|
||||
userAgent: string,
|
||||
) => {
|
||||
let faceIndex: FaceIndex;
|
||||
try {
|
||||
faceIndex = await indexFaces(enteFile, image, electron, userAgent);
|
||||
@@ -346,8 +391,6 @@ const index = async (
|
||||
log.error(`Failed to index faces in ${f}`, e);
|
||||
await markIndexingFailed(enteFile.id);
|
||||
throw e;
|
||||
} finally {
|
||||
image.bitmap.close();
|
||||
}
|
||||
|
||||
// [Note: Transient and permanent indexing failures]
|
||||
@@ -389,11 +432,37 @@ const index = async (
|
||||
throw e;
|
||||
}
|
||||
|
||||
log.debug(() => {
|
||||
const nf = faceIndex.faceEmbedding.faces.length;
|
||||
const ms = Date.now() - startTime;
|
||||
return `Indexed ${nf} faces in ${f} (${ms} ms)`;
|
||||
});
|
||||
|
||||
return faceIndex;
|
||||
// A message for debug printing.
|
||||
return `${faceIndex.faceEmbedding.faces.length} faces`;
|
||||
};
|
||||
|
||||
// TODO-ML: clip-test export
|
||||
export const _indexCLIP = async (
|
||||
f: string,
|
||||
enteFile: EnteFile,
|
||||
image: ImageBitmapAndData,
|
||||
electron: MLWorkerElectron,
|
||||
userAgent: string,
|
||||
) => {
|
||||
let clipIndex: CLIPIndex;
|
||||
try {
|
||||
clipIndex = await indexCLIP(enteFile, image, electron, userAgent);
|
||||
} catch (e) {
|
||||
log.error(`Failed to index CLIP in ${f}`, e);
|
||||
await markIndexingFailed(enteFile.id);
|
||||
throw e;
|
||||
}
|
||||
|
||||
// See: [Note: Transient and permanent indexing failures]
|
||||
try {
|
||||
await putCLIPIndex(enteFile, clipIndex);
|
||||
await saveCLIPIndex(clipIndex);
|
||||
} catch (e) {
|
||||
log.error(`Failed to put/save CLIP index for ${f}`, e);
|
||||
if (isHTTP4xxError(e)) await markIndexingFailed(enteFile.id);
|
||||
throw e;
|
||||
}
|
||||
|
||||
// A message for debug printing.
|
||||
return "clip";
|
||||
};
|
||||
|
||||
@@ -38,8 +38,8 @@ export const apiURL = async (path: string) => (await apiOrigin()) + path;
|
||||
export const customAPIOrigin = async () => {
|
||||
let origin = await getKV("apiOrigin");
|
||||
if (!origin && !inWorker()) {
|
||||
// TODO: Migration of apiOrigin from local storage to indexed DB
|
||||
// Remove me after a bit (27 June 2024).
|
||||
// TODO: Migration of apiOrigin from local storage to indexed DB. Added
|
||||
// 27 June 2024, 1.7.2-rc. Remove me after a bit (tag: Migration).
|
||||
const legacyOrigin = localStorage.getItem("apiOrigin");
|
||||
if (legacyOrigin !== null) {
|
||||
origin = legacyOrigin;
|
||||
|
||||
@@ -337,15 +337,15 @@ export interface Electron {
|
||||
/**
|
||||
* Return a CLIP embedding of the given image.
|
||||
*
|
||||
* See: [Note: CLIP based magic search]
|
||||
* See: [Note: Natural language search using CLIP]
|
||||
*
|
||||
* @param jpegImageData The raw bytes of the image encoded as an JPEG.
|
||||
* The input is a opaque float32 array representing the image. The layout
|
||||
* and exact encoding of the input is specific to our implementation and the
|
||||
* ML model (CLIP) we use.
|
||||
*
|
||||
* @returns A CLIP embedding.
|
||||
* @returns A CLIP embedding (an array of 512 floating point values).
|
||||
*/
|
||||
computeCLIPImageEmbedding: (
|
||||
jpegImageData: Uint8Array,
|
||||
) => Promise<Float32Array>;
|
||||
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Return a CLIP embedding of the given image if we already have the model
|
||||
@@ -361,7 +361,7 @@ export interface Electron {
|
||||
* embeddings are used as part of deducing user initiated search results,
|
||||
* and we don't want to block that interaction on a large network request.
|
||||
*
|
||||
* See: [Note: CLIP based magic search]
|
||||
* See: [Note: Natural language search using CLIP]
|
||||
*
|
||||
* @param text The string whose embedding we want to compute.
|
||||
*
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { ensureElectron } from "@/next/electron";
|
||||
import log, { logToDisk } from "@/next/log";
|
||||
import { expose, wrap, type Remote } from "comlink";
|
||||
import { ensureLocalUser } from "../local-user";
|
||||
|
||||
/**
|
||||
* A minimal wrapper for a web {@link Worker}, proxying a class of type T.
|
||||
@@ -62,13 +61,12 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
*
|
||||
* Not all workers need access to all these functions, and this can indeed be
|
||||
* done in a more fine-grained, per-worker, manner if needed. For now, since it
|
||||
* is a motley bunch, we just inject them all to all workers.
|
||||
* is just a couple, we just inject them all to all workers.
|
||||
*/
|
||||
const workerBridge = {
|
||||
// Needed by all workers (likely, not necessarily).
|
||||
// Needed by all workers (likely, but not necessarily).
|
||||
logToDisk,
|
||||
// Needed by MLWorker.
|
||||
getAuthToken: () => ensureLocalUser().token,
|
||||
// Needed by ML worker.
|
||||
convertToJPEG: (imageData: Uint8Array) =>
|
||||
ensureElectron().convertToJPEG(imageData),
|
||||
};
|
||||
|
||||
@@ -70,8 +70,6 @@ export const CustomError = {
|
||||
TWO_FACTOR_ENABLED: "two factor enabled",
|
||||
CLIENT_ERROR: "client error",
|
||||
ServerError: "server error",
|
||||
FILE_NOT_FOUND: "file not found",
|
||||
UNSUPPORTED_PLATFORM: "Unsupported platform",
|
||||
UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch",
|
||||
URL_ALREADY_SET: "url already set",
|
||||
FILE_CONVERSION_FAILED: "file conversion failed",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { EventEmitter } from "eventemitter3";
|
||||
|
||||
// TODO: Remove me
|
||||
|
||||
// When registering event handlers,
|
||||
// handle errors to avoid unhandled rejection or propagation to emit call
|
||||
|
||||
export enum Events {
|
||||
LOGOUT = "logout",
|
||||
FILE_UPLOADED = "fileUploaded",
|
||||
LOCAL_FILES_UPDATED = "localFilesUpdated",
|
||||
}
|
||||
|
||||
|
||||
@@ -2477,11 +2477,6 @@ eventemitter3@^4.0.7:
|
||||
resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-4.0.7.tgz#2de9b68f6528d5644ef5c59526a1b4a07306169f"
|
||||
integrity sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==
|
||||
|
||||
eventemitter3@^5.0.1:
|
||||
version "5.0.1"
|
||||
resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-5.0.1.tgz#53f5ffd0a492ac800721bb42c66b841de96423c4"
|
||||
integrity sha512-GWkBvjiSZK87ELrYOSESUYeVIc9mvLLf/nXalMOS5dYrgZq9o5OVkbZAVM06CVxYsCwH9BDZFPlQTlPA1j4ahA==
|
||||
|
||||
exifr@^7.1.3:
|
||||
version "7.1.3"
|
||||
resolved "https://registry.yarnpkg.com/exifr/-/exifr-7.1.3.tgz#f6218012c36dbb7d843222011b27f065fddbab6f"
|
||||
@@ -3759,19 +3754,6 @@ p-locate@^5.0.0:
|
||||
dependencies:
|
||||
p-limit "^3.0.2"
|
||||
|
||||
p-queue@^7.1.0:
|
||||
version "7.4.1"
|
||||
resolved "https://registry.yarnpkg.com/p-queue/-/p-queue-7.4.1.tgz#7f86f853048beca8272abdbb7cec1ed2afc0f265"
|
||||
integrity sha512-vRpMXmIkYF2/1hLBKisKeVYJZ8S2tZ0zEAmIJgdVKP2nq0nh4qCdf8bgw+ZgKrkh71AOCaqzwbJJk1WtdcF3VA==
|
||||
dependencies:
|
||||
eventemitter3 "^5.0.1"
|
||||
p-timeout "^5.0.2"
|
||||
|
||||
p-timeout@^5.0.2:
|
||||
version "5.1.0"
|
||||
resolved "https://registry.yarnpkg.com/p-timeout/-/p-timeout-5.1.0.tgz#b3c691cf4415138ce2d9cfe071dba11f0fee085b"
|
||||
integrity sha512-auFDyzzzGZZZdHz3BtET9VEz0SE/uMEAx7uWfGPucfzEwwe/xH0iVeZibQmANYE/hp9T2+UUZT5m+BKyrDp3Ew==
|
||||
|
||||
pako@~1.0.2:
|
||||
version "1.0.11"
|
||||
resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.11.tgz#6c9599d340d54dfd3946380252a35705a6b992bf"
|
||||
|
||||
Reference in New Issue
Block a user