[desktop] Enable ONNX Face ML (#1413)

Indexing + PUT embeddings works now, but most other things are still
WIP.
This commit is contained in:
Manav Rathi
2024-04-11 16:40:36 +05:30
committed by GitHub
19 changed files with 340 additions and 301 deletions

View File

@@ -1,8 +1,8 @@
# Dependencies
* [Electron](#electron)
* [Dev dependencies](#dev)
* [Functionality](#functionality)
- [Electron](#electron)
- [Dev dependencies](#dev)
- [Functionality](#functionality)
## Electron
@@ -114,8 +114,8 @@ available on the host machine, and is not bundled with our app.
AI/ML runtime. It powers both natural language searches (using CLIP) and face
detection (using YOLO).
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding
JPEG data into raw RGB bytes before passing it to ONNX.
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding JPEG
data into raw RGB bytes before passing it to ONNX.
html-entities is used by the bundled clip-bpe-ts tokenizer for CLIP.

View File

@@ -36,13 +36,14 @@ import {
updateAndRestart,
updateOnNextRestart,
} from "./services/app-update";
import { clipImageEmbedding, clipTextEmbedding } from "./services/clip";
import { runFFmpegCmd } from "./services/ffmpeg";
import { getDirFiles } from "./services/fs";
import {
convertToJPEG,
generateImageThumbnail,
} from "./services/imageProcessor";
import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip";
import { detectFaces, faceEmbedding } from "./services/ml-face";
import {
clearStores,
encryptionKey,
@@ -146,6 +147,14 @@ export const attachIPCHandlers = () => {
clipTextEmbedding(text),
);
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
detectFaces(input),
);
ipcMain.handle("faceEmbedding", (_, input: Float32Array) =>
faceEmbedding(input),
);
// - File selection
ipcMain.handle("selectDirectory", () => selectDirectory());

View File

@@ -1,19 +1,13 @@
/**
* @file Compute CLIP embeddings
* @file Compute CLIP embeddings for images and text.
*
* @see `web/apps/photos/src/services/clip-service.ts` for more details. This
* file implements the Node.js implementation of the actual embedding
* computation. By doing it in the Node.js layer, we can use the binary ONNX
* runtimes which are 10-20x faster than the WASM based web ones.
* The embeddings are computed using ONNX runtime, with CLIP as the model.
*
* The embeddings are computed using ONNX runtime. The model itself is not
* shipped with the app but is downloaded on demand.
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
*/
import { app, net } from "electron/main";
import { existsSync } from "fs";
import jpeg from "jpeg-js";
import fs from "node:fs/promises";
import path from "node:path";
import * as ort from "onnxruntime-node";
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
import { CustomErrors } from "../../types/ipc";
@@ -21,6 +15,12 @@ import { writeStream } from "../fs";
import log from "../log";
import { generateTempFilePath } from "../temp";
import { deleteTempFile } from "./ffmpeg";
import {
createInferenceSession,
downloadModel,
modelPathDownloadingIfNeeded,
modelSavePath,
} from "./ml";
const textModelName = "clip-text-vit-32-uint8.onnx";
const textModelByteSize = 64173509; // 61.2 MB
@@ -28,55 +28,20 @@ const textModelByteSize = 64173509; // 61.2 MB
const imageModelName = "clip-image-vit-32-float32.onnx";
const imageModelByteSize = 351468764; // 335.2 MB
/** Return the path where the given {@link modelName} is meant to be saved */
const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
// Download
log.info(`Downloading CLIP model from ${name}`);
const url = `https://models.ente.io/${name}`;
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
// Save
await writeStream(saveLocation, res.body);
log.info(`Downloaded CLIP model ${name}`);
};
let activeImageModelDownload: Promise<void> | undefined;
let activeImageModelDownload: Promise<string> | undefined;
const imageModelPathDownloadingIfNeeded = async () => {
try {
const modelPath = modelSavePath(imageModelName);
if (activeImageModelDownload) {
log.info("Waiting for CLIP image model download to finish");
await activeImageModelDownload;
} else {
if (!existsSync(modelPath)) {
log.info("CLIP image model not found, downloading");
activeImageModelDownload = downloadModel(
modelPath,
imageModelName,
);
await activeImageModelDownload;
} else {
const localFileSize = (await fs.stat(modelPath)).size;
if (localFileSize !== imageModelByteSize) {
log.error(
`CLIP image model size ${localFileSize} does not match the expected size, downloading again`,
);
activeImageModelDownload = downloadModel(
modelPath,
imageModelName,
);
await activeImageModelDownload;
}
}
activeImageModelDownload = modelPathDownloadingIfNeeded(
imageModelName,
imageModelByteSize,
);
return await activeImageModelDownload;
}
return modelPath;
} finally {
activeImageModelDownload = undefined;
}
@@ -84,6 +49,8 @@ const imageModelPathDownloadingIfNeeded = async () => {
let textModelDownloadInProgress = false;
/* TODO(MR): use the generic method. Then we can remove the exports for the
internal details functions that we use here */
const textModelPathDownloadingIfNeeded = async () => {
if (textModelDownloadInProgress)
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
@@ -123,13 +90,6 @@ const textModelPathDownloadingIfNeeded = async () => {
return modelPath;
};
const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
intraOpNumThreads: 1,
enableCpuMemArena: false,
});
};
let imageSessionPromise: Promise<any> | undefined;
const onnxImageSession = async () => {
@@ -174,7 +134,7 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => {
const results = await imageSession.run(feeds);
log.debug(
() =>
`CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
`onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const imageEmbedding = results["output"].data; // Float32Array
return normalizeEmbedding(imageEmbedding);
@@ -281,7 +241,7 @@ export const clipTextEmbedding = async (text: string) => {
const results = await imageSession.run(feeds);
log.debug(
() =>
`CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const textEmbedding = results["output"].data;
return normalizeEmbedding(textEmbedding);

View File

@@ -0,0 +1,108 @@
/**
* @file Various face recognition related tasks.
*
* - Face detection with the YOLO model.
* - Face embedding with the MobileFaceNet model.
*
* The runtime used is ONNX.
*/
import * as ort from "onnxruntime-node";
import log from "../log";
import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml";
const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx";
const faceDetectionModelByteSize = 30762872; // 29.3 MB
const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
const faceEmbeddingModelByteSize = 5286998; // 5 MB
let activeFaceDetectionModelDownload: Promise<string> | undefined;
const faceDetectionModelPathDownloadingIfNeeded = async () => {
try {
if (activeFaceDetectionModelDownload) {
log.info("Waiting for face detection model download to finish");
await activeFaceDetectionModelDownload;
} else {
activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded(
faceDetectionModelName,
faceDetectionModelByteSize,
);
return await activeFaceDetectionModelDownload;
}
} finally {
activeFaceDetectionModelDownload = undefined;
}
};
let _faceDetectionSession: Promise<ort.InferenceSession> | undefined;
const faceDetectionSession = async () => {
if (!_faceDetectionSession) {
_faceDetectionSession =
faceDetectionModelPathDownloadingIfNeeded().then((modelPath) =>
createInferenceSession(modelPath),
);
}
return _faceDetectionSession;
};
let activeFaceEmbeddingModelDownload: Promise<string> | undefined;
const faceEmbeddingModelPathDownloadingIfNeeded = async () => {
try {
if (activeFaceEmbeddingModelDownload) {
log.info("Waiting for face embedding model download to finish");
await activeFaceEmbeddingModelDownload;
} else {
activeFaceEmbeddingModelDownload = modelPathDownloadingIfNeeded(
faceEmbeddingModelName,
faceEmbeddingModelByteSize,
);
return await activeFaceEmbeddingModelDownload;
}
} finally {
activeFaceEmbeddingModelDownload = undefined;
}
};
let _faceEmbeddingSession: Promise<ort.InferenceSession> | undefined;
const faceEmbeddingSession = async () => {
if (!_faceEmbeddingSession) {
_faceEmbeddingSession =
faceEmbeddingModelPathDownloadingIfNeeded().then((modelPath) =>
createInferenceSession(modelPath),
);
}
return _faceEmbeddingSession;
};
export const detectFaces = async (input: Float32Array) => {
const session = await faceDetectionSession();
const t = Date.now();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
};
const results = await session.run(feeds);
log.debug(() => `onnx/yolo face detection took ${Date.now() - t} ms`);
return results["output"].data;
};
export const faceEmbedding = async (input: Float32Array) => {
// Dimension of each face (alias)
const mobileFaceNetFaceSize = 112;
// Smaller alias
const z = mobileFaceNetFaceSize;
// Size of each face's data in the batch
const n = Math.round(input.length / (z * z * 3));
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
const session = await faceEmbeddingSession();
const t = Date.now();
const feeds = { img_inputs: inputTensor };
const results = await session.run(feeds);
log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`);
// TODO: What's with this type? It works in practice, but double check.
return (results.embeddings as unknown as any)["cpuData"]; // as Float32Array;
};

View File

@@ -0,0 +1,79 @@
/**
* @file AI/ML related functionality.
*
* @see also `ml-clip.ts`, `ml-face.ts`.
*
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
* for various tasks are not shipped with the app but are downloaded on demand.
*
* The primary reason for doing these tasks in the Node.js layer is so that we
* can use the binary ONNX runtime which is 10-20x faster than the WASM based
* web one.
*/
import { app, net } from "electron/main";
import { existsSync } from "fs";
import fs from "node:fs/promises";
import path from "node:path";
import * as ort from "onnxruntime-node";
import { writeStream } from "../fs";
import log from "../log";
/**
* Download the model named {@link modelName} if we don't already have it.
*
* Also verify that the size of the model we get matches {@expectedByteSize} (if
* not, redownload it).
*
* @returns the path to the model on the local machine.
*/
export const modelPathDownloadingIfNeeded = async (
modelName: string,
expectedByteSize: number,
) => {
const modelPath = modelSavePath(modelName);
if (!existsSync(modelPath)) {
log.info("CLIP image model not found, downloading");
await downloadModel(modelPath, modelName);
} else {
const size = (await fs.stat(modelPath)).size;
if (size !== expectedByteSize) {
log.error(
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
);
await downloadModel(modelPath, modelName);
}
}
return modelPath;
};
/** Return the path where the given {@link modelName} is meant to be saved */
export const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
export const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
// Download
log.info(`Downloading ML model from ${name}`);
const url = `https://models.ente.io/${name}`;
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
// Save
await writeStream(saveLocation, res.body);
log.info(`Downloaded CLIP model ${name}`);
};
/**
* Crete an ONNX {@link InferenceSession} with some defaults.
*/
export const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
// Restrict the number of threads to 1
intraOpNumThreads: 1,
// Be more conservative with RAM usage
enableCpuMemArena: false,
});
};

View File

@@ -143,6 +143,12 @@ const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
const clipTextEmbedding = (text: string): Promise<Float32Array> =>
ipcRenderer.invoke("clipTextEmbedding", text);
const detectFaces = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("detectFaces", input);
const faceEmbedding = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("faceEmbedding", input);
// - File selection
// TODO: Deprecated - use dialogs on the renderer process itself
@@ -322,6 +328,8 @@ contextBridge.exposeInMainWorld("electron", {
// - ML
clipImageEmbedding,
clipTextEmbedding,
detectFaces,
faceEmbedding,
// - File selection
selectDirectory,

View File

@@ -53,15 +53,4 @@ export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
enabled: false,
};
export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000;
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
export const MAX_ML_SYNC_ERROR_COUNT = 1;
export const TEXT_DETECTION_TIMEOUT_MS = [10000, 30000, 60000, 120000, 240000];
export const MOBILEFACENET_FACE_SIZE = 112;
export const MOBILEFACENET_EMBEDDING_SIZE = 192;
export const BATCHES_BEFORE_SYNCING_INDEX = 5;

View File

@@ -1,4 +1,6 @@
import { inWorker } from "@/next/env";
import log from "@/next/log";
import { workerBridge } from "@/next/worker/worker-bridge";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError } from "@ente/shared/error";
import HTTPService from "@ente/shared/network/HTTPService";
@@ -262,7 +264,9 @@ export const putEmbedding = async (
putEmbeddingReq: PutEmbeddingRequest,
): Promise<EncryptedEmbedding> => {
try {
const token = getToken();
const token = inWorker()
? await workerBridge.getAuthToken()
: getToken();
if (!token) {
log.info("putEmbedding failed: token not found");
throw Error(CustomError.TOKEN_MISSING);

View File

@@ -55,7 +55,7 @@ class FaceService {
await syncContext.faceDetectionService.detectFaces(imageBitmap);
console.timeEnd(timerId);
console.log("faceDetections: ", faceDetections?.length);
// log.info('3 TF Memory stats: ',JSON.stringify(tf.memory()));
// TODO: reenable faces filtering based on width
const detectedFaces = faceDetections?.map((detection) => {
return {
@@ -150,7 +150,7 @@ class FaceService {
imageBitmap.close();
log.info("[MLService] alignedFaces: ", newMlFile.faces?.length);
// log.info('4 TF Memory stats: ',JSON.stringify(tf.memory()));
return faceImages;
}
@@ -187,7 +187,6 @@ class FaceService {
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length);
// log.info('5 TF Memory stats: ',JSON.stringify(tf.memory()));
}
async syncFileFaceMakeRelativeDetections(
@@ -226,11 +225,21 @@ class FaceService {
face.detection,
syncContext.config.faceCrop,
);
face.crop = await storeFaceCrop(
face.id,
faceCrop,
syncContext.config.faceCrop.blobOptions,
);
try {
face.crop = await storeFaceCrop(
face.id,
faceCrop,
syncContext.config.faceCrop.blobOptions,
);
} catch (e) {
// TODO(MR): Temporarily ignoring errors about failing cache puts
// when using a custom scheme in Electron. Needs an alternative
// approach, perhaps OPFS.
console.error(
"Ignoring error when caching face crop, the face crop will not be available",
e,
);
}
const blob = await imageBitmapToBlob(faceCrop.image);
faceCrop.image.close();
return blob;

View File

@@ -1,10 +1,10 @@
import { MOBILEFACENET_FACE_SIZE } from "constants/mlConfig";
import {
BlurDetectionMethod,
BlurDetectionService,
Versioned,
} from "types/machineLearning";
import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image";
import { mobileFaceNetFaceSize } from "./mobileFaceNetEmbeddingService";
class LaplacianBlurDetectionService implements BlurDetectionService {
public method: Versioned<BlurDetectionMethod>;
@@ -19,7 +19,7 @@ class LaplacianBlurDetectionService implements BlurDetectionService {
public detectBlur(alignedFaces: Float32Array): number[] {
const numFaces = Math.round(
alignedFaces.length /
(MOBILEFACENET_FACE_SIZE * MOBILEFACENET_FACE_SIZE * 3),
(mobileFaceNetFaceSize * mobileFaceNetFaceSize * 3),
);
const blurValues: number[] = [];
for (let i = 0; i < numFaces; i++) {

View File

@@ -203,9 +203,6 @@ export class LocalMLSyncContext implements MLSyncContext {
}
public async dispose() {
// await this.faceDetectionService.dispose();
// await this.faceEmbeddingService.dispose();
this.localFilesMap = undefined;
await this.syncQueue.onIdle();
this.syncQueue.removeAllListeners();

View File

@@ -186,8 +186,7 @@ class MLWorkManager {
return mlWorker.syncLocalFile(token, userID, enteFile, localFile);
});
// @ts-expect-error "TODO: Fix ML related type errors"
if ("message" in result) {
if (result instanceof Error) {
// TODO: redirect/refresh to gallery in case of session_expired
// may not be required as uploader should anyways take care of this
console.error("Error while syncing local file: ", result);

View File

@@ -1,7 +1,4 @@
import {
MOBILEFACENET_EMBEDDING_SIZE,
MOBILEFACENET_FACE_SIZE,
} from "constants/mlConfig";
import { workerBridge } from "@/next/worker/worker-bridge";
import {
FaceEmbedding,
FaceEmbeddingMethod,
@@ -9,17 +6,9 @@ import {
Versioned,
} from "types/machineLearning";
// TODO(MR): onnx-yolo
// import * as ort from "onnxruntime-web";
// import { env } from "onnxruntime-web";
const ort: any = {};
export const mobileFaceNetFaceSize = 112;
// TODO(MR): onnx-yolo
// env.wasm.wasmPaths = "/js/onnx/";
class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
// TODO(MR): onnx-yolo
// private onnxInferenceSession?: ort.InferenceSession;
private onnxInferenceSession?: any;
public method: Versioned<FaceEmbeddingMethod>;
public faceSize: number;
@@ -28,80 +17,25 @@ class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
value: "MobileFaceNet",
version: 2,
};
this.faceSize = MOBILEFACENET_FACE_SIZE;
// TODO: set timeout
}
private async initOnnx() {
console.log("start ort mobilefacenet");
this.onnxInferenceSession = await ort.InferenceSession.create(
"/models/mobilefacenet/mobilefacenet_opset15.onnx",
);
const faceBatchSize = 1;
const data = new Float32Array(
faceBatchSize * 3 * this.faceSize * this.faceSize,
);
const inputTensor = new ort.Tensor("float32", data, [
faceBatchSize,
this.faceSize,
this.faceSize,
3,
]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
const name = this.onnxInferenceSession.inputNames[0];
feeds[name] = inputTensor;
await this.onnxInferenceSession.run(feeds);
console.log("start end mobilefacenet");
}
private async getOnnxInferenceSession() {
if (!this.onnxInferenceSession) {
await this.initOnnx();
}
return this.onnxInferenceSession;
this.faceSize = mobileFaceNetFaceSize;
}
public async getFaceEmbeddings(
faceData: Float32Array,
): Promise<Array<FaceEmbedding>> {
const inputTensor = new ort.Tensor("float32", faceData, [
Math.round(faceData.length / (this.faceSize * this.faceSize * 3)),
this.faceSize,
this.faceSize,
3,
]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
feeds["img_inputs"] = inputTensor;
const inferenceSession = await this.getOnnxInferenceSession();
// TODO(MR): onnx-yolo
// const runout: ort.InferenceSession.OnnxValueMapType =
const runout: any = await inferenceSession.run(feeds);
// const test = runout.embeddings;
// const test2 = test.cpuData;
const outputData = runout.embeddings["cpuData"] as Float32Array;
const outputData = await workerBridge.faceEmbedding(faceData);
const embeddingSize = 192;
const embeddings = new Array<FaceEmbedding>(
outputData.length / MOBILEFACENET_EMBEDDING_SIZE,
outputData.length / embeddingSize,
);
for (let i = 0; i < embeddings.length; i++) {
embeddings[i] = new Float32Array(
outputData.slice(
i * MOBILEFACENET_EMBEDDING_SIZE,
(i + 1) * MOBILEFACENET_EMBEDDING_SIZE,
),
outputData.slice(i * embeddingSize, (i + 1) * embeddingSize),
);
}
return embeddings;
}
public async dispose() {
const inferenceSession = await this.getOnnxInferenceSession();
inferenceSession?.release();
this.onnxInferenceSession = undefined;
}
}
export default new MobileFaceNetEmbeddingService();

View File

@@ -1,4 +1,5 @@
import { MAX_FACE_DISTANCE_PERCENT } from "constants/mlConfig";
import { workerBridge } from "@/next/worker/worker-bridge";
import { euclidean } from "hdbscan";
import {
Matrix,
applyToPoint,
@@ -21,17 +22,7 @@ import {
import { newBox } from "utils/machineLearning";
import { Box, Point } from "../../../thirdparty/face-api/classes";
// TODO(MR): onnx-yolo
// import * as ort from "onnxruntime-web";
// import { env } from "onnxruntime-web";
const ort: any = {};
// TODO(MR): onnx-yolo
// env.wasm.wasmPaths = "/js/onnx/";
class YoloFaceDetectionService implements FaceDetectionService {
// TODO(MR): onnx-yolo
// private onnxInferenceSession?: ort.InferenceSession;
private onnxInferenceSession?: any;
public method: Versioned<FaceDetectionMethod>;
public constructor() {
@@ -41,27 +32,38 @@ class YoloFaceDetectionService implements FaceDetectionService {
};
}
private async initOnnx() {
console.log("start ort");
this.onnxInferenceSession = await ort.InferenceSession.create(
"/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
public async detectFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
const maxFaceDistancePercent = Math.sqrt(2) / 100;
const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
const preprocessResult =
this.preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const outputData = await workerBridge.detectFaces(data);
const faces = this.getFacesFromYoloOutput(
outputData as Float32Array,
0.7,
);
const data = new Float32Array(1 * 3 * 640 * 640);
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
const name = this.onnxInferenceSession.inputNames[0];
feeds[name] = inputTensor;
await this.onnxInferenceSession.run(feeds);
console.log("start end");
}
private async getOnnxInferenceSession() {
if (!this.onnxInferenceSession) {
await this.initOnnx();
}
return this.onnxInferenceSession;
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
}
private preprocessImageBitmapToFloat32ChannelsFirst(
@@ -156,43 +158,6 @@ class YoloFaceDetectionService implements FaceDetectionService {
};
}
/**
* @deprecated The method should not be used
*/
private imageBitmapToTensorData(imageBitmap) {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
const data = new Float32Array(
1 * 3 * imageBitmap.width * imageBitmap.height,
);
// Populate the Float32Array with normalized pixel values
for (let i = 0; i < pixelData.length; i += 4) {
// Normalize pixel values to the range [0, 1]
data[i / 4] = pixelData[i] / 255.0; // Red channel
data[i / 4 + imageBitmap.width * imageBitmap.height] =
pixelData[i + 1] / 255.0; // Green channel
data[i / 4 + 2 * imageBitmap.width * imageBitmap.height] =
pixelData[i + 2] / 255.0; // Blue channel
}
return {
data: data,
shape: [1, 3, imageBitmap.width, imageBitmap.height],
};
}
// The rowOutput is a Float32Array of shape [25200, 16], where each row represents a bounding box.
private getFacesFromYoloOutput(
rowOutput: Float32Array,
@@ -270,64 +235,10 @@ class YoloFaceDetectionService implements FaceDetectionService {
probability: faceDetection.probability,
};
}
private async estimateOnnx(imageBitmap: ImageBitmap) {
const maxFaceDistance = imageBitmap.width * MAX_FACE_DISTANCE_PERCENT;
const preprocessResult =
this.preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
feeds["input"] = inputTensor;
const inferenceSession = await this.getOnnxInferenceSession();
const runout = await inferenceSession.run(feeds);
const outputData = runout.output.data;
const faces = this.getFacesFromYoloOutput(
outputData as Float32Array,
0.7,
);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
}
public async detectFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
// measure time taken
const facesFromOnnx = await this.estimateOnnx(imageBitmap);
return facesFromOnnx;
}
public async dispose() {
const inferenceSession = await this.getOnnxInferenceSession();
inferenceSession?.release();
this.onnxInferenceSession = undefined;
}
}
export default new YoloFaceDetectionService();
import { euclidean } from "hdbscan";
/**
* Removes duplicate face detections from an array of detections.
*

View File

@@ -45,16 +45,13 @@ export declare type Landmark = Point;
export declare type ImageType = "Original" | "Preview";
export declare type FaceDetectionMethod = "FaceApiSSD" | "YoloFace";
export declare type FaceDetectionMethod = "YoloFace";
export declare type FaceCropMethod = "ArcFace";
export declare type FaceAlignmentMethod =
| "ArcFace"
| "FaceApiDlib"
| "RotatedFaceApiDlib";
export declare type FaceAlignmentMethod = "ArcFace";
export declare type FaceEmbeddingMethod = "MobileFaceNet" | "FaceApiDlib";
export declare type FaceEmbeddingMethod = "MobileFaceNet";
export declare type BlurDetectionMethod = "Laplacian";
@@ -261,13 +258,12 @@ export declare type MLIndex = "files" | "people";
export interface FaceDetectionService {
method: Versioned<FaceDetectionMethod>;
// init(): Promise<void>;
detectFaces(image: ImageBitmap): Promise<Array<FaceDetection>>;
getRelativeDetection(
faceDetection: FaceDetection,
imageDimensions: Dimensions,
): FaceDetection;
dispose(): Promise<void>;
}
export interface FaceCropService {
@@ -288,9 +284,8 @@ export interface FaceAlignmentService {
export interface FaceEmbeddingService {
method: Versioned<FaceEmbeddingMethod>;
faceSize: number;
// init(): Promise<void>;
getFaceEmbeddings(faceImages: Float32Array): Promise<Array<FaceEmbedding>>;
dispose(): Promise<void>;
}
export interface BlurDetectionService {

View File

@@ -50,7 +50,7 @@ export class SimpleJob<R extends JobResult> {
try {
const jobResult = await this.runCallback();
if (jobResult.shouldBackoff) {
if (jobResult && jobResult.shouldBackoff) {
this.intervalSec = Math.min(
this.config.maxItervalSec,
this.intervalSec * this.config.backoffMultiplier,

View File

@@ -124,6 +124,10 @@ class MLIDbStorage {
.add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME);
}
if (oldVersion < 4) {
// TODO(MR): This loses the user's settings.
db.deleteObjectStore("configs");
db.createObjectStore("configs");
db.deleteObjectStore("things");
}

View File

@@ -196,7 +196,7 @@ export interface Electron {
// - ML
/**
* Compute and return a CLIP embedding of the given image.
* Return a CLIP embedding of the given image.
*
* See: [Note: CLIP based magic search]
*
@@ -207,7 +207,7 @@ export interface Electron {
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
/**
* Compute and return a CLIP embedding of the given image.
* Return a CLIP embedding of the given image.
*
* See: [Note: CLIP based magic search]
*
@@ -217,6 +217,22 @@ export interface Electron {
*/
clipTextEmbedding: (text: string) => Promise<Float32Array>;
/**
* Detect faces in the given image using YOLO.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (YOLO) we use.
*/
detectFaces: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a MobileFaceNet embedding for the given face data.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (MobileFaceNet) we use.
*/
faceEmbedding: (input: Float32Array) => Promise<Float32Array>;
// - File selection
// TODO: Deprecated - use dialogs on the renderer process itself

View File

@@ -35,6 +35,19 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
}
}
// TODO(MR): Temporary method to forward auth tokens to workers
const getAuthToken = () => {
// LS_KEYS.USER
const userJSONString = localStorage.getItem("user");
if (!userJSONString) return undefined;
const json: unknown = JSON.parse(userJSONString);
if (!json || typeof json != "object" || !("token" in json))
return undefined;
const token = json.token;
if (typeof token != "string") return undefined;
return token;
};
/**
* A minimal set of utility functions that we expose to all workers that we
* create.
@@ -44,8 +57,12 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
*/
const workerBridge = {
logToDisk,
getAuthToken,
convertToJPEG: (inputFileData: Uint8Array, filename: string) =>
ensureElectron().convertToJPEG(inputFileData, filename),
detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
faceEmbedding: (input: Float32Array) =>
ensureElectron().faceEmbedding(input),
};
export type WorkerBridge = typeof workerBridge;