This commit is contained in:
Manav Rathi
2024-07-04 19:13:21 +05:30
parent 5698ee12a3
commit 38e23ca1b5
5 changed files with 38 additions and 38 deletions

View File

@@ -27,7 +27,7 @@ import {
pixelRGBBilinear,
warpAffineFloat32List,
} from "./image";
import type { MLWorkerDelegate } from "./worker-delegate";
import type { MLWorkerElectron } from "./worker-electron";
/**
* The version of the face indexing pipeline implemented by the current client.
@@ -221,7 +221,7 @@ export interface Box {
export const indexFaces = async (
enteFile: EnteFile,
file: File | undefined,
delegate: MLWorkerDelegate,
electron: MLWorkerElectron,
userAgent: string,
) => {
const imageBitmap = await renderableImageBitmap(enteFile, file);
@@ -236,7 +236,7 @@ export const indexFaces = async (
faceEmbedding: {
version: faceIndexingVersion,
client: userAgent,
faces: await indexFacesInBitmap(fileID, imageBitmap, delegate),
faces: await indexFacesInBitmap(fileID, imageBitmap, electron),
},
};
// This step, saving face crops, is not part of the indexing pipeline;
@@ -257,12 +257,12 @@ export const indexFaces = async (
const indexFacesInBitmap = async (
fileID: number,
imageBitmap: ImageBitmap,
delegate: MLWorkerDelegate,
electron: MLWorkerElectron,
): Promise<Face[]> => {
const { width, height } = imageBitmap;
const imageDimensions = { width, height };
const yoloFaceDetections = await detectFaces(imageBitmap, delegate);
const yoloFaceDetections = await detectFaces(imageBitmap, electron);
const partialResult = yoloFaceDetections.map(
({ box, landmarks, score }) => {
const faceID = makeFaceID(fileID, box, imageDimensions);
@@ -283,7 +283,7 @@ const indexFacesInBitmap = async (
alignments,
);
const embeddings = await computeEmbeddings(alignedFacesData, delegate);
const embeddings = await computeEmbeddings(alignedFacesData, electron);
const blurs = detectBlur(
alignedFacesData,
partialResult.map((f) => f.detection),
@@ -305,7 +305,7 @@ const indexFacesInBitmap = async (
*/
const detectFaces = async (
imageBitmap: ImageBitmap,
delegate: MLWorkerDelegate,
electron: MLWorkerElectron,
): Promise<YOLOFaceDetection[]> => {
const rect = ({ width, height }: Dimensions) => ({
x: 0,
@@ -316,7 +316,7 @@ const detectFaces = async (
const { yoloInput, yoloSize } =
convertToYOLOInputFloat32ChannelsFirst(imageBitmap);
const yoloOutput = await delegate.detectFaces(yoloInput);
const yoloOutput = await electron.detectFaces(yoloInput);
const faces = filterExtractDetectionsFromYOLOOutput(yoloOutput);
const faceDetections = transformYOLOFaceDetections(
faces,
@@ -874,9 +874,9 @@ const mobileFaceNetEmbeddingSize = 192;
*/
const computeEmbeddings = async (
faceData: Float32Array,
delegate: MLWorkerDelegate,
electron: MLWorkerElectron,
): Promise<Float32Array[]> => {
const outputData = await delegate.computeFaceEmbeddings(faceData);
const outputData = await electron.computeFaceEmbeddings(faceData);
const embeddingSize = mobileFaceNetEmbeddingSize;
const embeddings = new Array<Float32Array>(

View File

@@ -43,7 +43,7 @@ const worker = async () => {
const createComlinkWorker = async () => {
const electron = ensureElectron();
const delegate = {
const mlWorkerElectron = {
appVersion: electron.appVersion,
detectFaces: electron.detectFaces,
computeFaceEmbeddings: electron.computeFaceEmbeddings,
@@ -53,7 +53,7 @@ const createComlinkWorker = async () => {
"ml",
new Worker(new URL("worker.ts", import.meta.url)),
);
await cw.remote.then((w) => w.init(proxy(delegate)));
await cw.remote.then((w) => w.init(proxy(mlWorkerElectron)));
return cw;
};

View File

@@ -1,14 +0,0 @@
/**
* Callback functions made available to the ML worker.
*
* These primarily allow the worker code to access the functions exposed by our
* desktop app. Normally code accesses this via `globalThis.electron`, but that
* doesn't exist in the execution context of web workers.
*
* As such, this is currently a subset of {@link Electron}.
*/
export interface MLWorkerDelegate {
appVersion: () => Promise<string>;
detectFaces: (input: Float32Array) => Promise<Float32Array>;
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
}

View File

@@ -0,0 +1,13 @@
/**
* A subset of {@link Electron} provided to the {@link MLWorker}.
*
* `globalThis.electron` does not exist in the execution context of web workers.
* So instead, we manually provide a proxy object of type
* {@link MLWorkerElectron} that exposes a subset of the functions from
* {@link Electron} that are needed by the code running in the ML web worker.
*/
export interface MLWorkerElectron {
appVersion: () => Promise<string>;
detectFaces: (input: Float32Array) => Promise<Float32Array>;
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
}

View File

@@ -18,7 +18,7 @@ import {
} from "./db";
import { pullFaceEmbeddings, putFaceIndex } from "./embedding";
import { type FaceIndex, indexFaces } from "./face";
import type { MLWorkerDelegate } from "./worker-delegate";
import type { MLWorkerElectron } from "./worker-electron";
const idleDurationStart = 5; /* 5 seconds */
const idleDurationMax = 16 * 60; /* 16 minutes */
@@ -46,7 +46,7 @@ const idleDurationMax = 16 * 60; /* 16 minutes */
* - "idle": in between state transitions
*/
export class MLWorker {
private delegate: MLWorkerDelegate | undefined;
private electron: MLWorkerElectron | undefined;
private userAgent: string | undefined;
private shouldSync = false;
private liveQ: { enteFile: EnteFile; uploadItem: UploadItem }[] = [];
@@ -60,13 +60,14 @@ export class MLWorker {
* This is conceptually the constructor, however it is easier to have this
* as a separate function to avoid confounding the comlink types too much.
*
* @param delegate The {@link MLWorkerDelegate} that allows the worker to
* call back into the main thread.
* @param electron The {@link MLWorkerElectron} that allows the worker to
* use the functionality provided by our Node.js layer when running in the
* context of our desktop app
*/
async init(delegate: MLWorkerDelegate) {
this.delegate = delegate;
async init(electron: MLWorkerElectron) {
this.electron = electron;
// Set the user agent that'll be set in the generated embeddings.
this.userAgent = `${clientPackageName}/${await delegate.appVersion()}`;
this.userAgent = `${clientPackageName}/${await electron.appVersion()}`;
// Initialize the downloadManager running in the web worker with the
// user's token. It'll be used to download files to index if needed.
await downloadManager.init(await ensureAuthToken());
@@ -163,7 +164,7 @@ export class MLWorker {
this.state = "indexing";
const allSuccess = await indexNextBatch(
liveQ,
ensure(this.delegate),
ensure(this.electron),
ensure(this.userAgent),
);
if (allSuccess) {
@@ -208,7 +209,7 @@ const pull = pullFaceEmbeddings;
*/
const indexNextBatch = async (
liveQ: EnteFile[],
delegate: MLWorkerDelegate,
electron: MLWorkerElectron,
userAgent: string,
) => {
if (!self.navigator.onLine) {
@@ -227,7 +228,7 @@ const indexNextBatch = async (
let allSuccess = true;
for (const file of files) {
try {
await index(file, undefined, delegate, userAgent);
await index(file, undefined, electron, userAgent);
// Possibly unnecessary, but let us drain the microtask queue.
await wait(0);
} catch {
@@ -288,7 +289,7 @@ const syncWithLocalFilesAndGetFilesToIndex = async (
export const index = async (
enteFile: EnteFile,
file: File | undefined,
delegate: MLWorkerDelegate,
electron: MLWorkerElectron,
userAgent: string,
) => {
const f = fileLogID(enteFile);
@@ -296,7 +297,7 @@ export const index = async (
let faceIndex: FaceIndex;
try {
faceIndex = await indexFaces(enteFile, file, delegate, userAgent);
faceIndex = await indexFaces(enteFile, file, electron, userAgent);
} catch (e) {
// Mark indexing as having failed only if the indexing itself
// failed, not if there were subsequent failures (like when trying