diff --git a/web/packages/new/photos/services/ml/cluster.ts b/web/packages/new/photos/services/ml/cluster.ts index a290f541da..8ea107e78f 100644 --- a/web/packages/new/photos/services/ml/cluster.ts +++ b/web/packages/new/photos/services/ml/cluster.ts @@ -3,6 +3,8 @@ import log from "@/base/log"; import { ensure } from "@/utils/ensure"; import { wait } from "@/utils/promise"; import type { EnteFile } from "../../types/file"; +import { savedCGroups } from "../user-entity"; +import { savedFaceClusters } from "./db"; import { faceDirection, fileIDFromFaceID, @@ -31,17 +33,6 @@ export interface FaceCluster { faces: string[]; } -const clusteringOptions = { - minBlur: 10, - minScore: 0.8, - minClusterSize: 2, - joinThreshold: 0.76, - earlyExitThreshold: 0.9, - batchSize: 10000, - offsetIncrement: 7500, - badFaceHeuristics: true, -}; - export interface ClusteringProgress { completed: number; total: number; @@ -86,22 +77,12 @@ export const clusterFaces = async ( localFiles: EnteFile[], onProgress: (progress: ClusteringProgress) => void, ) => { - const { - minBlur, - minScore, - minClusterSize, - joinThreshold, - earlyExitThreshold, - batchSize, - offsetIncrement, - badFaceHeuristics, - } = clusteringOptions; const t = Date.now(); const localFileByID = new Map(localFiles.map((f) => [f.id, f])); // A flattened array of filtered and annotated faces. - const filteredFaces = [...enumerateFaces(faceIndexes, minBlur, minScore)]; + const filteredFaces = [...enumerateFaces(faceIndexes)]; const fileForFaceID = new Map( filteredFaces.map(({ faceID }) => [ @@ -121,64 +102,69 @@ export const clusterFaces = async ( fileForFace(a).metadata.creationTime, ); - // For fast reverse lookup - map from face ids to the face. - const faceForFaceID = new Map(faces.map((f) => [f.faceID, f])); - // The existing, and resultant, clusters. // // We start with clusters we have currently (whether from remote or from a - // previous local run). This function is called after we have already synced - // cgroups with remote, so we know that it is up to date. - // - // We generate two sets of maps. - // - // Remote: - // - faceID -> clusterID - // - clusterID -> cgroupID - // - // Local: - // - faceID -> clusterID - // - // Then merge them into a single set of maps, giving preference to the - // information that we get from remote. - // - // We then run clustering using these maps (skipping over faces that have - // already been clustered). + // previous local run, but preferring the remote ones). This function is + // called after we have already synced cgroups with remote, so we know that + // it is up to date. - // TODO-Cluster Later on, instead of starting from a blank slate, this will - // be list of existing clusters we fetch from remote. - // - fetchRemoteClusterFeedback - // [..local, ..remote] let clusters: FaceCluster[] = []; + // Get the locally available remote cluster groups. + const cgroups = await savedCGroups(); + // Sort them so that the latest ones are first. + const sortedCGroups = cgroups.sort((a, b) => b.updatedAt - a.updatedAt); + // Extract the remote clusters. + clusters = clusters.concat( + sortedCGroups.map((cg) => cg.data.assigned).flat(), + ); + + // Add on the clusters we have available locally. + clusters = clusters.concat(await savedFaceClusters()); + // For fast reverse lookup - map from the id of a face to the id of the // cluster to which it belongs. - let clusterIDForFaceID = new Map(); + const faceIDToClusterID = new Map(); - // For fast reverse lookup - map from the id of a cluster to its position in - // the clusters array. - let clusterIndexForFaceID = new Map(); + // For fast reverse lookup - map from the id of a face to the index of its + // corresponding cluster in the clusters array. + const faceIDToClusterIndex = new Map(); + + // Fill in the maps with the existing data. Since the remote clusters come + // first, they'll be preferred over any existing local clusters for the same + // face (as is the desired behaviour). + + for (const [i, cluster] of clusters.entries()) { + for (const faceID of cluster.faces) { + if (!faceIDToClusterID.has(faceID)) { + faceIDToClusterID.set(faceID, cluster.id); + faceIDToClusterIndex.set(faceID, i); + } + } + } + + const state = { + faceIDToClusterID, + faceIDToClusterIndex, + clusters, + }; // Process the faces in batches, but keep an overlap between batches to // allow "links" to form with existing clusters. + + const batchSize = 10000; + const offsetIncrement = 7500; + for (let offset = 0; offset < faces.length; offset += offsetIncrement) { const it = Date.now(); const batch = faces.slice(offset, offset + batchSize); log.info(`[batch] processing ${offset} to ${offset + batch.length}`); - const oldState = { - clusterIDForFaceID, - clusterIndexForFaceID, - clusters, - }; - - const newState = await clusterBatchLinear( + await clusterBatchLinear( batch, - oldState, - joinThreshold, - earlyExitThreshold, - badFaceHeuristics, + state, ({ completed }: ClusteringProgress) => onProgress({ completed: offset + completed, @@ -186,19 +172,13 @@ export const clusterFaces = async ( }), ); - clusterIDForFaceID = newState.clusterIDForFaceID; - clusterIndexForFaceID = newState.clusterIndexForFaceID; - clusters = newState.clusters; - log.info( - `[batch] ${newState.clusters.length} clusters from ${batch.length} faces (${Date.now() - it} ms)`, + `[batch] ${state.clusters.length} clusters from ${batch.length} faces (${Date.now() - it} ms)`, ); } // Prune clusters that are smaller than the threshold. - const validClusters = clusters.filter( - (cs) => cs.faces.length > minClusterSize, - ); + const validClusters = clusters.filter((cs) => cs.faces.length > 2); const sortedClusters = validClusters.sort( (a, b) => b.faces.length - a.faces.length, @@ -207,9 +187,13 @@ export const clusterFaces = async ( // TODO-Cluster // This isn't really part of the clustering, but help the main thread out by // pre-computing temporary in-memory people, one per cluster. + + // For fast reverse lookup - map from face ids to the face. + const faceForFaceID = new Map(faces.map((f) => [f.faceID, f])); + const people = toPeople(sortedClusters, localFileByID, faceForFaceID); - const clusteredFaceCount = clusterIDForFaceID.size; + const clusteredFaceCount = faceIDToClusterID.size; const timeTakenMs = Date.now() - t; log.info( `Generated ${sortedClusters.length} clusters from ${faces.length} faces (${clusteredFaceCount} clustered ${faces.length - clusteredFaceCount} unclustered) (${timeTakenMs} ms)`, @@ -230,30 +214,20 @@ export const clusterFaces = async ( * dot product calculations that will happen during clustering and attaches * other information that the clustering algorithm needs. */ -function* enumerateFaces( - faceIndices: FaceIndex[], - minBlur: number, - minScore: number, -) { +function* enumerateFaces(faceIndices: FaceIndex[]) { for (const fi of faceIndices) { - for (const f of fi.faces) { - if (shouldIncludeFace(f, minBlur, minScore)) { + for (const face of fi.faces) { + if (face.blur > 10 && face.score > 0.8) { yield { - ...f, - embedding: new Float32Array(f.embedding), - isBadFace: isBadFace(f), + ...face, + embedding: new Float32Array(face.embedding), + isBadFace: isBadFace(face), }; } } } } -/** - * Return true if the given face is above the minimum inclusion thresholds. - */ -const shouldIncludeFace = (face: Face, minBlur: number, minScore: number) => - face.blur > minBlur && face.score > minScore; - /** * Return true if the given face is above the minimum inclusion thresholds, but * is otherwise heuristically determined to be possibly spurious face detection. @@ -272,29 +246,20 @@ const isSidewaysFace = (face: Face) => const newClusterID = () => newNonSecureID("cluster_"); interface ClusteringState { - clusterIDForFaceID: Map; - clusterIndexForFaceID: Map; + faceIDToClusterID: Map; + faceIDToClusterIndex: Map; clusters: FaceCluster[]; } const clusterBatchLinear = async ( faces: ClusterFace[], - oldState: ClusteringState, - joinThreshold: number, - earlyExitThreshold: number, - badFaceHeuristics: boolean, + state: ClusteringState, onProgress: (progress: ClusteringProgress) => void, ) => { - const state: ClusteringState = { - clusterIDForFaceID: new Map(oldState.clusterIDForFaceID), - clusterIndexForFaceID: new Map(oldState.clusterIndexForFaceID), - clusters: [...oldState.clusters], - }; - // Sort the faces so that the already clustered ones are at the front. faces = faces - .filter((f) => state.clusterIDForFaceID.has(f.faceID)) - .concat(faces.filter((f) => !state.clusterIDForFaceID.has(f.faceID))); + .filter((f) => state.faceIDToClusterID.has(f.faceID)) + .concat(faces.filter((f) => !state.faceIDToClusterID.has(f.faceID))); // For each face in the batch for (const [i, fi] of faces.entries()) { @@ -305,7 +270,7 @@ const clusterBatchLinear = async ( } // If the face is already part of a cluster, then skip it. - if (state.clusterIDForFaceID.has(fi.faceID)) continue; + if (state.faceIDToClusterID.has(fi.faceID)) continue; // Find the nearest neighbour among the previous faces in this batch. let nnIndex: number | undefined; @@ -319,15 +284,14 @@ const clusterBatchLinear = async ( // The vectors are already normalized, so we can directly use their // dot product as their cosine similarity. const csim = dotProduct(fi.embedding, fj.embedding); - const threshold = - badFaceHeuristics && fj.isBadFace ? 0.84 : joinThreshold; + const threshold = fj.isBadFace ? 0.84 : 0.76; if (csim > nnCosineSimilarity && csim >= threshold) { nnIndex = j; nnCosineSimilarity = csim; - // If we've found something "near enough", stop looking for a - // better match (A heuristic to speed up clustering). - if (earlyExitThreshold > 0 && csim >= earlyExitThreshold) break; + // If we've find something above our early exit threshold, stop + // looking for a better match (A way to speed up clustering). + if (csim >= 0.9) break; } } @@ -335,12 +299,12 @@ const clusterBatchLinear = async ( // Found a neighbour close enough, add ourselves to its cluster. const nnFace = ensure(faces[nnIndex]); const nnClusterIndex = ensure( - state.clusterIndexForFaceID.get(nnFace.faceID), + state.faceIDToClusterIndex.get(nnFace.faceID), ); const nnCluster = ensure(state.clusters[nnClusterIndex]); - state.clusterIDForFaceID.set(fi.faceID, nnCluster.id); - state.clusterIndexForFaceID.set(fi.faceID, nnClusterIndex); + state.faceIDToClusterID.set(fi.faceID, nnCluster.id); + state.faceIDToClusterIndex.set(fi.faceID, nnClusterIndex); nnCluster.faces.push(fi.faceID); } else { // No neighbour within the threshold. Create a new cluster. @@ -348,13 +312,11 @@ const clusterBatchLinear = async ( const clusterIndex = state.clusters.length; const cluster = { id: clusterID, faces: [fi.faceID] }; - state.clusterIDForFaceID.set(fi.faceID, cluster.id); - state.clusterIndexForFaceID.set(fi.faceID, clusterIndex); + state.faceIDToClusterID.set(fi.faceID, cluster.id); + state.faceIDToClusterIndex.set(fi.faceID, clusterIndex); state.clusters.push(cluster); } } - - return state; }; /**