diff --git a/web/packages/new/photos/services/ml/cluster-new.ts b/web/packages/new/photos/services/ml/cluster-new.ts index c412c43f9a..0e4919d517 100644 --- a/web/packages/new/photos/services/ml/cluster-new.ts +++ b/web/packages/new/photos/services/ml/cluster-new.ts @@ -1,6 +1,7 @@ import { newNonSecureID } from "@/base/id-worker"; import log from "@/base/log"; import { ensure } from "@/utils/ensure"; +import { faceClusters } from "./db"; import type { FaceIndex } from "./face"; import { dotProduct } from "./math"; @@ -63,7 +64,7 @@ export interface Person { * * The clusters are generated using locally by clients using this algorithm: * - * 1. clusters = [] + * 1. clusters = [] initially, or fetched from remote. * * 2. For each face, find its nearest neighbour in the embedding space from * amongst the faces that have already been clustered. @@ -82,18 +83,45 @@ export interface Person { * * - They can remove a cluster from a person. */ -export const clusterFaces = (faceIndexes: FaceIndex[]) => { +export const clusterFaces = async (faceIndexes: FaceIndex[]) => { const t = Date.now(); + // The face data that we need (face ID and its embedding). const faces = [...faceIDAndEmbeddings(faceIndexes)]; - let clusters: FaceCluster[] = []; - const clusterIndexByFaceID = new Map(); + // Start with the clusters we already have (either from a previous indexing, + // or fetched from remote). + const clusters = await faceClusters(); + + // For fast reverse lookup - map from cluster ids to the index in the + // clusters array. + const clusterIndexForClusterID = new Map(clusters.map((c, i) => [c.id, i])); + + // For fast reverse lookup - map from face ids to the id of the cluster to + // which they belong. + const clusterIDForFaceID = new Map( + clusters.flatMap((c) => + c.faceIDs.map((faceID) => [faceID, c.id] as const), + ), + ); + + // Generate a new cluster ID + const newClusterID = () => newNonSecureID("cluster_"); + + // For each face for (const [i, { faceID, embedding }] of faces.entries()) { - // Find the nearest neighbour from among the faces we have already seen. + // If the face is already part of a cluster, then skip it. + if (clusterIDForFaceID.get(faceID)) continue; + + // Find the nearest neighbour from among all the other faces. let nnIndex: number | undefined; let nnCosineSimilarity = 0; - for (let j = 0; j < i; j++) { + for (let j = 0; j < faces.length; j++) { + // ! This is an O(n^2) loop, be careful when adding more code here. + + // Skip itself + if (i == j) continue; + // Can't find a way of avoiding the null assertion. // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const n = faces[j]!; @@ -106,36 +134,60 @@ export const clusterFaces = (faceIndexes: FaceIndex[]) => { nnCosineSimilarity = csim; } } - if (nnIndex === undefined) { - // We didn't find a neighbour. Create a new cluster with this face. - const cluster = { - id: newNonSecureID("cluster_"), - faceIDs: [faceID], - }; + if (nnIndex === undefined) { + // We didn't find a neighbour within the threshold. Create a new + // cluster with this face. + + const cluster = { id: newClusterID(), faceIDs: [faceID] }; clusters.push(cluster); - clusterIndexByFaceID.set(faceID, clusters.length); + clusterIndexForClusterID.set(cluster.id, clusters.length); + clusterIDForFaceID.set(faceID, cluster.id); } else { - // Found a neighbour near enough. Add this face to the neighbour's - // cluster. + // Found a neighbour near enough. // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const nn = faces[nnIndex]!; - const nnClusterIndex = ensure(clusterIndexByFaceID.get(nn.faceID)); - clusters[nnClusterIndex]?.faceIDs.push(faceID); - clusterIndexByFaceID.set(faceID, nnClusterIndex); + + // Find the cluster the nearest neighbour belongs to, if any. + const nnClusterID = clusterIDForFaceID.get(nn.faceID); + + if (nnClusterID) { + // If the neighbour is already part of a cluster, also add + // ourselves to that cluster. + + const nnClusterIndex = ensure( + clusterIndexForClusterID.get(nnClusterID), + ); + clusters[nnClusterIndex]?.faceIDs.push(faceID); + clusterIDForFaceID.set(faceID, nnClusterID); + } else { + // Create a new cluster with us and our nearest neighbour. + + const cluster = { + id: newClusterID(), + faceIDs: [faceID, nn.faceID], + }; + clusters.push(cluster); + clusterIndexForClusterID.set(cluster.id, clusters.length); + clusterIDForFaceID.set(faceID, cluster.id); + clusterIDForFaceID.set(nn.faceID, cluster.id); + } } } - clusters = clusters.filter(({ faceIDs }) => faceIDs.length > 1); + const validClusters = clusters.filter(({ faceIDs }) => faceIDs.length > 1); - log.debug(() => ["ml/cluster", { faces, clusters, clusterIndexByFaceID }]); + log.debug(() => [ + "ml/cluster", + { faces, validClusters, clusterIndexForClusterID, clusterIDForFaceID }, + ]); log.debug( () => - `Clustered ${faces.length} faces into ${clusters.length} clusters (${Date.now() - t} ms)`, + `Clustered ${faces.length} faces into ${validClusters.length} clusters (${Date.now() - t} ms)`, ); - return clusters; + return validClusters; }; /**