diff --git a/web/packages/new/photos/services/ml/embedding.ts b/web/packages/new/photos/services/ml/embedding.ts index 6da4e765d1..11df71bc84 100644 --- a/web/packages/new/photos/services/ml/embedding.ts +++ b/web/packages/new/photos/services/ml/embedding.ts @@ -235,7 +235,7 @@ const remoteDerivedDataFromJSONString = (jsonString: string) => { * @param fileIDs The ids of the files for which we want the embeddings. * * @returns a list of {@link RemoteEmbedding} for the files which had embeddings - * (and thatt remote was able to successfully retrieve). The order of this list + * (and that remote was able to successfully retrieve). The order of this list * is arbitrary, and the caller should use the {@link fileID} present within the * {@link RemoteEmbedding} to associate an item in the result back to a file * instead of relying on the order or count of items in the result. @@ -308,3 +308,50 @@ const putEmbedding = async ( }); ensureOk(res); }; + +/** A single entry in the response of {@link getIndexedFiles}. */ +const IndexedFile = z.object({ + fileID: z.number(), + updatedAt: z.number(), +}); + +type IndexedFile = z.infer; + +/** + * Fetch the file ids whose {@link model} derived data has been created or + * updated since the given {@link sinceTime}. + * + * This allows a client to perform a quick "diff" and first fetch all derived + * data that has changed since the last time it checked. By fetching these all + * upfront instead of doing them one by one during the indexing, we can speed up + * the initial sync of existing embeddings on a new client. + * + * @param sinceTime Epoch milliseconds. We use this to ask remote to provide us + * derived data whose {@link updatedAt} is more than the given value. If not + * specified, then we'll start from the beginning. + * + * @param limit An advisory limit on the number of items to return. + * + * @returns an array of file ids, each with an associated timestamp when the + * derived data for that file was last changed. + * + * The caller should persist the latest amongst these timestamps and use it in + * subsequent calls to resume pulling from the current checkpoint. + */ +export const getIndexedDerivedDataFiles = async ( + sinceTime: number, + limit: number, +): Promise => { + const params = new URLSearchParams({ + model: "derived", + sinceTime: sinceTime.toString(), + limit: limit.toString(), + }); + const url = await apiURL("/embeddings/indexed-files"); + const res = await fetch(`${url}?${params.toString()}`, { + headers: await authenticatedRequestHeaders(), + }); + ensureOk(res); + return z.object({ diff: z.array(IndexedFile) }).parse(await res.json()) + .diff; +}; diff --git a/web/packages/new/photos/services/ml/worker.ts b/web/packages/new/photos/services/ml/worker.ts index 37d6dc259e..aecf7b45d3 100644 --- a/web/packages/new/photos/services/ml/worker.ts +++ b/web/packages/new/photos/services/ml/worker.ts @@ -1,6 +1,6 @@ import { clientPackageName } from "@/base/app"; import { isHTTP4xxError } from "@/base/http"; -import { getKVN } from "@/base/kv"; +import { getKVN, setKV } from "@/base/kv"; import { ensureAuthToken } from "@/base/local-user"; import log from "@/base/log"; import type { ElectronMLWorker } from "@/base/types/ipc"; @@ -34,6 +34,7 @@ import { } from "./db"; import { fetchDerivedData, + getIndexedDerivedDataFiles, putDerivedData, type RawRemoteDerivedData, type RemoteDerivedData, @@ -64,15 +65,16 @@ interface IndexableItem { * * ext. event state then state * ------------- --------------- -------------- + * sync -> "pull" -> "idle" * sync -> "backfillq" -> "idle" * upload -> "liveq" -> "idle" * idleTimeout -> "backfillq" -> "idle" * * where: * + * - "pull": pull existing embeddings from remote. * - "liveq": indexing items that are being uploaded, - * - "backfillq": fetching remote embeddings of unindexed items, and then - * indexing them if needed, + * - "backfillq": index unindexed items otherwise. * - "idle": in between state transitions. * * In addition, MLWorker can also be invoked for interactive tasks: in @@ -81,7 +83,9 @@ interface IndexableItem { export class MLWorker { private electron: ElectronMLWorker | undefined; private delegate: MLWorkerDelegate | undefined; - private state: "idle" | "indexing" = "idle"; + private state: "idle" | "waking" | "pull" | "indexing" = "idle"; + private shouldPull = false; + private havePulledAtLeastOnce = false; private liveQ: IndexableItem[] = []; private idleTimeout: ReturnType | undefined; private idleDuration = idleDurationStart; /* unit: seconds */ @@ -127,23 +131,36 @@ export class MLWorker { } /** - * Start backfilling if needed. + * Pull embeddings from remote, and start backfilling if needed. * - * This function enqueues a backfill attempt and returns immediately without - * waiting for it complete. During a backfill, it will first attempt to - * fetch embeddings for files which don't have that data locally. If we - * fetch and find what we need, we save it locally. Otherwise we index them. + * This function enqueues a pull and returns immediately without waiting for + * the pull to complete. + * + * Once the pull is done, it then schedules a backfill. So calling this also + * implicitly triggers a backfill (which is why we call it a less-precise + * "sync" instead of "pull"). + * + * During a backfill we will first attempt to fetch embeddings for files + * which don't have that data locally. If we fetch and find what we need, we + * save it locally. Otherwise we index them. This the pull upfront is not + * necessary, but it helps a new client get up to speed faster since it can + * fetch all existing embeddings first before getting down to the indexing. */ sync() { + this.shouldPull = true; this.wakeUp(); } /** Invoked in response to external events. */ private wakeUp() { if (this.state == "idle") { - // Currently paused. Get back to work. + // We are currently paused. Get back to work. if (this.idleTimeout) clearTimeout(this.idleTimeout); this.idleTimeout = undefined; + // Change state so that multiple calls to `wakeUp` don't cause + // multiple calls to `tick`. + this.state = "waking"; + // Enqueue a tick. void this.tick(); } else { // In the middle of a task. Do nothing, `this.tick` will @@ -200,18 +217,49 @@ export class MLWorker { { state: this.state, liveQ: this.liveQ, + shouldPull: this.shouldPull, idleDuration: this.idleDuration, }, ]); const scheduleTick = () => void setTimeout(() => this.tick(), 0); + // If we've been asked to pull, do that first (before indexing). + if (this.shouldPull) { + // Allow this flag to be reset while we're pulling (triggering + // another pull when we tick next). + this.shouldPull = false; + this.state = "pull"; + try { + const didPull = await pull(); + // Mark that we completed one attempt at pulling successfully + // (irrespective of whether or not that got us some data). + this.havePulledAtLeastOnce = true; + // Reset the idle duration if we did pull something. + if (didPull) this.idleDuration = idleDurationStart; + } catch (e) { + log.error("Failed to pull embeddings", e); + } + // Tick again, even if we got an error. + // + // While the backfillQ won't be processed until at least a pull has + // happened once (`havePulledAtLeastOnce`), the liveQ can still be + // processed since these are new files without remote embeddings. + scheduleTick(); + return; + } + const liveQ = this.liveQ; this.liveQ = []; this.state = "indexing"; - // Use the liveQ if present, otherwise get the next batch to backfill. - const items = liveQ.length > 0 ? liveQ : await this.backfillQ(); + // Use the liveQ if present, otherwise get the next batch to backfill, + // but only after we've pulled once from remote successfully. + const items = liveQ.length + ? liveQ + : this.havePulledAtLeastOnce + ? await this.backfillQ() + : []; const allSuccess = await indexNextBatch( items, @@ -263,6 +311,137 @@ export class MLWorker { expose(MLWorker); +/** + * Pull embeddings from remote. + * + * Return true atleast one embedding was pulled. + */ +const pull = async () => { + // If we've never pulled before, start at the beginning (0). + const sinceTime = (await latestDerivedDataUpdatedAt()) ?? 0; + // Start fetching, starting the fetched count at 0. + const fetchedCount = await pullSince(sinceTime, 0); + // Return true if something got fetched. + return fetchedCount > 0; +}; + +const latestDerivedDataUpdatedAt = () => getKVN("latestDerivedDataUpdatedAt"); + +const setLatestDerivedDataUpdatedAt = (n: number) => + setKV("latestDerivedDataUpdatedAt", n); + +const pullSince = async (sinceTime: number, fetchedCount: number) => { + // See if anything has changed since `sinceTime`. + const indexedFiles = await getIndexedDerivedDataFiles(sinceTime, 200); + + // Nothing more is left. Return the previous fetch count we got. + if (!indexedFiles.length) return fetchedCount; + + // Find the latest from amongst all the updatedAt we got back. This'll serve + // as our checkpoint for the next pull. + const latestUpdatedAt = indexedFiles.reduce( + (max, { updatedAt }) => Math.max(max, updatedAt), + sinceTime, + ); + + // Fetch the embeddings for the files which changed. + // + // In rare cases, remote might return a partial response, but that will not + // have any lasting impact since we anyways refetch the derived data before + // attempting indexing. + + const localFiles = await getAllLocalFiles(); + const localFilesByID = new Map(localFiles.map((f) => [f.id, f])); + + const filesByID = new Map( + indexedFiles + .map(({ fileID }) => localFilesByID.get(fileID)) + .filter((x) => x !== undefined) + .map((f) => [f.id, f]), + ); + + const items = await fetchDerivedData(filesByID); + + const save = async ([id, data]: [number, RemoteDerivedData]) => { + try { + await saveDerivedData(id, data); + } catch (e) { + // Ignore errors during saving individual items, let the rest of the + // pull proceed. Failures will not have a lasting impact since the + // file will anyways get revisited as part of a backfill. + log.warn( + `Ignoring error when saving pulled derived data for file id ${id}`, + e, + ); + } + }; + + // Save items. + await Promise.all([...items.entries()].map(save)); + + // Save the checkpoint. + await setLatestDerivedDataUpdatedAt(latestUpdatedAt); + + // Fetch subsequent items. As a safety valve, ensure we don't get into an + // infinite loop by checking that the sinceTime has advanced. + + if (latestUpdatedAt == sinceTime) + throw new Error(`Since time ${sinceTime} did not advance after a pull`); + + return pullSince(latestUpdatedAt, fetchedCount + items.size); +}; + +/** + * Save the given {@link remoteDerivedData} for {@link fileID}. + * + * This as subset of the save sequence during {@link index}. This one is meant + * to be used during a {@link pull}. + */ +const saveDerivedData = async ( + fileID: number, + remoteDerivedData: RemoteDerivedData, +) => { + // Discard any existing data that is made by an older indexing pipelines. + // See: [Note: Embedding versions] + + const existingRemoteFaceIndex = remoteDerivedData.parsed?.face; + const existingRemoteCLIPIndex = remoteDerivedData.parsed?.clip; + + let existingFaceIndex: FaceIndex | undefined; + if ( + existingRemoteFaceIndex && + existingRemoteFaceIndex.version >= faceIndexingVersion + ) { + const { width, height, faces } = existingRemoteFaceIndex; + existingFaceIndex = { width, height, faces }; + } + + let existingCLIPIndex: CLIPIndex | undefined; + if ( + existingRemoteCLIPIndex && + existingRemoteCLIPIndex.version >= clipIndexingVersion + ) { + const { embedding } = existingRemoteCLIPIndex; + existingCLIPIndex = { embedding }; + } + + // If we have all the required embedding types, then save them, marking a + // file as indexed. + // + // In particular, this means that there might be files which we've marked + // indexed but still don't have the optional derived data types like exif. + // This is fine, we wish to compute the optional type of derived data when + // we can, but by themselves they're not reason enough for us to download + // and index the original. + + if (existingFaceIndex && existingCLIPIndex) { + await saveIndexes( + { fileID, ...existingFaceIndex }, + { fileID, ...existingCLIPIndex }, + ); + } +}; + /** * Find out files which need to be indexed. Then index the next batch of them. * @@ -321,6 +500,9 @@ const indexNextBatch = async ( await wait(0); } + // Wait for the pending tasks to drain out. + await Promise.all(tasks); + // Return true if nothing failed. return allSuccess; }; @@ -436,6 +618,8 @@ const index = async ( existingRemoteFaceIndex && existingRemoteFaceIndex.version >= faceIndexingVersion ) { + // Destructure the data we got from remote so that we only retain the + // fields we're interested in the object that gets put into indexed db. const { width, height, faces } = existingRemoteFaceIndex; existingFaceIndex = { width, height, faces }; }