This commit is contained in:
Manav Rathi
2024-07-04 18:45:13 +05:30
parent 95ef21ae4f
commit c3bbceae49
5 changed files with 79 additions and 30 deletions

View File

@@ -9,7 +9,6 @@
import type { EnteFile } from "@/new/photos/types/file";
import log from "@/next/log";
import { workerBridge } from "@/next/worker/worker-bridge";
import { ensure } from "@/utils/ensure";
import { Matrix } from "ml-matrix";
import { getSimilarityTransformation } from "similarity-transformation";
@@ -28,6 +27,7 @@ import {
pixelRGBBilinear,
warpAffineFloat32List,
} from "./image";
import type { MLWorkerDelegate } from "./worker-delegate";
/**
* The version of the face indexing pipeline implemented by the current client.
@@ -221,6 +221,7 @@ export interface Box {
export const indexFaces = async (
enteFile: EnteFile,
file: File | undefined,
delegate: MLWorkerDelegate,
userAgent: string,
) => {
const imageBitmap = await renderableImageBitmap(enteFile, file);
@@ -235,7 +236,7 @@ export const indexFaces = async (
faceEmbedding: {
version: faceIndexingVersion,
client: userAgent,
faces: await indexFacesInBitmap(fileID, imageBitmap),
faces: await indexFacesInBitmap(fileID, imageBitmap, delegate),
},
};
// This step, saving face crops, is not part of the indexing pipeline;
@@ -256,11 +257,12 @@ export const indexFaces = async (
const indexFacesInBitmap = async (
fileID: number,
imageBitmap: ImageBitmap,
delegate: MLWorkerDelegate,
): Promise<Face[]> => {
const { width, height } = imageBitmap;
const imageDimensions = { width, height };
const yoloFaceDetections = await detectFaces(imageBitmap);
const yoloFaceDetections = await detectFaces(imageBitmap, delegate);
const partialResult = yoloFaceDetections.map(
({ box, landmarks, score }) => {
const faceID = makeFaceID(fileID, box, imageDimensions);
@@ -281,7 +283,7 @@ const indexFacesInBitmap = async (
alignments,
);
const embeddings = await computeEmbeddings(alignedFacesData);
const embeddings = await computeEmbeddings(alignedFacesData, delegate);
const blurs = detectBlur(
alignedFacesData,
partialResult.map((f) => f.detection),
@@ -303,6 +305,7 @@ const indexFacesInBitmap = async (
*/
const detectFaces = async (
imageBitmap: ImageBitmap,
delegate: MLWorkerDelegate,
): Promise<YOLOFaceDetection[]> => {
const rect = ({ width, height }: Dimensions) => ({
x: 0,
@@ -313,7 +316,7 @@ const detectFaces = async (
const { yoloInput, yoloSize } =
convertToYOLOInputFloat32ChannelsFirst(imageBitmap);
const yoloOutput = await workerBridge.detectFaces(yoloInput);
const yoloOutput = await delegate.detectFaces(yoloInput);
const faces = filterExtractDetectionsFromYOLOOutput(yoloOutput);
const faceDetections = transformYOLOFaceDetections(
faces,
@@ -871,8 +874,9 @@ const mobileFaceNetEmbeddingSize = 192;
*/
const computeEmbeddings = async (
faceData: Float32Array,
delegate: MLWorkerDelegate,
): Promise<Float32Array[]> => {
const outputData = await workerBridge.computeFaceEmbeddings(faceData);
const outputData = await delegate.computeFaceEmbeddings(faceData);
const embeddingSize = mobileFaceNetEmbeddingSize;
const embeddings = new Array<Float32Array>(

View File

@@ -8,11 +8,12 @@ import {
isInternalUser,
} from "@/new/photos/services/feature-flags";
import type { EnteFile } from "@/new/photos/types/file";
import { clientPackageName, isDesktop } from "@/next/app";
import { isDesktop } from "@/next/app";
import { blobCache } from "@/next/blob-cache";
import { ensureElectron } from "@/next/electron";
import log from "@/next/log";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import { proxy } from "comlink";
import type { UploadItem } from "../upload/types";
import { regenerateFaceCrops } from "./crop";
import { clearFaceDB, faceIndex, indexableAndIndexedCounts } from "./db";
@@ -41,18 +42,21 @@ const worker = async () => {
};
const createComlinkWorker = async () => {
const electron = ensureElectron();
const delegate = {
appVersion: electron.appVersion,
detectFaces: electron.detectFaces,
computeFaceEmbeddings: electron.computeFaceEmbeddings,
};
const cw = new ComlinkWorker<typeof MLWorker>(
"ml",
new Worker(new URL("worker.ts", import.meta.url)),
);
const ua = await getUserAgent();
await cw.remote.then((w) => w.init(ua));
await cw.remote.then((w) => w.init(proxy(delegate)));
return cw;
};
const getUserAgent = async () =>
`${clientPackageName}/${await ensureElectron().appVersion()}`;
/**
* Terminate {@link worker} (if any).
*
@@ -173,12 +177,13 @@ export const triggerMLSync = () => {
* @param enteFile The {@link EnteFile} that got uploaded.
*
* @param uploadItem The item that was uploaded. This can be used to get at the
* contents of the file that got uploaded.
* contents of the file that got uploaded. In case of live photos, this is the
* image part of the live photo that was uploaded.
*/
export const indexNewUpload = (enteFile: EnteFile, uploadItem: UploadItem) => {
if (!_isMLEnabled) return;
if (enteFile.metadata.fileType !== FILE_TYPE.IMAGE) return;
log.debug(() => ({ t: "ml-liveq", enteFile, uploadItem }));
log.debug(() => ({ t: "ml/liveq", enteFile, uploadItem }));
void worker().then((w) => w.onUpload(enteFile, uploadItem));
};

View File

@@ -0,0 +1,14 @@
/**
* Callback functions made available to the ML worker.
*
* These primarily allow the worker code to access the functions exposed by our
* desktop app. Normally code accesses this via `globalThis.electron`, but that
* doesn't exist in the execution context of web workers.
*
* As such, this is currently a subset of {@link Electron}.
*/
export interface MLWorkerDelegate {
appVersion: () => Promise<string>;
detectFaces: (input: Float32Array) => Promise<Float32Array>;
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
}

View File

@@ -1,5 +1,6 @@
import type { EnteFile } from "@/new/photos/types/file";
import { fileLogID } from "@/new/photos/utils/file";
import { clientPackageName } from "@/next/app";
import { getKVN } from "@/next/kv";
import { ensureAuthToken } from "@/next/local-user";
import log from "@/next/log";
@@ -17,6 +18,7 @@ import {
} from "./db";
import { pullFaceEmbeddings, putFaceIndex } from "./embedding";
import { type FaceIndex, indexFaces } from "./face";
import type { MLWorkerDelegate } from "./worker-delegate";
const idleDurationStart = 5; /* 5 seconds */
const idleDurationMax = 16 * 60; /* 16 minutes */
@@ -44,6 +46,7 @@ const idleDurationMax = 16 * 60; /* 16 minutes */
* - "idle": in between state transitions
*/
export class MLWorker {
private delegate: MLWorkerDelegate | undefined;
private userAgent: string | undefined;
private shouldSync = false;
private liveQ: { enteFile: EnteFile; uploadItem: UploadItem }[] = [];
@@ -57,11 +60,13 @@ export class MLWorker {
* This is conceptually the constructor, however it is easier to have this
* as a separate function to avoid confounding the comlink types too much.
*
* @param userAgent The user agent string to use as the client field in the
* embeddings generated during indexing by this client.
* @param delegate The {@link MLWorkerDelegate} that allows the worker to
* call back into the main thread.
*/
async init(userAgent: string) {
this.userAgent = userAgent;
async init(delegate: MLWorkerDelegate) {
this.delegate = delegate;
// Set the user agent that'll be set in the generated embeddings.
this.userAgent = `${clientPackageName}/${await delegate.appVersion()}`;
// Initialize the downloadManager running in the web worker with the
// user's token. It'll be used to download files to index if needed.
await downloadManager.init(await ensureAuthToken());
@@ -130,7 +135,7 @@ export class MLWorker {
private async tick() {
log.debug(() => ({
t: "ml-tick",
t: "ml/tick",
state: this.state,
shouldSync: this.shouldSync,
liveQ: this.liveQ,
@@ -156,7 +161,11 @@ export class MLWorker {
const liveQ = this.liveQ.map((i) => i.enteFile);
this.liveQ = [];
this.state = "indexing";
const allSuccess = await indexNextBatch(ensure(this.userAgent), liveQ);
const allSuccess = await indexNextBatch(
liveQ,
ensure(this.delegate),
ensure(this.userAgent),
);
if (allSuccess) {
// Everything is running smoothly. Reset the idle duration.
this.idleDuration = idleDurationStart;
@@ -197,7 +206,11 @@ const pull = pullFaceEmbeddings;
* Which means that when it returns true, all is well and there are more
* things pending to process, so we should chug along at full speed.
*/
const indexNextBatch = async (userAgent: string, liveQ: EnteFile[]) => {
const indexNextBatch = async (
liveQ: EnteFile[],
delegate: MLWorkerDelegate,
userAgent: string,
) => {
if (!self.navigator.onLine) {
log.info("Skipping ML indexing since we are not online");
return false;
@@ -214,7 +227,7 @@ const indexNextBatch = async (userAgent: string, liveQ: EnteFile[]) => {
let allSuccess = true;
for (const file of files) {
try {
await index(file, undefined, userAgent);
await index(file, undefined, delegate, userAgent);
// Possibly unnecessary, but let us drain the microtask queue.
await wait(0);
} catch {
@@ -275,6 +288,7 @@ const syncWithLocalFilesAndGetFilesToIndex = async (
export const index = async (
enteFile: EnteFile,
file: File | undefined,
delegate: MLWorkerDelegate,
userAgent: string,
) => {
const f = fileLogID(enteFile);
@@ -282,7 +296,7 @@ export const index = async (
let faceIndex: FaceIndex;
try {
faceIndex = await indexFaces(enteFile, file, userAgent);
faceIndex = await indexFaces(enteFile, file, delegate, userAgent);
} catch (e) {
// Mark indexing as having failed only if the indexing itself
// failed, not if there were subsequent failures (like when trying

View File

@@ -3,6 +3,21 @@ import log, { logToDisk } from "@/next/log";
import { expose, wrap, type Remote } from "comlink";
import { ensureLocalUser } from "../local-user";
/**
* A minimal wrapper for a web {@link Worker}, proxying a class of type T.
*
* `comlink` is a library that simplies working with web workers by
* transparently proxying objects across the boundary instead of us needing to
* work directly with the raw postMessage interface.
*
* This class is a thin wrapper over a common usage pattern of comlink. It takes
* a web worker ({@link Worker}) that is expected to have {@link expose}-ed a
* class of type T. It then makes available the main thread handle to this class
* as the {@link remote} property.
*
* It also exposes an object of type {@link WorkerBridge} _to_ the code running
* inside the web worker.
*/
export class ComlinkWorker<T extends new () => InstanceType<T>> {
public remote: Promise<Remote<InstanceType<T>>>;
private worker: Worker;
@@ -17,7 +32,7 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
`Got error event from worker: ${JSON.stringify({ event, name })}`,
);
};
log.debug(() => `Initiated web worker ${name}`);
log.debug(() => `Created ${name} web worker`);
const comlink = wrap<T>(worker);
this.remote = new comlink() as Promise<Remote<InstanceType<T>>>;
expose(workerBridge, worker);
@@ -25,7 +40,7 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
public terminate() {
this.worker.terminate();
log.debug(() => `Terminated web worker ${this.name}`);
log.debug(() => `Terminated ${this.name} web worker`);
}
}
@@ -40,15 +55,12 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
* is a motley bunch, we just inject them all.
*/
const workerBridge = {
// Needed: generally (presumably)
// Needed by all workers (likely, not necessarily).
logToDisk,
// Needed by ML worker
// Needed by MLWorker.
getAuthToken: () => ensureLocalUser().token,
convertToJPEG: (imageData: Uint8Array) =>
ensureElectron().convertToJPEG(imageData),
detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
computeFaceEmbeddings: (input: Float32Array) =>
ensureElectron().computeFaceEmbeddings(input),
};
export type WorkerBridge = typeof workerBridge;