delegate
This commit is contained in:
@@ -9,7 +9,6 @@
|
||||
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import log from "@/next/log";
|
||||
import { workerBridge } from "@/next/worker/worker-bridge";
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import { Matrix } from "ml-matrix";
|
||||
import { getSimilarityTransformation } from "similarity-transformation";
|
||||
@@ -28,6 +27,7 @@ import {
|
||||
pixelRGBBilinear,
|
||||
warpAffineFloat32List,
|
||||
} from "./image";
|
||||
import type { MLWorkerDelegate } from "./worker-delegate";
|
||||
|
||||
/**
|
||||
* The version of the face indexing pipeline implemented by the current client.
|
||||
@@ -221,6 +221,7 @@ export interface Box {
|
||||
export const indexFaces = async (
|
||||
enteFile: EnteFile,
|
||||
file: File | undefined,
|
||||
delegate: MLWorkerDelegate,
|
||||
userAgent: string,
|
||||
) => {
|
||||
const imageBitmap = await renderableImageBitmap(enteFile, file);
|
||||
@@ -235,7 +236,7 @@ export const indexFaces = async (
|
||||
faceEmbedding: {
|
||||
version: faceIndexingVersion,
|
||||
client: userAgent,
|
||||
faces: await indexFacesInBitmap(fileID, imageBitmap),
|
||||
faces: await indexFacesInBitmap(fileID, imageBitmap, delegate),
|
||||
},
|
||||
};
|
||||
// This step, saving face crops, is not part of the indexing pipeline;
|
||||
@@ -256,11 +257,12 @@ export const indexFaces = async (
|
||||
const indexFacesInBitmap = async (
|
||||
fileID: number,
|
||||
imageBitmap: ImageBitmap,
|
||||
delegate: MLWorkerDelegate,
|
||||
): Promise<Face[]> => {
|
||||
const { width, height } = imageBitmap;
|
||||
const imageDimensions = { width, height };
|
||||
|
||||
const yoloFaceDetections = await detectFaces(imageBitmap);
|
||||
const yoloFaceDetections = await detectFaces(imageBitmap, delegate);
|
||||
const partialResult = yoloFaceDetections.map(
|
||||
({ box, landmarks, score }) => {
|
||||
const faceID = makeFaceID(fileID, box, imageDimensions);
|
||||
@@ -281,7 +283,7 @@ const indexFacesInBitmap = async (
|
||||
alignments,
|
||||
);
|
||||
|
||||
const embeddings = await computeEmbeddings(alignedFacesData);
|
||||
const embeddings = await computeEmbeddings(alignedFacesData, delegate);
|
||||
const blurs = detectBlur(
|
||||
alignedFacesData,
|
||||
partialResult.map((f) => f.detection),
|
||||
@@ -303,6 +305,7 @@ const indexFacesInBitmap = async (
|
||||
*/
|
||||
const detectFaces = async (
|
||||
imageBitmap: ImageBitmap,
|
||||
delegate: MLWorkerDelegate,
|
||||
): Promise<YOLOFaceDetection[]> => {
|
||||
const rect = ({ width, height }: Dimensions) => ({
|
||||
x: 0,
|
||||
@@ -313,7 +316,7 @@ const detectFaces = async (
|
||||
|
||||
const { yoloInput, yoloSize } =
|
||||
convertToYOLOInputFloat32ChannelsFirst(imageBitmap);
|
||||
const yoloOutput = await workerBridge.detectFaces(yoloInput);
|
||||
const yoloOutput = await delegate.detectFaces(yoloInput);
|
||||
const faces = filterExtractDetectionsFromYOLOOutput(yoloOutput);
|
||||
const faceDetections = transformYOLOFaceDetections(
|
||||
faces,
|
||||
@@ -871,8 +874,9 @@ const mobileFaceNetEmbeddingSize = 192;
|
||||
*/
|
||||
const computeEmbeddings = async (
|
||||
faceData: Float32Array,
|
||||
delegate: MLWorkerDelegate,
|
||||
): Promise<Float32Array[]> => {
|
||||
const outputData = await workerBridge.computeFaceEmbeddings(faceData);
|
||||
const outputData = await delegate.computeFaceEmbeddings(faceData);
|
||||
|
||||
const embeddingSize = mobileFaceNetEmbeddingSize;
|
||||
const embeddings = new Array<Float32Array>(
|
||||
|
||||
@@ -8,11 +8,12 @@ import {
|
||||
isInternalUser,
|
||||
} from "@/new/photos/services/feature-flags";
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import { clientPackageName, isDesktop } from "@/next/app";
|
||||
import { isDesktop } from "@/next/app";
|
||||
import { blobCache } from "@/next/blob-cache";
|
||||
import { ensureElectron } from "@/next/electron";
|
||||
import log from "@/next/log";
|
||||
import { ComlinkWorker } from "@/next/worker/comlink-worker";
|
||||
import { proxy } from "comlink";
|
||||
import type { UploadItem } from "../upload/types";
|
||||
import { regenerateFaceCrops } from "./crop";
|
||||
import { clearFaceDB, faceIndex, indexableAndIndexedCounts } from "./db";
|
||||
@@ -41,18 +42,21 @@ const worker = async () => {
|
||||
};
|
||||
|
||||
const createComlinkWorker = async () => {
|
||||
const electron = ensureElectron();
|
||||
const delegate = {
|
||||
appVersion: electron.appVersion,
|
||||
detectFaces: electron.detectFaces,
|
||||
computeFaceEmbeddings: electron.computeFaceEmbeddings,
|
||||
};
|
||||
|
||||
const cw = new ComlinkWorker<typeof MLWorker>(
|
||||
"ml",
|
||||
new Worker(new URL("worker.ts", import.meta.url)),
|
||||
);
|
||||
const ua = await getUserAgent();
|
||||
await cw.remote.then((w) => w.init(ua));
|
||||
await cw.remote.then((w) => w.init(proxy(delegate)));
|
||||
return cw;
|
||||
};
|
||||
|
||||
const getUserAgent = async () =>
|
||||
`${clientPackageName}/${await ensureElectron().appVersion()}`;
|
||||
|
||||
/**
|
||||
* Terminate {@link worker} (if any).
|
||||
*
|
||||
@@ -173,12 +177,13 @@ export const triggerMLSync = () => {
|
||||
* @param enteFile The {@link EnteFile} that got uploaded.
|
||||
*
|
||||
* @param uploadItem The item that was uploaded. This can be used to get at the
|
||||
* contents of the file that got uploaded.
|
||||
* contents of the file that got uploaded. In case of live photos, this is the
|
||||
* image part of the live photo that was uploaded.
|
||||
*/
|
||||
export const indexNewUpload = (enteFile: EnteFile, uploadItem: UploadItem) => {
|
||||
if (!_isMLEnabled) return;
|
||||
if (enteFile.metadata.fileType !== FILE_TYPE.IMAGE) return;
|
||||
log.debug(() => ({ t: "ml-liveq", enteFile, uploadItem }));
|
||||
log.debug(() => ({ t: "ml/liveq", enteFile, uploadItem }));
|
||||
void worker().then((w) => w.onUpload(enteFile, uploadItem));
|
||||
};
|
||||
|
||||
|
||||
14
web/packages/new/photos/services/ml/worker-delegate.ts
Normal file
14
web/packages/new/photos/services/ml/worker-delegate.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Callback functions made available to the ML worker.
|
||||
*
|
||||
* These primarily allow the worker code to access the functions exposed by our
|
||||
* desktop app. Normally code accesses this via `globalThis.electron`, but that
|
||||
* doesn't exist in the execution context of web workers.
|
||||
*
|
||||
* As such, this is currently a subset of {@link Electron}.
|
||||
*/
|
||||
export interface MLWorkerDelegate {
|
||||
appVersion: () => Promise<string>;
|
||||
detectFaces: (input: Float32Array) => Promise<Float32Array>;
|
||||
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { EnteFile } from "@/new/photos/types/file";
|
||||
import { fileLogID } from "@/new/photos/utils/file";
|
||||
import { clientPackageName } from "@/next/app";
|
||||
import { getKVN } from "@/next/kv";
|
||||
import { ensureAuthToken } from "@/next/local-user";
|
||||
import log from "@/next/log";
|
||||
@@ -17,6 +18,7 @@ import {
|
||||
} from "./db";
|
||||
import { pullFaceEmbeddings, putFaceIndex } from "./embedding";
|
||||
import { type FaceIndex, indexFaces } from "./face";
|
||||
import type { MLWorkerDelegate } from "./worker-delegate";
|
||||
|
||||
const idleDurationStart = 5; /* 5 seconds */
|
||||
const idleDurationMax = 16 * 60; /* 16 minutes */
|
||||
@@ -44,6 +46,7 @@ const idleDurationMax = 16 * 60; /* 16 minutes */
|
||||
* - "idle": in between state transitions
|
||||
*/
|
||||
export class MLWorker {
|
||||
private delegate: MLWorkerDelegate | undefined;
|
||||
private userAgent: string | undefined;
|
||||
private shouldSync = false;
|
||||
private liveQ: { enteFile: EnteFile; uploadItem: UploadItem }[] = [];
|
||||
@@ -57,11 +60,13 @@ export class MLWorker {
|
||||
* This is conceptually the constructor, however it is easier to have this
|
||||
* as a separate function to avoid confounding the comlink types too much.
|
||||
*
|
||||
* @param userAgent The user agent string to use as the client field in the
|
||||
* embeddings generated during indexing by this client.
|
||||
* @param delegate The {@link MLWorkerDelegate} that allows the worker to
|
||||
* call back into the main thread.
|
||||
*/
|
||||
async init(userAgent: string) {
|
||||
this.userAgent = userAgent;
|
||||
async init(delegate: MLWorkerDelegate) {
|
||||
this.delegate = delegate;
|
||||
// Set the user agent that'll be set in the generated embeddings.
|
||||
this.userAgent = `${clientPackageName}/${await delegate.appVersion()}`;
|
||||
// Initialize the downloadManager running in the web worker with the
|
||||
// user's token. It'll be used to download files to index if needed.
|
||||
await downloadManager.init(await ensureAuthToken());
|
||||
@@ -130,7 +135,7 @@ export class MLWorker {
|
||||
|
||||
private async tick() {
|
||||
log.debug(() => ({
|
||||
t: "ml-tick",
|
||||
t: "ml/tick",
|
||||
state: this.state,
|
||||
shouldSync: this.shouldSync,
|
||||
liveQ: this.liveQ,
|
||||
@@ -156,7 +161,11 @@ export class MLWorker {
|
||||
const liveQ = this.liveQ.map((i) => i.enteFile);
|
||||
this.liveQ = [];
|
||||
this.state = "indexing";
|
||||
const allSuccess = await indexNextBatch(ensure(this.userAgent), liveQ);
|
||||
const allSuccess = await indexNextBatch(
|
||||
liveQ,
|
||||
ensure(this.delegate),
|
||||
ensure(this.userAgent),
|
||||
);
|
||||
if (allSuccess) {
|
||||
// Everything is running smoothly. Reset the idle duration.
|
||||
this.idleDuration = idleDurationStart;
|
||||
@@ -197,7 +206,11 @@ const pull = pullFaceEmbeddings;
|
||||
* Which means that when it returns true, all is well and there are more
|
||||
* things pending to process, so we should chug along at full speed.
|
||||
*/
|
||||
const indexNextBatch = async (userAgent: string, liveQ: EnteFile[]) => {
|
||||
const indexNextBatch = async (
|
||||
liveQ: EnteFile[],
|
||||
delegate: MLWorkerDelegate,
|
||||
userAgent: string,
|
||||
) => {
|
||||
if (!self.navigator.onLine) {
|
||||
log.info("Skipping ML indexing since we are not online");
|
||||
return false;
|
||||
@@ -214,7 +227,7 @@ const indexNextBatch = async (userAgent: string, liveQ: EnteFile[]) => {
|
||||
let allSuccess = true;
|
||||
for (const file of files) {
|
||||
try {
|
||||
await index(file, undefined, userAgent);
|
||||
await index(file, undefined, delegate, userAgent);
|
||||
// Possibly unnecessary, but let us drain the microtask queue.
|
||||
await wait(0);
|
||||
} catch {
|
||||
@@ -275,6 +288,7 @@ const syncWithLocalFilesAndGetFilesToIndex = async (
|
||||
export const index = async (
|
||||
enteFile: EnteFile,
|
||||
file: File | undefined,
|
||||
delegate: MLWorkerDelegate,
|
||||
userAgent: string,
|
||||
) => {
|
||||
const f = fileLogID(enteFile);
|
||||
@@ -282,7 +296,7 @@ export const index = async (
|
||||
|
||||
let faceIndex: FaceIndex;
|
||||
try {
|
||||
faceIndex = await indexFaces(enteFile, file, userAgent);
|
||||
faceIndex = await indexFaces(enteFile, file, delegate, userAgent);
|
||||
} catch (e) {
|
||||
// Mark indexing as having failed only if the indexing itself
|
||||
// failed, not if there were subsequent failures (like when trying
|
||||
|
||||
@@ -3,6 +3,21 @@ import log, { logToDisk } from "@/next/log";
|
||||
import { expose, wrap, type Remote } from "comlink";
|
||||
import { ensureLocalUser } from "../local-user";
|
||||
|
||||
/**
|
||||
* A minimal wrapper for a web {@link Worker}, proxying a class of type T.
|
||||
*
|
||||
* `comlink` is a library that simplies working with web workers by
|
||||
* transparently proxying objects across the boundary instead of us needing to
|
||||
* work directly with the raw postMessage interface.
|
||||
*
|
||||
* This class is a thin wrapper over a common usage pattern of comlink. It takes
|
||||
* a web worker ({@link Worker}) that is expected to have {@link expose}-ed a
|
||||
* class of type T. It then makes available the main thread handle to this class
|
||||
* as the {@link remote} property.
|
||||
*
|
||||
* It also exposes an object of type {@link WorkerBridge} _to_ the code running
|
||||
* inside the web worker.
|
||||
*/
|
||||
export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
public remote: Promise<Remote<InstanceType<T>>>;
|
||||
private worker: Worker;
|
||||
@@ -17,7 +32,7 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
`Got error event from worker: ${JSON.stringify({ event, name })}`,
|
||||
);
|
||||
};
|
||||
log.debug(() => `Initiated web worker ${name}`);
|
||||
log.debug(() => `Created ${name} web worker`);
|
||||
const comlink = wrap<T>(worker);
|
||||
this.remote = new comlink() as Promise<Remote<InstanceType<T>>>;
|
||||
expose(workerBridge, worker);
|
||||
@@ -25,7 +40,7 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
|
||||
public terminate() {
|
||||
this.worker.terminate();
|
||||
log.debug(() => `Terminated web worker ${this.name}`);
|
||||
log.debug(() => `Terminated ${this.name} web worker`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,15 +55,12 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
|
||||
* is a motley bunch, we just inject them all.
|
||||
*/
|
||||
const workerBridge = {
|
||||
// Needed: generally (presumably)
|
||||
// Needed by all workers (likely, not necessarily).
|
||||
logToDisk,
|
||||
// Needed by ML worker
|
||||
// Needed by MLWorker.
|
||||
getAuthToken: () => ensureLocalUser().token,
|
||||
convertToJPEG: (imageData: Uint8Array) =>
|
||||
ensureElectron().convertToJPEG(imageData),
|
||||
detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
|
||||
computeFaceEmbeddings: (input: Float32Array) =>
|
||||
ensureElectron().computeFaceEmbeddings(input),
|
||||
};
|
||||
|
||||
export type WorkerBridge = typeof workerBridge;
|
||||
|
||||
Reference in New Issue
Block a user