Split on the main/utility axis
This commit is contained in:
@@ -45,8 +45,9 @@ import { logout } from "./services/logout";
|
||||
import {
|
||||
computeCLIPImageEmbedding,
|
||||
computeCLIPTextEmbeddingIfAvailable,
|
||||
} from "./services/ml-clip";
|
||||
import { computeFaceEmbeddings, detectFaces } from "./services/ml-face";
|
||||
computeFaceEmbeddings,
|
||||
detectFaces,
|
||||
} from "./services/ml-utility";
|
||||
import {
|
||||
encryptionKey,
|
||||
lastShownChangelogVersion,
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
/**
|
||||
* @file Compute CLIP embeddings for images and text.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*/
|
||||
|
||||
import Tokenizer from "clip-bpe-js";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { ensure, wait } from "../utils/common";
|
||||
import { makeCachedInferenceSession } from "./ml";
|
||||
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
"clip-image-vit-32-float32.onnx",
|
||||
351468764 /* 335.2 MB */,
|
||||
);
|
||||
|
||||
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
|
||||
const session = await cachedCLIPImageSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
|
||||
};
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to type the result */
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
"clip-text-vit-32-uint8.onnx",
|
||||
64173509 /* 61.2 MB */,
|
||||
);
|
||||
|
||||
let _tokenizer: Tokenizer | undefined;
|
||||
const getTokenizer = () => {
|
||||
if (!_tokenizer) _tokenizer = new Tokenizer();
|
||||
return _tokenizer;
|
||||
};
|
||||
|
||||
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||
const sessionOrSkip = await Promise.race([
|
||||
cachedCLIPTextSession(),
|
||||
// Wait for a tick to get the session promise to resolved the first time
|
||||
// this code runs on each app start (and the model has been downloaded).
|
||||
wait(0).then(() => 1),
|
||||
]);
|
||||
|
||||
// Don't wait for the download to complete.
|
||||
if (typeof sessionOrSkip == "number") {
|
||||
log.info(
|
||||
"Ignoring CLIP text embedding request because model download is pending",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const session = sessionOrSkip;
|
||||
const t = Date.now();
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
const feeds = {
|
||||
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
||||
};
|
||||
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/CLIP text embedding took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
@@ -1,53 +0,0 @@
|
||||
/**
|
||||
* @file Various face recognition related tasks.
|
||||
*
|
||||
* - Face detection with the YOLO model.
|
||||
* - Face embedding with the MobileFaceNet model.
|
||||
*
|
||||
* The runtime used is ONNX.
|
||||
*/
|
||||
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { ensure } from "../utils/common";
|
||||
import { makeCachedInferenceSession } from "./ml";
|
||||
|
||||
const cachedFaceDetectionSession = makeCachedInferenceSession(
|
||||
"yolov5s_face_640_640_dynamic.onnx",
|
||||
30762872 /* 29.3 MB */,
|
||||
);
|
||||
|
||||
export const detectFaces = async (input: Float32Array) => {
|
||||
const session = await cachedFaceDetectionSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
|
||||
};
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data;
|
||||
};
|
||||
|
||||
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
|
||||
"mobilefacenet_opset15.onnx",
|
||||
5286998 /* 5 MB */,
|
||||
);
|
||||
|
||||
export const computeFaceEmbeddings = async (input: Float32Array) => {
|
||||
// Dimension of each face (alias)
|
||||
const mobileFaceNetFaceSize = 112;
|
||||
// Smaller alias
|
||||
const z = mobileFaceNetFaceSize;
|
||||
// Size of each face's data in the batch
|
||||
const n = Math.round(input.length / (z * z * 3));
|
||||
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
|
||||
|
||||
const session = await cachedFaceEmbeddingSession();
|
||||
const t = Date.now();
|
||||
const feeds = { img_inputs: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to extract and type the result */
|
||||
return (results.embeddings as unknown as Record<string, unknown>)
|
||||
.cpuData as Float32Array;
|
||||
};
|
||||
235
desktop/src/main/services/ml-utility.ts
Normal file
235
desktop/src/main/services/ml-utility.ts
Normal file
@@ -0,0 +1,235 @@
|
||||
/**
|
||||
* @file ML related tasks. This code runs in a utility process.
|
||||
*
|
||||
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
|
||||
* for various tasks are not shipped with the app but are downloaded on demand.
|
||||
*/
|
||||
|
||||
import Tokenizer from "clip-bpe-js";
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { writeStream } from "../stream";
|
||||
import { ensure, wait } from "../utils/common";
|
||||
|
||||
/**
|
||||
* Return a function that can be used to trigger a download of the specified
|
||||
* model, and the creating of an ONNX inference session initialized using it.
|
||||
*
|
||||
* Multiple parallel calls to the returned function are fine, it ensures that
|
||||
* the the model will be downloaded and the session created using it only once.
|
||||
* All pending calls to it meanwhile will just await on the same promise.
|
||||
*
|
||||
* And once the promise is resolved, the create ONNX inference session will be
|
||||
* cached, so subsequent calls to the returned function will just reuse the same
|
||||
* session.
|
||||
*
|
||||
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
|
||||
* actively trigger a download until the returned function is called.
|
||||
*
|
||||
* @param modelName The name of the model to download.
|
||||
*
|
||||
* @param modelByteSize The size in bytes that we expect the model to have. If
|
||||
* the size of the downloaded model does not match the expected size, then we
|
||||
* will redownload it.
|
||||
*
|
||||
* @returns a function. calling that function returns a promise to an ONNX
|
||||
* session.
|
||||
*/
|
||||
const makeCachedInferenceSession = (
|
||||
modelName: string,
|
||||
modelByteSize: number,
|
||||
) => {
|
||||
let session: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const download = () =>
|
||||
modelPathDownloadingIfNeeded(modelName, modelByteSize);
|
||||
|
||||
const createSession = (modelPath: string) =>
|
||||
createInferenceSession(modelPath);
|
||||
|
||||
const cachedInferenceSession = () => {
|
||||
if (!session) session = download().then(createSession);
|
||||
return session;
|
||||
};
|
||||
|
||||
return cachedInferenceSession;
|
||||
};
|
||||
|
||||
/**
|
||||
* Download the model named {@link modelName} if we don't already have it.
|
||||
*
|
||||
* Also verify that the size of the model we get matches {@expectedByteSize} (if
|
||||
* not, redownload it).
|
||||
*
|
||||
* @returns the path to the model on the local machine.
|
||||
*/
|
||||
const modelPathDownloadingIfNeeded = async (
|
||||
modelName: string,
|
||||
expectedByteSize: number,
|
||||
) => {
|
||||
const modelPath = modelSavePath(modelName);
|
||||
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
await downloadModel(modelPath, modelName);
|
||||
} else {
|
||||
const size = (await fs.stat(modelPath)).size;
|
||||
if (size !== expectedByteSize) {
|
||||
log.error(
|
||||
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
|
||||
);
|
||||
await downloadModel(modelPath, modelName);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download.
|
||||
log.info(`Downloading ML model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
const body = res.body;
|
||||
if (!body) throw new Error(`Received an null response for ${url}`);
|
||||
// Save.
|
||||
await writeStream(saveLocation, body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Create an ONNX {@link InferenceSession} with some defaults.
|
||||
*/
|
||||
const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
// Restrict the number of threads to 1.
|
||||
intraOpNumThreads: 1,
|
||||
// Be more conservative with RAM usage.
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
};
|
||||
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
"clip-image-vit-32-float32.onnx",
|
||||
351468764 /* 335.2 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
* Compute CLIP embeddings for an image.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*/
|
||||
export const computeCLIPImageEmbedding = async (input: Float32Array) => {
|
||||
const session = await cachedCLIPImageSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 224, 224]),
|
||||
};
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/CLIP image embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to type the result */
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
"clip-text-vit-32-uint8.onnx",
|
||||
64173509 /* 61.2 MB */,
|
||||
);
|
||||
|
||||
let _tokenizer: Tokenizer | undefined;
|
||||
const getTokenizer = () => {
|
||||
if (!_tokenizer) _tokenizer = new Tokenizer();
|
||||
return _tokenizer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Compute CLIP embeddings for an text snippet.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*/
|
||||
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||
const sessionOrSkip = await Promise.race([
|
||||
cachedCLIPTextSession(),
|
||||
// Wait for a tick to get the session promise to resolved the first time
|
||||
// this code runs on each app start (and the model has been downloaded).
|
||||
wait(0).then(() => 1),
|
||||
]);
|
||||
|
||||
// Don't wait for the download to complete.
|
||||
if (typeof sessionOrSkip == "number") {
|
||||
log.info(
|
||||
"Ignoring CLIP text embedding request because model download is pending",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const session = sessionOrSkip;
|
||||
const t = Date.now();
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
const feeds = {
|
||||
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
||||
};
|
||||
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/CLIP text embedding took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data as Float32Array;
|
||||
};
|
||||
|
||||
const cachedFaceDetectionSession = makeCachedInferenceSession(
|
||||
"yolov5s_face_640_640_dynamic.onnx",
|
||||
30762872 /* 29.3 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
* Face detection with the YOLO model and ONNX runtime.
|
||||
*/
|
||||
export const detectFaces = async (input: Float32Array) => {
|
||||
const session = await cachedFaceDetectionSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
|
||||
};
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`);
|
||||
return ensure(results.output).data;
|
||||
};
|
||||
|
||||
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
|
||||
"mobilefacenet_opset15.onnx",
|
||||
5286998 /* 5 MB */,
|
||||
);
|
||||
|
||||
/**
|
||||
* Face embedding with the MobileFaceNet model and ONNX runtime.
|
||||
*/
|
||||
export const computeFaceEmbeddings = async (input: Float32Array) => {
|
||||
// Dimension of each face (alias)
|
||||
const mobileFaceNetFaceSize = 112;
|
||||
// Smaller alias
|
||||
const z = mobileFaceNetFaceSize;
|
||||
// Size of each face's data in the batch
|
||||
const n = Math.round(input.length / (z * z * 3));
|
||||
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
|
||||
|
||||
const session = await cachedFaceEmbeddingSession();
|
||||
const t = Date.now();
|
||||
const feeds = { img_inputs: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`);
|
||||
/* Need these model specific casts to extract and type the result */
|
||||
return (results.embeddings as unknown as Record<string, unknown>)
|
||||
.cpuData as Float32Array;
|
||||
};
|
||||
@@ -1,20 +1,7 @@
|
||||
/**
|
||||
* @file ML related functionality, generic layer.
|
||||
*
|
||||
* @see also `ml-clip.ts`, `ml-face.ts`.
|
||||
*
|
||||
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
|
||||
* for various tasks are not shipped with the app but are downloaded on demand.
|
||||
* @file ML related functionality. This code runs in the main process.
|
||||
*/
|
||||
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { writeStream } from "../stream";
|
||||
|
||||
/**
|
||||
* Create a new ML session.
|
||||
*
|
||||
@@ -53,15 +40,18 @@ import { writeStream } from "../stream";
|
||||
* to communicate. A pair of message ports is called a message channel. The cool
|
||||
* thing about these is that we can pass these ports themselves over IPC.
|
||||
*
|
||||
* > One caveat here is that the message ports can only be passed using the
|
||||
* > `postMessage` APIs, not the usual send/invoke APIs.
|
||||
*
|
||||
* So we
|
||||
*
|
||||
* 1. Spawn a utility process.
|
||||
* 2. In the utility process create a message channel.
|
||||
* 3. Keep one port of the pair with us, and send the other over IPC to the
|
||||
* _web worker_ that is coordinating the ML indexing on the web layer.
|
||||
* 1. In the utility process create a message channel.
|
||||
* 2. Spawn a utility process, and send one port of the pair to it.
|
||||
* 3. Send the other port of the pair to the renderer.
|
||||
*
|
||||
* Thereafter, the utility process and web worker can directly talk to each
|
||||
* other!
|
||||
* The renderer will forward that port to the web worker that is coordinating
|
||||
* the ML indexing on the web layer. Thereafter, the utility process and web
|
||||
* worker can directly talk to each other!
|
||||
*
|
||||
* Node.js utility process <-> Renderer web worker
|
||||
*
|
||||
@@ -70,109 +60,3 @@ export const createMLSession = () => {
|
||||
// }: Promise<MessagePort> => {
|
||||
throw new Error("Not implemented");
|
||||
};
|
||||
|
||||
/**
|
||||
* Return a function that can be used to trigger a download of the specified
|
||||
* model, and the creating of an ONNX inference session initialized using it.
|
||||
*
|
||||
* Multiple parallel calls to the returned function are fine, it ensures that
|
||||
* the the model will be downloaded and the session created using it only once.
|
||||
* All pending calls to it meanwhile will just await on the same promise.
|
||||
*
|
||||
* And once the promise is resolved, the create ONNX inference session will be
|
||||
* cached, so subsequent calls to the returned function will just reuse the same
|
||||
* session.
|
||||
*
|
||||
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
|
||||
* actively trigger a download until the returned function is called.
|
||||
*
|
||||
* @param modelName The name of the model to download.
|
||||
*
|
||||
* @param modelByteSize The size in bytes that we expect the model to have. If
|
||||
* the size of the downloaded model does not match the expected size, then we
|
||||
* will redownload it.
|
||||
*
|
||||
* @returns a function. calling that function returns a promise to an ONNX
|
||||
* session.
|
||||
*/
|
||||
export const makeCachedInferenceSession = (
|
||||
modelName: string,
|
||||
modelByteSize: number,
|
||||
) => {
|
||||
let session: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const download = () =>
|
||||
modelPathDownloadingIfNeeded(modelName, modelByteSize);
|
||||
|
||||
const createSession = (modelPath: string) =>
|
||||
createInferenceSession(modelPath);
|
||||
|
||||
const cachedInferenceSession = () => {
|
||||
if (!session) session = download().then(createSession);
|
||||
return session;
|
||||
};
|
||||
|
||||
return cachedInferenceSession;
|
||||
};
|
||||
|
||||
/**
|
||||
* Download the model named {@link modelName} if we don't already have it.
|
||||
*
|
||||
* Also verify that the size of the model we get matches {@expectedByteSize} (if
|
||||
* not, redownload it).
|
||||
*
|
||||
* @returns the path to the model on the local machine.
|
||||
*/
|
||||
const modelPathDownloadingIfNeeded = async (
|
||||
modelName: string,
|
||||
expectedByteSize: number,
|
||||
) => {
|
||||
const modelPath = modelSavePath(modelName);
|
||||
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
await downloadModel(modelPath, modelName);
|
||||
} else {
|
||||
const size = (await fs.stat(modelPath)).size;
|
||||
if (size !== expectedByteSize) {
|
||||
log.error(
|
||||
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
|
||||
);
|
||||
await downloadModel(modelPath, modelName);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download.
|
||||
log.info(`Downloading ML model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
const body = res.body;
|
||||
if (!body) throw new Error(`Received an null response for ${url}`);
|
||||
// Save.
|
||||
await writeStream(saveLocation, body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Crete an ONNX {@link InferenceSession} with some defaults.
|
||||
*/
|
||||
const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
// Restrict the number of threads to 1.
|
||||
intraOpNumThreads: 1,
|
||||
// Be more conservative with RAM usage.
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user