This commit is contained in:
Manav Rathi
2024-07-09 14:55:27 +05:30
parent 942231c521
commit 56fe538f07
5 changed files with 16 additions and 49 deletions

View File

@@ -186,10 +186,8 @@ export const attachIPCHandlers = () => {
// - ML
ipcMain.handle(
"computeCLIPImageEmbedding",
(_, jpegImageData: Uint8Array) =>
computeCLIPImageEmbedding(jpegImageData),
ipcMain.handle("computeCLIPImageEmbedding", (_, input: Float32Array) =>
computeCLIPImageEmbedding(input),
);
ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) =>

View File

@@ -1,23 +1,12 @@
// TODO: These arise from the array indexing in the pre-processing code. Isolate
// once that code settles down to its final place (currently duplicated across
// web and desktop).
/* eslint-disable @typescript-eslint/no-non-null-assertion */
/**
* @file Compute CLIP embeddings for images and text.
*
* The embeddings are computed using ONNX runtime, with CLIP as the model.
*
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
*/
import Tokenizer from "clip-bpe-js";
import jpeg from "jpeg-js";
import fs from "node:fs/promises";
import * as ort from "onnxruntime-node";
import log from "../log";
import { writeStream } from "../stream";
import { ensure, wait } from "../utils/common";
import { deleteTempFile, makeTempFilePath } from "../utils/temp";
import { makeCachedInferenceSession } from "./ml";
const cachedCLIPImageSession = makeCachedInferenceSession(
@@ -25,36 +14,18 @@ const cachedCLIPImageSession = makeCachedInferenceSession(
351468764 /* 335.2 MB */,
);
export const computeCLIPImageEmbedding = async (jpegImageData: Uint8Array) => {
const tempFilePath = await makeTempFilePath();
const imageStream = new Response(jpegImageData.buffer).body;
await writeStream(tempFilePath, ensure(imageStream));
try {
return await clipImageEmbedding_(tempFilePath);
} finally {
await deleteTempFile(tempFilePath);
}
};
const clipImageEmbedding_ = async (jpegFilePath: string) => {
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
const session = await cachedCLIPImageSession();
const t1 = Date.now();
const rgbData = await getRGBData(jpegFilePath);
const t = Date.now();
const feeds = {
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
};
const t2 = Date.now();
const results = await session.run(feeds);
log.debug(
() =>
`ONNX/CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`);
/* Need these model specific casts to type the result */
const imageEmbedding = ensure(results.output).data as Float32Array;
return normalizeEmbedding(imageEmbedding);
return ensure(results.output).data as Float32Array;
};
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
@@ -95,6 +66,5 @@ export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
() =>
`ONNX/CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const textEmbedding = ensure(results.output).data as Float32Array;
return normalizeEmbedding(textEmbedding);
return ensure(results.output).data as Float32Array;
};

View File

@@ -163,8 +163,8 @@ const ffmpegExec = (
// - ML
const computeCLIPImageEmbedding = (jpegImageData: Uint8Array) =>
ipcRenderer.invoke("computeCLIPImageEmbedding", jpegImageData);
const computeCLIPImageEmbedding = (input: Float32Array) =>
ipcRenderer.invoke("computeCLIPImageEmbedding", input);
const computeCLIPTextEmbeddingIfAvailable = (text: string) =>
ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text);

View File

@@ -402,7 +402,7 @@ const searchClip = async (
// };
// getTextEmbeddingIfAvailable = async (text: string) => {
// return ensureElectron().computeCLIPTextEmbeddingIfAvailable(text);
// return normalizeEmbedding(ensureElectron().computeCLIPTextEmbeddingIfAvailable(text));
// };
// export const computeClipMatchScore = async (

View File

@@ -339,14 +339,13 @@ export interface Electron {
*
* See: [Note: Natural language search using CLIP]
*
* The input is a opaque binary data whose internal structure is specific to
* our implementation and the ML model (CLIP) we use.
* The input is a opaque float32 array representing the image. The layout
* and exact encoding of the input is specific to our implementation and the
* ML model (CLIP) we use.
*
* @returns A CLIP embedding.
* @returns A CLIP embedding (an array of 512 floating point values).
*/
computeCLIPImageEmbedding: (
jpegImageData: Uint8Array,
) => Promise<Float32Array>;
computeCLIPImageEmbedding: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a CLIP embedding of the given image if we already have the model