[desktop] Use MobileCLIP for magic search (#2649)

This commit is contained in:
Manav Rathi
2024-08-10 14:08:37 +05:30
committed by GitHub
2 changed files with 19 additions and 22 deletions

View File

@@ -202,19 +202,19 @@ const createInferenceSession = async (modelPath: string) => {
};
const cachedCLIPImageSession = makeCachedInferenceSession(
"clip-image-vit-32-float32.onnx",
351468764 /* 335.2 MB */,
"mobileclip_s2_image.onnx",
143061211 /* 143 MB */,
);
/**
* Compute CLIP embeddings for an image.
*
* The embeddings are computed using ONNX runtime, with CLIP as the model.
* The embeddings are computed using ONNX runtime, with MobileCLIP as the model.
*/
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
const session = await cachedCLIPImageSession();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
input: new ort.Tensor("float32", input, [1, 3, 256, 256]),
};
const t = Date.now();
const results = await session.run(feeds);
@@ -224,8 +224,8 @@ export const computeCLIPImageEmbedding = async (input: Float32Array) => {
};
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
"mobileclip_s2_text_int32.onnx",
253895600 /* 253 MB */,
);
let _tokenizer: Tokenizer | undefined;
@@ -237,7 +237,7 @@ const getTokenizer = () => {
/**
* Compute CLIP embeddings for an text snippet.
*
* The embeddings are computed using ONNX runtime, with CLIP as the model.
* The embeddings are computed using ONNX runtime, with MobileCLIP as the model.
*/
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
const sessionOrSkip = await Promise.race([
@@ -270,7 +270,7 @@ export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
const cachedFaceDetectionSession = makeCachedInferenceSession(
"yolov5s_face_640_640_dynamic.onnx",
30762872 /* 29.3 MB */,
30762872 /* 29 MB */,
);
/**

View File

@@ -1,7 +1,7 @@
import type { ElectronMLWorker } from "@/base/types/ipc";
import type { ImageBitmapAndData } from "./blob";
import { clipIndexes } from "./db";
import { pixelRGBBicubic } from "./image";
import { pixelRGBBilinear } from "./image";
import { dotProduct, norm } from "./math";
import type { CLIPMatches } from "./worker-types";
@@ -39,8 +39,9 @@ export const clipIndexingVersion = 1;
* initial launch of this feature using the GGML runtime.
*
* Since the initial launch, we've switched over to another runtime,
* [ONNX](https://onnxruntime.ai) and have made other implementation changes,
* but the overall gist remains the same.
* [ONNX](https://onnxruntime.ai), started using Apple's
* [MobileCLIP](https://github.com/apple/ml-mobileclip/) as the model and have
* made other implementation changes, but the overall gist remains the same.
*
* Note that we don't train the neural network - we only use one of the publicly
* available pre-trained neural networks for inference. These neural networks
@@ -117,14 +118,10 @@ const computeEmbedding = async (
};
/**
* Convert {@link imageData} into the format that the CLIP model expects.
* Convert {@link imageData} into the format that the MobileCLIP model expects.
*/
const convertToCLIPInput = (imageData: ImageData) => {
const requiredWidth = 224;
const requiredHeight = 224;
const mean = [0.48145466, 0.4578275, 0.40821073] as const;
const std = [0.26862954, 0.26130258, 0.27577711] as const;
const [requiredWidth, requiredHeight] = [256, 256];
const { width, height, data: pixelData } = imageData;
@@ -144,16 +141,16 @@ const convertToCLIPInput = (imageData: ImageData) => {
const cOffsetB = 2 * requiredHeight * requiredWidth; // ChannelOffsetBlue
for (let h = 0 + heightOffset; h < scaledHeight - heightOffset; h++) {
for (let w = 0 + widthOffset; w < scaledWidth - widthOffset; w++) {
const { r, g, b } = pixelRGBBicubic(
const { r, g, b } = pixelRGBBilinear(
w / scale,
h / scale,
pixelData,
width,
height,
);
clipInput[pi] = (r / 255.0 - mean[0]) / std[0];
clipInput[pi + cOffsetG] = (g / 255.0 - mean[1]) / std[1];
clipInput[pi + cOffsetB] = (b / 255.0 - mean[2]) / std[2];
clipInput[pi] = r / 255.0;
clipInput[pi + cOffsetG] = g / 255.0;
clipInput[pi + cOffsetB] = b / 255.0;
pi++;
}
}
@@ -189,5 +186,5 @@ export const clipMatches = async (
// This code is on the hot path, so these optimizations help.
[fileID, dotProduct(embedding, textEmbedding)] as const,
);
return new Map(items.filter(([, score]) => score >= 0.23));
return new Map(items.filter(([, score]) => score >= 0.2));
};