[desktop] Use MobileCLIP for magic search (#2649)
This commit is contained in:
@@ -202,19 +202,19 @@ const createInferenceSession = async (modelPath: string) => {
|
||||
};
|
||||
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
"clip-image-vit-32-float32.onnx",
|
||||
351468764 /* 335.2 MB */,
|
||||
"mobileclip_s2_image.onnx",
|
||||
143061211 /* 143 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
* Compute CLIP embeddings for an image.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
* The embeddings are computed using ONNX runtime, with MobileCLIP as the model.
|
||||
*/
|
||||
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
|
||||
const session = await cachedCLIPImageSession();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
|
||||
input: new ort.Tensor("float32", input, [1, 3, 256, 256]),
|
||||
};
|
||||
const t = Date.now();
|
||||
const results = await session.run(feeds);
|
||||
@@ -224,8 +224,8 @@ export const computeCLIPImageEmbedding = async (input: Float32Array) => {
|
||||
};
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
"clip-text-vit-32-uint8.onnx",
|
||||
64173509 /* 61.2 MB */,
|
||||
"mobileclip_s2_text_int32.onnx",
|
||||
253895600 /* 253 MB */,
|
||||
);
|
||||
|
||||
let _tokenizer: Tokenizer | undefined;
|
||||
@@ -237,7 +237,7 @@ const getTokenizer = () => {
|
||||
/**
|
||||
* Compute CLIP embeddings for an text snippet.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
* The embeddings are computed using ONNX runtime, with MobileCLIP as the model.
|
||||
*/
|
||||
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||
const sessionOrSkip = await Promise.race([
|
||||
@@ -270,7 +270,7 @@ export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||
|
||||
const cachedFaceDetectionSession = makeCachedInferenceSession(
|
||||
"yolov5s_face_640_640_dynamic.onnx",
|
||||
30762872 /* 29.3 MB */,
|
||||
30762872 /* 29 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { ElectronMLWorker } from "@/base/types/ipc";
|
||||
import type { ImageBitmapAndData } from "./blob";
|
||||
import { clipIndexes } from "./db";
|
||||
import { pixelRGBBicubic } from "./image";
|
||||
import { pixelRGBBilinear } from "./image";
|
||||
import { dotProduct, norm } from "./math";
|
||||
import type { CLIPMatches } from "./worker-types";
|
||||
|
||||
@@ -39,8 +39,9 @@ export const clipIndexingVersion = 1;
|
||||
* initial launch of this feature using the GGML runtime.
|
||||
*
|
||||
* Since the initial launch, we've switched over to another runtime,
|
||||
* [ONNX](https://onnxruntime.ai) and have made other implementation changes,
|
||||
* but the overall gist remains the same.
|
||||
* [ONNX](https://onnxruntime.ai), started using Apple's
|
||||
* [MobileCLIP](https://github.com/apple/ml-mobileclip/) as the model and have
|
||||
* made other implementation changes, but the overall gist remains the same.
|
||||
*
|
||||
* Note that we don't train the neural network - we only use one of the publicly
|
||||
* available pre-trained neural networks for inference. These neural networks
|
||||
@@ -117,14 +118,10 @@ const computeEmbedding = async (
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert {@link imageData} into the format that the CLIP model expects.
|
||||
* Convert {@link imageData} into the format that the MobileCLIP model expects.
|
||||
*/
|
||||
const convertToCLIPInput = (imageData: ImageData) => {
|
||||
const requiredWidth = 224;
|
||||
const requiredHeight = 224;
|
||||
|
||||
const mean = [0.48145466, 0.4578275, 0.40821073] as const;
|
||||
const std = [0.26862954, 0.26130258, 0.27577711] as const;
|
||||
const [requiredWidth, requiredHeight] = [256, 256];
|
||||
|
||||
const { width, height, data: pixelData } = imageData;
|
||||
|
||||
@@ -144,16 +141,16 @@ const convertToCLIPInput = (imageData: ImageData) => {
|
||||
const cOffsetB = 2 * requiredHeight * requiredWidth; // ChannelOffsetBlue
|
||||
for (let h = 0 + heightOffset; h < scaledHeight - heightOffset; h++) {
|
||||
for (let w = 0 + widthOffset; w < scaledWidth - widthOffset; w++) {
|
||||
const { r, g, b } = pixelRGBBicubic(
|
||||
const { r, g, b } = pixelRGBBilinear(
|
||||
w / scale,
|
||||
h / scale,
|
||||
pixelData,
|
||||
width,
|
||||
height,
|
||||
);
|
||||
clipInput[pi] = (r / 255.0 - mean[0]) / std[0];
|
||||
clipInput[pi + cOffsetG] = (g / 255.0 - mean[1]) / std[1];
|
||||
clipInput[pi + cOffsetB] = (b / 255.0 - mean[2]) / std[2];
|
||||
clipInput[pi] = r / 255.0;
|
||||
clipInput[pi + cOffsetG] = g / 255.0;
|
||||
clipInput[pi + cOffsetB] = b / 255.0;
|
||||
pi++;
|
||||
}
|
||||
}
|
||||
@@ -189,5 +186,5 @@ export const clipMatches = async (
|
||||
// This code is on the hot path, so these optimizations help.
|
||||
[fileID, dotProduct(embedding, textEmbedding)] as const,
|
||||
);
|
||||
return new Map(items.filter(([, score]) => score >= 0.23));
|
||||
return new Map(items.filter(([, score]) => score >= 0.2));
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user