From a1d6ef43b4565733010e76b413090cdf61681729 Mon Sep 17 00:00:00 2001 From: Manav Rathi Date: Thu, 11 Apr 2024 14:37:44 +0530 Subject: [PATCH] Roundtrip --- desktop/src/main/ipc.ts | 4 +- desktop/src/main/services/ml-face.ts | 38 +++++++++++-------- desktop/src/preload.ts | 4 +- .../yoloFaceDetectionService.ts | 10 +---- web/packages/next/types/ipc.ts | 9 ++--- 5 files changed, 33 insertions(+), 32 deletions(-) diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index b6e8848183..180e68cdcf 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -147,8 +147,8 @@ export const attachIPCHandlers = () => { clipTextEmbedding(text), ); - ipcMain.handle("detectFaces", (_, imageData: Uint8Array) => - detectFaces(imageData), + ipcMain.handle("detectFaces", (_, input: Float32Array) => + detectFaces(input), ); ipcMain.handle("faceEmbedding", (_, input: Float32Array) => diff --git a/desktop/src/main/services/ml-face.ts b/desktop/src/main/services/ml-face.ts index 066f5406bf..78eb82bd15 100644 --- a/desktop/src/main/services/ml-face.ts +++ b/desktop/src/main/services/ml-face.ts @@ -78,8 +78,29 @@ const faceEmbeddingSession = async () => { return _faceEmbeddingSession; }; -export const detectFaces = async (inputImage: Uint8Array) => { - throw new Error("test"); +export const detectFaces = async (input: Float32Array) => { + // console.log("start ort"); + // this.onnxInferenceSession = await ort.InferenceSession.create( + // "/models/yoloface/yolov5s_face_640_640_dynamic.onnx", + // ); + // const data = new Float32Array(1 * 3 * 640 * 640); + // const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]); + // // TODO(MR): onnx-yolo + // // const feeds: Record = {}; + // const feeds: Record = {}; + // const name = this.onnxInferenceSession.inputNames[0]; + // feeds[name] = inputTensor; + // await this.onnxInferenceSession.run(feeds); + // console.log("start end"); + + const session = await faceDetectionSession(); + const t = Date.now(); + const feeds = { + input: new ort.Tensor("float32", input, [1, 3, 640, 640]), + }; + const results = await session.run(feeds); + log.debug(() => `onnx/yolo inference took ${Date.now() - t} ms`); + return results["output"].data; }; export const faceEmbedding = async (input: Float32Array) => { @@ -89,19 +110,6 @@ export const faceEmbedding = async (input: Float32Array) => { /* private async initOnnx() { - console.log("start ort"); - this.onnxInferenceSession = await ort.InferenceSession.create( - "/models/yoloface/yolov5s_face_640_640_dynamic.onnx", - ); - const data = new Float32Array(1 * 3 * 640 * 640); - const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]); - // TODO(MR): onnx-yolo - // const feeds: Record = {}; - const feeds: Record = {}; - const name = this.onnxInferenceSession.inputNames[0]; - feeds[name] = inputTensor; - await this.onnxInferenceSession.run(feeds); - console.log("start end"); } private async getOnnxInferenceSession() { diff --git a/desktop/src/preload.ts b/desktop/src/preload.ts index bea5c9e18f..2db39e2290 100644 --- a/desktop/src/preload.ts +++ b/desktop/src/preload.ts @@ -143,8 +143,8 @@ const clipImageEmbedding = (jpegImageData: Uint8Array): Promise => const clipTextEmbedding = (text: string): Promise => ipcRenderer.invoke("clipTextEmbedding", text); -const detectFaces = (imageData: Uint8Array): Promise => - ipcRenderer.invoke("detectFaces", imageData); +const detectFaces = (input: Float32Array): Promise => + ipcRenderer.invoke("detectFaces", input); const faceEmbedding = (input: Float32Array): Promise => ipcRenderer.invoke("faceEmbedding", input); diff --git a/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts b/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts index 02e5bb02b2..9efd31cbb7 100644 --- a/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts +++ b/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts @@ -1,3 +1,4 @@ +import { ensureElectron } from "@/next/electron"; import { MAX_FACE_DISTANCE_PERCENT } from "constants/mlConfig"; import { euclidean } from "hdbscan"; import { @@ -44,14 +45,7 @@ class YoloFaceDetectionService implements FaceDetectionService { ); const data = preprocessResult.data; const resized = preprocessResult.newSize; - const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]); - // TODO(MR): onnx-yolo - // const feeds: Record = {}; - const feeds: Record = {}; - feeds["input"] = inputTensor; - const inferenceSession = await this.getOnnxInferenceSession(); - const runout = await inferenceSession.run(feeds); - const outputData = runout.output.data; + const outputData = await ensureElectron().detectFaces(data); const faces = this.getFacesFromYoloOutput( outputData as Float32Array, 0.7, diff --git a/web/packages/next/types/ipc.ts b/web/packages/next/types/ipc.ts index 83d9ee6bdd..5b0979eaa2 100644 --- a/web/packages/next/types/ipc.ts +++ b/web/packages/next/types/ipc.ts @@ -221,16 +221,15 @@ export interface Electron { * Detect faces in the given image using YOLO. * * Both the input and output are opaque binary data whose internal structure - * is model (YOLO) and our implementation specific. That said, specifically - * the {@link inputImage} a particular bitmap encoding of an image. + * is specific to our implementation and the model (YOLO) we use. */ - detectFaces: (inputImage: Uint8Array) => Promise; + detectFaces: (input: Float32Array) => Promise; /** - * Return a mobilefacenet embedding for the given face data. + * Return a MobileFaceNet embedding for the given face data. * * Both the input and output are opaque binary data whose internal structure - * is model (mobilefacenet) and our implementation specific. + * is specific to our implementation and the model (MobileFaceNet) we use. */ faceEmbedding: (input: Float32Array) => Promise;