This commit is contained in:
Manav Rathi
2024-07-09 13:14:24 +05:30
parent 023610cebd
commit ca44202fcd

View File

@@ -12,8 +12,9 @@ import { getKV, setKV } from "@/next/kv";
import log from "@/next/log";
import { apiURL } from "@/next/origins";
import { z } from "zod";
import { saveFaceIndex } from "./db";
import { type FaceIndex, faceIndexingVersion } from "./face";
import { clipIndexingVersion, type CLIPIndex } from "./clip";
import { saveCLIPIndex, saveFaceIndex } from "./db";
import { faceIndexingVersion, type FaceIndex } from "./face";
/**
* The embeddings that we (the current client) knows how to handle.
@@ -397,3 +398,61 @@ const FaceIndex = z
*/
export const putFaceIndex = async (enteFile: EnteFile, faceIndex: FaceIndex) =>
putEmbedding(enteFile, "file-ml-clip-face", JSON.stringify(faceIndex));
// MARK: - CLIP
/**
* Fetch new or updated CLIP embeddings from remote and save them locally.
*
* See {@link pullFaceEmbeddings} for a sibling function with more comprehensive
* documentation.
*
* @returns true if at least one embedding was pulled, false otherwise.
*/
export const pullCLIPEmbeddings = () =>
pullEmbeddings("onnx-clip", (jsonString: string) =>
// eslint-disable-next-line @typescript-eslint/prefer-ts-expect-error, @typescript-eslint/ban-ts-comment
// @ts-ignore TODO: There is no error here, but this file is imported by
// one of our packages that doesn't have strict mode enabled yet,
// causing a spurious error to be emitted in that context.
saveCLIPIndexIfNewer(CLIPIndex.parse(JSON.parse(jsonString))),
);
/**
* Save the given {@link clipIndex} locally if it is newer than the one we have.
*
* This is a variant of {@link saveCLIPIndex} that performs version checking as
* described in [Note: Handling versioning of embeddings].
*/
const saveCLIPIndexIfNewer = async (index: CLIPIndex) => {
const version = index.version;
if (version < clipIndexingVersion) {
log.info(
`Ignoring remote CLIP index with version ${version} older than what our indexer can produce (${clipIndexingVersion})`,
);
return;
}
return saveCLIPIndex(index);
};
/**
* Zod schemas for the {@link CLIPIndex} types.
*
* See: [Note: Duplicated between Zod schemas and TS type]
*/
const CLIPIndex = z
.object({
fileID: z.number(),
version: z.number(),
client: z.string(),
embedding: z.array(z.number()),
})
// Retain fields we might not (currently) understand.
.passthrough();
/**
* Save the CLIP index for the given {@link enteFile} on remote so that other
* clients can directly pull it instead of needing to reindex.
*/
export const putCLIPIndex = async (enteFile: EnteFile, clipIndex: CLIPIndex) =>
putEmbedding(enteFile, "onnx-clip", JSON.stringify(clipIndex));