diff --git a/web/packages/new/photos/services/ml/clip.ts b/web/packages/new/photos/services/ml/clip.ts index 72de9abedb..773621602f 100644 --- a/web/packages/new/photos/services/ml/clip.ts +++ b/web/packages/new/photos/services/ml/clip.ts @@ -1,6 +1,6 @@ import type { ElectronMLWorker } from "@/base/types/ipc"; import type { ImageBitmapAndData } from "./blob"; -import { getCLIPIndexes } from "./db"; +import { savedCLIPIndexes } from "./db"; import { dotProduct, norm } from "./math"; import type { CLIPMatches } from "./worker-types"; @@ -169,7 +169,7 @@ let _cachedCLIPIndexes: * produces potentially new CLIP indexes). */ const cachedOrReadCLIPIndexes = async () => - (_cachedCLIPIndexes ??= (await getCLIPIndexes()).map( + (_cachedCLIPIndexes ??= (await savedCLIPIndexes()).map( ({ fileID, embedding }) => ({ fileID, embedding: new Float32Array(embedding), diff --git a/web/packages/new/photos/services/ml/db.ts b/web/packages/new/photos/services/ml/db.ts index 3220f6a2ed..4473697aec 100644 --- a/web/packages/new/photos/services/ml/db.ts +++ b/web/packages/new/photos/services/ml/db.ts @@ -206,7 +206,7 @@ const newFileStatus = (fileID: number): FileStatus => ({ /** * Return the {@link FaceIndex}, if any, for {@link fileID}. */ -export const getFaceIndex = async (fileID: number) => { +export const savedFaceIndex = async (fileID: number) => { const db = await mlDB(); return db.get("face-index", fileID); }; @@ -214,7 +214,7 @@ export const getFaceIndex = async (fileID: number) => { /** * Return all face indexes present locally. */ -export const getFaceIndexes = async () => { +export const savedFaceIndexes = async () => { const db = await mlDB(); return await db.getAll("face-index"); }; @@ -222,7 +222,7 @@ export const getFaceIndexes = async () => { /** * Return all CLIP indexes present locally. */ -export const getCLIPIndexes = async () => { +export const savedCLIPIndexes = async () => { const db = await mlDB(); return await db.getAll("clip-index"); }; diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts index 677b340c93..e1af6bcfac 100644 --- a/web/packages/new/photos/services/ml/index.ts +++ b/web/packages/new/photos/services/ml/index.ts @@ -26,7 +26,7 @@ import { import { deleteUserEntity } from "../user-entity/remote"; import type { FaceCluster } from "./cluster"; import { regenerateFaceCrops } from "./crop"; -import { clearMLDB, getFaceIndex, getIndexableAndIndexedCounts } from "./db"; +import { clearMLDB, getIndexableAndIndexedCounts, savedFaceIndex } from "./db"; import { filterNamedPeople, reconstructPeople, type Person } from "./people"; import { MLWorker } from "./worker"; import type { CLIPMatches } from "./worker-types"; @@ -633,7 +633,7 @@ export const getAnnotatedFacesForFile = async ( const annotatedFaceIDs: AnnotatedFaceID[] = []; const otherFaceIDs: string[] = []; - const index = await getFaceIndex(enteFile.id); + const index = await savedFaceIndex(enteFile.id); if (!index) return { annotatedFaceIDs, otherFaceIDs }; const people = _state.peopleSnapshot ?? []; @@ -702,7 +702,7 @@ export const faceCrop = async (faceID: string, enteFile: EnteFile) => { * the file (updating the "face-crops" {@link BlobCache}). */ const regenerateFaceCropsIfNeeded = async (enteFile: EnteFile) => { - const index = await getFaceIndex(enteFile.id); + const index = await savedFaceIndex(enteFile.id); if (!index) return; const cache = await blobCache("face-crops"); diff --git a/web/packages/new/photos/services/ml/people.ts b/web/packages/new/photos/services/ml/people.ts index 08482e0e1a..9a02670dc6 100644 --- a/web/packages/new/photos/services/ml/people.ts +++ b/web/packages/new/photos/services/ml/people.ts @@ -2,8 +2,9 @@ import type { EnteFile } from "../../types/file"; import { getLocalFiles } from "../files"; import { savedCGroups, type CGroup } from "../user-entity"; import type { FaceCluster } from "./cluster"; -import { getFaceIndexes, savedFaceClusters } from "./db"; +import { savedFaceClusters, savedFaceIndexes } from "./db"; import { fileIDFromFaceID } from "./face"; +import { dotProduct } from "./math"; /** * A cgroup ("cluster group") is a group of clusters (possibly containing just a @@ -167,7 +168,7 @@ export const reconstructPeople = async (): Promise => { { faceID: string; file: EnteFile; score: number } >(); - const faceIndexes = await getFaceIndexes(); + const faceIndexes = await savedFaceIndexes(); for (const { faces } of faceIndexes) { for (const { faceID, score } of faces) { const fileID = fileIDFromFaceID(faceID); @@ -305,3 +306,59 @@ export const filterNamedPeople = (people: Person[]): NamedPerson[] => { } return namedPeople; }; + +/** + * Returns suggestions for the given person. + */ +export const suggestionsForPerson = async (person: CGroupPerson) => { + const personClusters = person.cgroup.data.assigned; + // TODO-Cluster: Persist this. + const ignoredClusters: FaceCluster[] = []; + + const clusters = await savedFaceClusters(); + const faceIndexes = await savedFaceIndexes(); + + const embeddingByFaceID = new Map( + faceIndexes + .map(({ faces }) => + faces.map( + (f) => [f.faceID, new Float32Array(f.embedding)] as const, + ), + ) + .flat(), + ); + + const personClusterIDs = new Set(personClusters.map(({ id }) => id)); + const ignoredClusterIDs = new Set(ignoredClusters.map(({ id }) => id)); + + const personFaceEmbeddings = personClusters + .map(({ faces }) => faces.map((id) => embeddingByFaceID.get(id))) + .flat() + .filter((e) => !!e); + + const suggestedClusters: FaceCluster[] = []; + for (const cluster of clusters) { + const { id, faces } = cluster; + + if (personClusterIDs.has(id)) continue; + if (ignoredClusterIDs.has(id)) continue; + + let suggest = false; + for (const fi of faces) { + const ei = embeddingByFaceID.get(fi); + if (!ei) continue; + for (const ej of personFaceEmbeddings) { + const csim = dotProduct(ei, ej); + if (csim >= 0.6) { + suggest = true; + break; + } + } + if (suggest) break; + } + + if (suggest) suggestedClusters.push(cluster); + } + + return suggestedClusters.sort((a, b) => b.faces.length - a.faces.length); +}; diff --git a/web/packages/new/photos/services/ml/worker.ts b/web/packages/new/photos/services/ml/worker.ts index 42ee2a439f..eb8b5f09dd 100644 --- a/web/packages/new/photos/services/ml/worker.ts +++ b/web/packages/new/photos/services/ml/worker.ts @@ -32,9 +32,9 @@ import { } from "./cluster"; import { saveFaceCrops } from "./crop"; import { - getFaceIndexes, getIndexableFileIDs, markIndexingFailed, + savedFaceIndexes, saveIndexes, updateAssumingLocalFiles, } from "./db"; @@ -327,7 +327,7 @@ export class MLWorker { */ async clusterFaces(masterKey: Uint8Array) { const clusters = await clusterFaces( - await getFaceIndexes(), + await savedFaceIndexes(), await getAllLocalFiles(), (progress) => this.updateClusteringProgress(progress), );