diff --git a/web/packages/new/photos/services/ml/face.ts b/web/packages/new/photos/services/ml/face.ts index ffaad3bb16..394421ed0f 100644 --- a/web/packages/new/photos/services/ml/face.ts +++ b/web/packages/new/photos/services/ml/face.ts @@ -9,7 +9,6 @@ import type { EnteFile } from "@/new/photos/types/file"; import log from "@/next/log"; -import { workerBridge } from "@/next/worker/worker-bridge"; import { ensure } from "@/utils/ensure"; import { Matrix } from "ml-matrix"; import { getSimilarityTransformation } from "similarity-transformation"; @@ -28,6 +27,7 @@ import { pixelRGBBilinear, warpAffineFloat32List, } from "./image"; +import type { MLWorkerDelegate } from "./worker-delegate"; /** * The version of the face indexing pipeline implemented by the current client. @@ -221,6 +221,7 @@ export interface Box { export const indexFaces = async ( enteFile: EnteFile, file: File | undefined, + delegate: MLWorkerDelegate, userAgent: string, ) => { const imageBitmap = await renderableImageBitmap(enteFile, file); @@ -235,7 +236,7 @@ export const indexFaces = async ( faceEmbedding: { version: faceIndexingVersion, client: userAgent, - faces: await indexFacesInBitmap(fileID, imageBitmap), + faces: await indexFacesInBitmap(fileID, imageBitmap, delegate), }, }; // This step, saving face crops, is not part of the indexing pipeline; @@ -256,11 +257,12 @@ export const indexFaces = async ( const indexFacesInBitmap = async ( fileID: number, imageBitmap: ImageBitmap, + delegate: MLWorkerDelegate, ): Promise => { const { width, height } = imageBitmap; const imageDimensions = { width, height }; - const yoloFaceDetections = await detectFaces(imageBitmap); + const yoloFaceDetections = await detectFaces(imageBitmap, delegate); const partialResult = yoloFaceDetections.map( ({ box, landmarks, score }) => { const faceID = makeFaceID(fileID, box, imageDimensions); @@ -281,7 +283,7 @@ const indexFacesInBitmap = async ( alignments, ); - const embeddings = await computeEmbeddings(alignedFacesData); + const embeddings = await computeEmbeddings(alignedFacesData, delegate); const blurs = detectBlur( alignedFacesData, partialResult.map((f) => f.detection), @@ -303,6 +305,7 @@ const indexFacesInBitmap = async ( */ const detectFaces = async ( imageBitmap: ImageBitmap, + delegate: MLWorkerDelegate, ): Promise => { const rect = ({ width, height }: Dimensions) => ({ x: 0, @@ -313,7 +316,7 @@ const detectFaces = async ( const { yoloInput, yoloSize } = convertToYOLOInputFloat32ChannelsFirst(imageBitmap); - const yoloOutput = await workerBridge.detectFaces(yoloInput); + const yoloOutput = await delegate.detectFaces(yoloInput); const faces = filterExtractDetectionsFromYOLOOutput(yoloOutput); const faceDetections = transformYOLOFaceDetections( faces, @@ -871,8 +874,9 @@ const mobileFaceNetEmbeddingSize = 192; */ const computeEmbeddings = async ( faceData: Float32Array, + delegate: MLWorkerDelegate, ): Promise => { - const outputData = await workerBridge.computeFaceEmbeddings(faceData); + const outputData = await delegate.computeFaceEmbeddings(faceData); const embeddingSize = mobileFaceNetEmbeddingSize; const embeddings = new Array( diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts index 92e1606631..c61014751f 100644 --- a/web/packages/new/photos/services/ml/index.ts +++ b/web/packages/new/photos/services/ml/index.ts @@ -8,11 +8,12 @@ import { isInternalUser, } from "@/new/photos/services/feature-flags"; import type { EnteFile } from "@/new/photos/types/file"; -import { clientPackageName, isDesktop } from "@/next/app"; +import { isDesktop } from "@/next/app"; import { blobCache } from "@/next/blob-cache"; import { ensureElectron } from "@/next/electron"; import log from "@/next/log"; import { ComlinkWorker } from "@/next/worker/comlink-worker"; +import { proxy } from "comlink"; import type { UploadItem } from "../upload/types"; import { regenerateFaceCrops } from "./crop"; import { clearFaceDB, faceIndex, indexableAndIndexedCounts } from "./db"; @@ -41,18 +42,21 @@ const worker = async () => { }; const createComlinkWorker = async () => { + const electron = ensureElectron(); + const delegate = { + appVersion: electron.appVersion, + detectFaces: electron.detectFaces, + computeFaceEmbeddings: electron.computeFaceEmbeddings, + }; + const cw = new ComlinkWorker( "ml", new Worker(new URL("worker.ts", import.meta.url)), ); - const ua = await getUserAgent(); - await cw.remote.then((w) => w.init(ua)); + await cw.remote.then((w) => w.init(proxy(delegate))); return cw; }; -const getUserAgent = async () => - `${clientPackageName}/${await ensureElectron().appVersion()}`; - /** * Terminate {@link worker} (if any). * @@ -173,12 +177,13 @@ export const triggerMLSync = () => { * @param enteFile The {@link EnteFile} that got uploaded. * * @param uploadItem The item that was uploaded. This can be used to get at the - * contents of the file that got uploaded. + * contents of the file that got uploaded. In case of live photos, this is the + * image part of the live photo that was uploaded. */ export const indexNewUpload = (enteFile: EnteFile, uploadItem: UploadItem) => { if (!_isMLEnabled) return; if (enteFile.metadata.fileType !== FILE_TYPE.IMAGE) return; - log.debug(() => ({ t: "ml-liveq", enteFile, uploadItem })); + log.debug(() => ({ t: "ml/liveq", enteFile, uploadItem })); void worker().then((w) => w.onUpload(enteFile, uploadItem)); }; diff --git a/web/packages/new/photos/services/ml/worker-delegate.ts b/web/packages/new/photos/services/ml/worker-delegate.ts new file mode 100644 index 0000000000..934f07f8b6 --- /dev/null +++ b/web/packages/new/photos/services/ml/worker-delegate.ts @@ -0,0 +1,14 @@ +/** + * Callback functions made available to the ML worker. + * + * These primarily allow the worker code to access the functions exposed by our + * desktop app. Normally code accesses this via `globalThis.electron`, but that + * doesn't exist in the execution context of web workers. + * + * As such, this is currently a subset of {@link Electron}. + */ +export interface MLWorkerDelegate { + appVersion: () => Promise; + detectFaces: (input: Float32Array) => Promise; + computeFaceEmbeddings: (input: Float32Array) => Promise; +} diff --git a/web/packages/new/photos/services/ml/worker.ts b/web/packages/new/photos/services/ml/worker.ts index 4a67c18083..162e20ae06 100644 --- a/web/packages/new/photos/services/ml/worker.ts +++ b/web/packages/new/photos/services/ml/worker.ts @@ -1,5 +1,6 @@ import type { EnteFile } from "@/new/photos/types/file"; import { fileLogID } from "@/new/photos/utils/file"; +import { clientPackageName } from "@/next/app"; import { getKVN } from "@/next/kv"; import { ensureAuthToken } from "@/next/local-user"; import log from "@/next/log"; @@ -17,6 +18,7 @@ import { } from "./db"; import { pullFaceEmbeddings, putFaceIndex } from "./embedding"; import { type FaceIndex, indexFaces } from "./face"; +import type { MLWorkerDelegate } from "./worker-delegate"; const idleDurationStart = 5; /* 5 seconds */ const idleDurationMax = 16 * 60; /* 16 minutes */ @@ -44,6 +46,7 @@ const idleDurationMax = 16 * 60; /* 16 minutes */ * - "idle": in between state transitions */ export class MLWorker { + private delegate: MLWorkerDelegate | undefined; private userAgent: string | undefined; private shouldSync = false; private liveQ: { enteFile: EnteFile; uploadItem: UploadItem }[] = []; @@ -57,11 +60,13 @@ export class MLWorker { * This is conceptually the constructor, however it is easier to have this * as a separate function to avoid confounding the comlink types too much. * - * @param userAgent The user agent string to use as the client field in the - * embeddings generated during indexing by this client. + * @param delegate The {@link MLWorkerDelegate} that allows the worker to + * call back into the main thread. */ - async init(userAgent: string) { - this.userAgent = userAgent; + async init(delegate: MLWorkerDelegate) { + this.delegate = delegate; + // Set the user agent that'll be set in the generated embeddings. + this.userAgent = `${clientPackageName}/${await delegate.appVersion()}`; // Initialize the downloadManager running in the web worker with the // user's token. It'll be used to download files to index if needed. await downloadManager.init(await ensureAuthToken()); @@ -130,7 +135,7 @@ export class MLWorker { private async tick() { log.debug(() => ({ - t: "ml-tick", + t: "ml/tick", state: this.state, shouldSync: this.shouldSync, liveQ: this.liveQ, @@ -156,7 +161,11 @@ export class MLWorker { const liveQ = this.liveQ.map((i) => i.enteFile); this.liveQ = []; this.state = "indexing"; - const allSuccess = await indexNextBatch(ensure(this.userAgent), liveQ); + const allSuccess = await indexNextBatch( + liveQ, + ensure(this.delegate), + ensure(this.userAgent), + ); if (allSuccess) { // Everything is running smoothly. Reset the idle duration. this.idleDuration = idleDurationStart; @@ -197,7 +206,11 @@ const pull = pullFaceEmbeddings; * Which means that when it returns true, all is well and there are more * things pending to process, so we should chug along at full speed. */ -const indexNextBatch = async (userAgent: string, liveQ: EnteFile[]) => { +const indexNextBatch = async ( + liveQ: EnteFile[], + delegate: MLWorkerDelegate, + userAgent: string, +) => { if (!self.navigator.onLine) { log.info("Skipping ML indexing since we are not online"); return false; @@ -214,7 +227,7 @@ const indexNextBatch = async (userAgent: string, liveQ: EnteFile[]) => { let allSuccess = true; for (const file of files) { try { - await index(file, undefined, userAgent); + await index(file, undefined, delegate, userAgent); // Possibly unnecessary, but let us drain the microtask queue. await wait(0); } catch { @@ -275,6 +288,7 @@ const syncWithLocalFilesAndGetFilesToIndex = async ( export const index = async ( enteFile: EnteFile, file: File | undefined, + delegate: MLWorkerDelegate, userAgent: string, ) => { const f = fileLogID(enteFile); @@ -282,7 +296,7 @@ export const index = async ( let faceIndex: FaceIndex; try { - faceIndex = await indexFaces(enteFile, file, userAgent); + faceIndex = await indexFaces(enteFile, file, delegate, userAgent); } catch (e) { // Mark indexing as having failed only if the indexing itself // failed, not if there were subsequent failures (like when trying diff --git a/web/packages/next/worker/comlink-worker.ts b/web/packages/next/worker/comlink-worker.ts index 1a6645d334..75a5a07020 100644 --- a/web/packages/next/worker/comlink-worker.ts +++ b/web/packages/next/worker/comlink-worker.ts @@ -3,6 +3,21 @@ import log, { logToDisk } from "@/next/log"; import { expose, wrap, type Remote } from "comlink"; import { ensureLocalUser } from "../local-user"; +/** + * A minimal wrapper for a web {@link Worker}, proxying a class of type T. + * + * `comlink` is a library that simplies working with web workers by + * transparently proxying objects across the boundary instead of us needing to + * work directly with the raw postMessage interface. + * + * This class is a thin wrapper over a common usage pattern of comlink. It takes + * a web worker ({@link Worker}) that is expected to have {@link expose}-ed a + * class of type T. It then makes available the main thread handle to this class + * as the {@link remote} property. + * + * It also exposes an object of type {@link WorkerBridge} _to_ the code running + * inside the web worker. + */ export class ComlinkWorker InstanceType> { public remote: Promise>>; private worker: Worker; @@ -17,7 +32,7 @@ export class ComlinkWorker InstanceType> { `Got error event from worker: ${JSON.stringify({ event, name })}`, ); }; - log.debug(() => `Initiated web worker ${name}`); + log.debug(() => `Created ${name} web worker`); const comlink = wrap(worker); this.remote = new comlink() as Promise>>; expose(workerBridge, worker); @@ -25,7 +40,7 @@ export class ComlinkWorker InstanceType> { public terminate() { this.worker.terminate(); - log.debug(() => `Terminated web worker ${this.name}`); + log.debug(() => `Terminated ${this.name} web worker`); } } @@ -40,15 +55,12 @@ export class ComlinkWorker InstanceType> { * is a motley bunch, we just inject them all. */ const workerBridge = { - // Needed: generally (presumably) + // Needed by all workers (likely, not necessarily). logToDisk, - // Needed by ML worker + // Needed by MLWorker. getAuthToken: () => ensureLocalUser().token, convertToJPEG: (imageData: Uint8Array) => ensureElectron().convertToJPEG(imageData), - detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input), - computeFaceEmbeddings: (input: Float32Array) => - ensureElectron().computeFaceEmbeddings(input), }; export type WorkerBridge = typeof workerBridge;