[desktop] Pull existing embeddings first before starting on indexing (#2591)

Speeds up the initial sync on a new client.
This commit is contained in:
Manav Rathi
2024-08-01 12:36:14 +05:30
committed by GitHub
2 changed files with 244 additions and 13 deletions

View File

@@ -235,7 +235,7 @@ const remoteDerivedDataFromJSONString = (jsonString: string) => {
* @param fileIDs The ids of the files for which we want the embeddings.
*
* @returns a list of {@link RemoteEmbedding} for the files which had embeddings
* (and thatt remote was able to successfully retrieve). The order of this list
* (and that remote was able to successfully retrieve). The order of this list
* is arbitrary, and the caller should use the {@link fileID} present within the
* {@link RemoteEmbedding} to associate an item in the result back to a file
* instead of relying on the order or count of items in the result.
@@ -308,3 +308,50 @@ const putEmbedding = async (
});
ensureOk(res);
};
/** A single entry in the response of {@link getIndexedFiles}. */
const IndexedFile = z.object({
fileID: z.number(),
updatedAt: z.number(),
});
type IndexedFile = z.infer<typeof IndexedFile>;
/**
* Fetch the file ids whose {@link model} derived data has been created or
* updated since the given {@link sinceTime}.
*
* This allows a client to perform a quick "diff" and first fetch all derived
* data that has changed since the last time it checked. By fetching these all
* upfront instead of doing them one by one during the indexing, we can speed up
* the initial sync of existing embeddings on a new client.
*
* @param sinceTime Epoch milliseconds. We use this to ask remote to provide us
* derived data whose {@link updatedAt} is more than the given value. If not
* specified, then we'll start from the beginning.
*
* @param limit An advisory limit on the number of items to return.
*
* @returns an array of file ids, each with an associated timestamp when the
* derived data for that file was last changed.
*
* The caller should persist the latest amongst these timestamps and use it in
* subsequent calls to resume pulling from the current checkpoint.
*/
export const getIndexedDerivedDataFiles = async (
sinceTime: number,
limit: number,
): Promise<IndexedFile[]> => {
const params = new URLSearchParams({
model: "derived",
sinceTime: sinceTime.toString(),
limit: limit.toString(),
});
const url = await apiURL("/embeddings/indexed-files");
const res = await fetch(`${url}?${params.toString()}`, {
headers: await authenticatedRequestHeaders(),
});
ensureOk(res);
return z.object({ diff: z.array(IndexedFile) }).parse(await res.json())
.diff;
};

View File

@@ -1,6 +1,6 @@
import { clientPackageName } from "@/base/app";
import { isHTTP4xxError } from "@/base/http";
import { getKVN } from "@/base/kv";
import { getKVN, setKV } from "@/base/kv";
import { ensureAuthToken } from "@/base/local-user";
import log from "@/base/log";
import type { ElectronMLWorker } from "@/base/types/ipc";
@@ -34,6 +34,7 @@ import {
} from "./db";
import {
fetchDerivedData,
getIndexedDerivedDataFiles,
putDerivedData,
type RawRemoteDerivedData,
type RemoteDerivedData,
@@ -64,15 +65,16 @@ interface IndexableItem {
*
* ext. event state then state
* ------------- --------------- --------------
* sync -> "pull" -> "idle"
* sync -> "backfillq" -> "idle"
* upload -> "liveq" -> "idle"
* idleTimeout -> "backfillq" -> "idle"
*
* where:
*
* - "pull": pull existing embeddings from remote.
* - "liveq": indexing items that are being uploaded,
* - "backfillq": fetching remote embeddings of unindexed items, and then
* indexing them if needed,
* - "backfillq": index unindexed items otherwise.
* - "idle": in between state transitions.
*
* In addition, MLWorker can also be invoked for interactive tasks: in
@@ -81,7 +83,9 @@ interface IndexableItem {
export class MLWorker {
private electron: ElectronMLWorker | undefined;
private delegate: MLWorkerDelegate | undefined;
private state: "idle" | "indexing" = "idle";
private state: "idle" | "waking" | "pull" | "indexing" = "idle";
private shouldPull = false;
private havePulledAtLeastOnce = false;
private liveQ: IndexableItem[] = [];
private idleTimeout: ReturnType<typeof setTimeout> | undefined;
private idleDuration = idleDurationStart; /* unit: seconds */
@@ -127,23 +131,36 @@ export class MLWorker {
}
/**
* Start backfilling if needed.
* Pull embeddings from remote, and start backfilling if needed.
*
* This function enqueues a backfill attempt and returns immediately without
* waiting for it complete. During a backfill, it will first attempt to
* fetch embeddings for files which don't have that data locally. If we
* fetch and find what we need, we save it locally. Otherwise we index them.
* This function enqueues a pull and returns immediately without waiting for
* the pull to complete.
*
* Once the pull is done, it then schedules a backfill. So calling this also
* implicitly triggers a backfill (which is why we call it a less-precise
* "sync" instead of "pull").
*
* During a backfill we will first attempt to fetch embeddings for files
* which don't have that data locally. If we fetch and find what we need, we
* save it locally. Otherwise we index them. This the pull upfront is not
* necessary, but it helps a new client get up to speed faster since it can
* fetch all existing embeddings first before getting down to the indexing.
*/
sync() {
this.shouldPull = true;
this.wakeUp();
}
/** Invoked in response to external events. */
private wakeUp() {
if (this.state == "idle") {
// Currently paused. Get back to work.
// We are currently paused. Get back to work.
if (this.idleTimeout) clearTimeout(this.idleTimeout);
this.idleTimeout = undefined;
// Change state so that multiple calls to `wakeUp` don't cause
// multiple calls to `tick`.
this.state = "waking";
// Enqueue a tick.
void this.tick();
} else {
// In the middle of a task. Do nothing, `this.tick` will
@@ -200,18 +217,49 @@ export class MLWorker {
{
state: this.state,
liveQ: this.liveQ,
shouldPull: this.shouldPull,
idleDuration: this.idleDuration,
},
]);
const scheduleTick = () => void setTimeout(() => this.tick(), 0);
// If we've been asked to pull, do that first (before indexing).
if (this.shouldPull) {
// Allow this flag to be reset while we're pulling (triggering
// another pull when we tick next).
this.shouldPull = false;
this.state = "pull";
try {
const didPull = await pull();
// Mark that we completed one attempt at pulling successfully
// (irrespective of whether or not that got us some data).
this.havePulledAtLeastOnce = true;
// Reset the idle duration if we did pull something.
if (didPull) this.idleDuration = idleDurationStart;
} catch (e) {
log.error("Failed to pull embeddings", e);
}
// Tick again, even if we got an error.
//
// While the backfillQ won't be processed until at least a pull has
// happened once (`havePulledAtLeastOnce`), the liveQ can still be
// processed since these are new files without remote embeddings.
scheduleTick();
return;
}
const liveQ = this.liveQ;
this.liveQ = [];
this.state = "indexing";
// Use the liveQ if present, otherwise get the next batch to backfill.
const items = liveQ.length > 0 ? liveQ : await this.backfillQ();
// Use the liveQ if present, otherwise get the next batch to backfill,
// but only after we've pulled once from remote successfully.
const items = liveQ.length
? liveQ
: this.havePulledAtLeastOnce
? await this.backfillQ()
: [];
const allSuccess = await indexNextBatch(
items,
@@ -263,6 +311,137 @@ export class MLWorker {
expose(MLWorker);
/**
* Pull embeddings from remote.
*
* Return true atleast one embedding was pulled.
*/
const pull = async () => {
// If we've never pulled before, start at the beginning (0).
const sinceTime = (await latestDerivedDataUpdatedAt()) ?? 0;
// Start fetching, starting the fetched count at 0.
const fetchedCount = await pullSince(sinceTime, 0);
// Return true if something got fetched.
return fetchedCount > 0;
};
const latestDerivedDataUpdatedAt = () => getKVN("latestDerivedDataUpdatedAt");
const setLatestDerivedDataUpdatedAt = (n: number) =>
setKV("latestDerivedDataUpdatedAt", n);
const pullSince = async (sinceTime: number, fetchedCount: number) => {
// See if anything has changed since `sinceTime`.
const indexedFiles = await getIndexedDerivedDataFiles(sinceTime, 200);
// Nothing more is left. Return the previous fetch count we got.
if (!indexedFiles.length) return fetchedCount;
// Find the latest from amongst all the updatedAt we got back. This'll serve
// as our checkpoint for the next pull.
const latestUpdatedAt = indexedFiles.reduce(
(max, { updatedAt }) => Math.max(max, updatedAt),
sinceTime,
);
// Fetch the embeddings for the files which changed.
//
// In rare cases, remote might return a partial response, but that will not
// have any lasting impact since we anyways refetch the derived data before
// attempting indexing.
const localFiles = await getAllLocalFiles();
const localFilesByID = new Map(localFiles.map((f) => [f.id, f]));
const filesByID = new Map(
indexedFiles
.map(({ fileID }) => localFilesByID.get(fileID))
.filter((x) => x !== undefined)
.map((f) => [f.id, f]),
);
const items = await fetchDerivedData(filesByID);
const save = async ([id, data]: [number, RemoteDerivedData]) => {
try {
await saveDerivedData(id, data);
} catch (e) {
// Ignore errors during saving individual items, let the rest of the
// pull proceed. Failures will not have a lasting impact since the
// file will anyways get revisited as part of a backfill.
log.warn(
`Ignoring error when saving pulled derived data for file id ${id}`,
e,
);
}
};
// Save items.
await Promise.all([...items.entries()].map(save));
// Save the checkpoint.
await setLatestDerivedDataUpdatedAt(latestUpdatedAt);
// Fetch subsequent items. As a safety valve, ensure we don't get into an
// infinite loop by checking that the sinceTime has advanced.
if (latestUpdatedAt == sinceTime)
throw new Error(`Since time ${sinceTime} did not advance after a pull`);
return pullSince(latestUpdatedAt, fetchedCount + items.size);
};
/**
* Save the given {@link remoteDerivedData} for {@link fileID}.
*
* This as subset of the save sequence during {@link index}. This one is meant
* to be used during a {@link pull}.
*/
const saveDerivedData = async (
fileID: number,
remoteDerivedData: RemoteDerivedData,
) => {
// Discard any existing data that is made by an older indexing pipelines.
// See: [Note: Embedding versions]
const existingRemoteFaceIndex = remoteDerivedData.parsed?.face;
const existingRemoteCLIPIndex = remoteDerivedData.parsed?.clip;
let existingFaceIndex: FaceIndex | undefined;
if (
existingRemoteFaceIndex &&
existingRemoteFaceIndex.version >= faceIndexingVersion
) {
const { width, height, faces } = existingRemoteFaceIndex;
existingFaceIndex = { width, height, faces };
}
let existingCLIPIndex: CLIPIndex | undefined;
if (
existingRemoteCLIPIndex &&
existingRemoteCLIPIndex.version >= clipIndexingVersion
) {
const { embedding } = existingRemoteCLIPIndex;
existingCLIPIndex = { embedding };
}
// If we have all the required embedding types, then save them, marking a
// file as indexed.
//
// In particular, this means that there might be files which we've marked
// indexed but still don't have the optional derived data types like exif.
// This is fine, we wish to compute the optional type of derived data when
// we can, but by themselves they're not reason enough for us to download
// and index the original.
if (existingFaceIndex && existingCLIPIndex) {
await saveIndexes(
{ fileID, ...existingFaceIndex },
{ fileID, ...existingCLIPIndex },
);
}
};
/**
* Find out files which need to be indexed. Then index the next batch of them.
*
@@ -321,6 +500,9 @@ const indexNextBatch = async (
await wait(0);
}
// Wait for the pending tasks to drain out.
await Promise.all(tasks);
// Return true if nothing failed.
return allSuccess;
};
@@ -436,6 +618,8 @@ const index = async (
existingRemoteFaceIndex &&
existingRemoteFaceIndex.version >= faceIndexingVersion
) {
// Destructure the data we got from remote so that we only retain the
// fields we're interested in the object that gets put into indexed db.
const { width, height, faces } = existingRemoteFaceIndex;
existingFaceIndex = { width, height, faces };
}