Start with existing clusters

This commit is contained in:
Manav Rathi
2024-09-25 16:22:25 +05:30
parent 31ffc5bee5
commit 1ff21b3d8e

View File

@@ -3,6 +3,8 @@ import log from "@/base/log";
import { ensure } from "@/utils/ensure";
import { wait } from "@/utils/promise";
import type { EnteFile } from "../../types/file";
import { savedCGroups } from "../user-entity";
import { savedFaceClusters } from "./db";
import {
faceDirection,
fileIDFromFaceID,
@@ -31,17 +33,6 @@ export interface FaceCluster {
faces: string[];
}
const clusteringOptions = {
minBlur: 10,
minScore: 0.8,
minClusterSize: 2,
joinThreshold: 0.76,
earlyExitThreshold: 0.9,
batchSize: 10000,
offsetIncrement: 7500,
badFaceHeuristics: true,
};
export interface ClusteringProgress {
completed: number;
total: number;
@@ -86,22 +77,12 @@ export const clusterFaces = async (
localFiles: EnteFile[],
onProgress: (progress: ClusteringProgress) => void,
) => {
const {
minBlur,
minScore,
minClusterSize,
joinThreshold,
earlyExitThreshold,
batchSize,
offsetIncrement,
badFaceHeuristics,
} = clusteringOptions;
const t = Date.now();
const localFileByID = new Map(localFiles.map((f) => [f.id, f]));
// A flattened array of filtered and annotated faces.
const filteredFaces = [...enumerateFaces(faceIndexes, minBlur, minScore)];
const filteredFaces = [...enumerateFaces(faceIndexes)];
const fileForFaceID = new Map(
filteredFaces.map(({ faceID }) => [
@@ -121,64 +102,69 @@ export const clusterFaces = async (
fileForFace(a).metadata.creationTime,
);
// For fast reverse lookup - map from face ids to the face.
const faceForFaceID = new Map(faces.map((f) => [f.faceID, f]));
// The existing, and resultant, clusters.
//
// We start with clusters we have currently (whether from remote or from a
// previous local run). This function is called after we have already synced
// cgroups with remote, so we know that it is up to date.
//
// We generate two sets of maps.
//
// Remote:
// - faceID -> clusterID
// - clusterID -> cgroupID
//
// Local:
// - faceID -> clusterID
//
// Then merge them into a single set of maps, giving preference to the
// information that we get from remote.
//
// We then run clustering using these maps (skipping over faces that have
// already been clustered).
// previous local run, but preferring the remote ones). This function is
// called after we have already synced cgroups with remote, so we know that
// it is up to date.
// TODO-Cluster Later on, instead of starting from a blank slate, this will
// be list of existing clusters we fetch from remote.
// - fetchRemoteClusterFeedback
// [..local, ..remote]
let clusters: FaceCluster[] = [];
// Get the locally available remote cluster groups.
const cgroups = await savedCGroups();
// Sort them so that the latest ones are first.
const sortedCGroups = cgroups.sort((a, b) => b.updatedAt - a.updatedAt);
// Extract the remote clusters.
clusters = clusters.concat(
sortedCGroups.map((cg) => cg.data.assigned).flat(),
);
// Add on the clusters we have available locally.
clusters = clusters.concat(await savedFaceClusters());
// For fast reverse lookup - map from the id of a face to the id of the
// cluster to which it belongs.
let clusterIDForFaceID = new Map<string, string>();
const faceIDToClusterID = new Map<string, string>();
// For fast reverse lookup - map from the id of a cluster to its position in
// the clusters array.
let clusterIndexForFaceID = new Map<string, number>();
// For fast reverse lookup - map from the id of a face to the index of its
// corresponding cluster in the clusters array.
const faceIDToClusterIndex = new Map<string, number>();
// Fill in the maps with the existing data. Since the remote clusters come
// first, they'll be preferred over any existing local clusters for the same
// face (as is the desired behaviour).
for (const [i, cluster] of clusters.entries()) {
for (const faceID of cluster.faces) {
if (!faceIDToClusterID.has(faceID)) {
faceIDToClusterID.set(faceID, cluster.id);
faceIDToClusterIndex.set(faceID, i);
}
}
}
const state = {
faceIDToClusterID,
faceIDToClusterIndex,
clusters,
};
// Process the faces in batches, but keep an overlap between batches to
// allow "links" to form with existing clusters.
const batchSize = 10000;
const offsetIncrement = 7500;
for (let offset = 0; offset < faces.length; offset += offsetIncrement) {
const it = Date.now();
const batch = faces.slice(offset, offset + batchSize);
log.info(`[batch] processing ${offset} to ${offset + batch.length}`);
const oldState = {
clusterIDForFaceID,
clusterIndexForFaceID,
clusters,
};
const newState = await clusterBatchLinear(
await clusterBatchLinear(
batch,
oldState,
joinThreshold,
earlyExitThreshold,
badFaceHeuristics,
state,
({ completed }: ClusteringProgress) =>
onProgress({
completed: offset + completed,
@@ -186,19 +172,13 @@ export const clusterFaces = async (
}),
);
clusterIDForFaceID = newState.clusterIDForFaceID;
clusterIndexForFaceID = newState.clusterIndexForFaceID;
clusters = newState.clusters;
log.info(
`[batch] ${newState.clusters.length} clusters from ${batch.length} faces (${Date.now() - it} ms)`,
`[batch] ${state.clusters.length} clusters from ${batch.length} faces (${Date.now() - it} ms)`,
);
}
// Prune clusters that are smaller than the threshold.
const validClusters = clusters.filter(
(cs) => cs.faces.length > minClusterSize,
);
const validClusters = clusters.filter((cs) => cs.faces.length > 2);
const sortedClusters = validClusters.sort(
(a, b) => b.faces.length - a.faces.length,
@@ -207,9 +187,13 @@ export const clusterFaces = async (
// TODO-Cluster
// This isn't really part of the clustering, but help the main thread out by
// pre-computing temporary in-memory people, one per cluster.
// For fast reverse lookup - map from face ids to the face.
const faceForFaceID = new Map(faces.map((f) => [f.faceID, f]));
const people = toPeople(sortedClusters, localFileByID, faceForFaceID);
const clusteredFaceCount = clusterIDForFaceID.size;
const clusteredFaceCount = faceIDToClusterID.size;
const timeTakenMs = Date.now() - t;
log.info(
`Generated ${sortedClusters.length} clusters from ${faces.length} faces (${clusteredFaceCount} clustered ${faces.length - clusteredFaceCount} unclustered) (${timeTakenMs} ms)`,
@@ -230,30 +214,20 @@ export const clusterFaces = async (
* dot product calculations that will happen during clustering and attaches
* other information that the clustering algorithm needs.
*/
function* enumerateFaces(
faceIndices: FaceIndex[],
minBlur: number,
minScore: number,
) {
function* enumerateFaces(faceIndices: FaceIndex[]) {
for (const fi of faceIndices) {
for (const f of fi.faces) {
if (shouldIncludeFace(f, minBlur, minScore)) {
for (const face of fi.faces) {
if (face.blur > 10 && face.score > 0.8) {
yield {
...f,
embedding: new Float32Array(f.embedding),
isBadFace: isBadFace(f),
...face,
embedding: new Float32Array(face.embedding),
isBadFace: isBadFace(face),
};
}
}
}
}
/**
* Return true if the given face is above the minimum inclusion thresholds.
*/
const shouldIncludeFace = (face: Face, minBlur: number, minScore: number) =>
face.blur > minBlur && face.score > minScore;
/**
* Return true if the given face is above the minimum inclusion thresholds, but
* is otherwise heuristically determined to be possibly spurious face detection.
@@ -272,29 +246,20 @@ const isSidewaysFace = (face: Face) =>
const newClusterID = () => newNonSecureID("cluster_");
interface ClusteringState {
clusterIDForFaceID: Map<string, string>;
clusterIndexForFaceID: Map<string, number>;
faceIDToClusterID: Map<string, string>;
faceIDToClusterIndex: Map<string, number>;
clusters: FaceCluster[];
}
const clusterBatchLinear = async (
faces: ClusterFace[],
oldState: ClusteringState,
joinThreshold: number,
earlyExitThreshold: number,
badFaceHeuristics: boolean,
state: ClusteringState,
onProgress: (progress: ClusteringProgress) => void,
) => {
const state: ClusteringState = {
clusterIDForFaceID: new Map(oldState.clusterIDForFaceID),
clusterIndexForFaceID: new Map(oldState.clusterIndexForFaceID),
clusters: [...oldState.clusters],
};
// Sort the faces so that the already clustered ones are at the front.
faces = faces
.filter((f) => state.clusterIDForFaceID.has(f.faceID))
.concat(faces.filter((f) => !state.clusterIDForFaceID.has(f.faceID)));
.filter((f) => state.faceIDToClusterID.has(f.faceID))
.concat(faces.filter((f) => !state.faceIDToClusterID.has(f.faceID)));
// For each face in the batch
for (const [i, fi] of faces.entries()) {
@@ -305,7 +270,7 @@ const clusterBatchLinear = async (
}
// If the face is already part of a cluster, then skip it.
if (state.clusterIDForFaceID.has(fi.faceID)) continue;
if (state.faceIDToClusterID.has(fi.faceID)) continue;
// Find the nearest neighbour among the previous faces in this batch.
let nnIndex: number | undefined;
@@ -319,15 +284,14 @@ const clusterBatchLinear = async (
// The vectors are already normalized, so we can directly use their
// dot product as their cosine similarity.
const csim = dotProduct(fi.embedding, fj.embedding);
const threshold =
badFaceHeuristics && fj.isBadFace ? 0.84 : joinThreshold;
const threshold = fj.isBadFace ? 0.84 : 0.76;
if (csim > nnCosineSimilarity && csim >= threshold) {
nnIndex = j;
nnCosineSimilarity = csim;
// If we've found something "near enough", stop looking for a
// better match (A heuristic to speed up clustering).
if (earlyExitThreshold > 0 && csim >= earlyExitThreshold) break;
// If we've find something above our early exit threshold, stop
// looking for a better match (A way to speed up clustering).
if (csim >= 0.9) break;
}
}
@@ -335,12 +299,12 @@ const clusterBatchLinear = async (
// Found a neighbour close enough, add ourselves to its cluster.
const nnFace = ensure(faces[nnIndex]);
const nnClusterIndex = ensure(
state.clusterIndexForFaceID.get(nnFace.faceID),
state.faceIDToClusterIndex.get(nnFace.faceID),
);
const nnCluster = ensure(state.clusters[nnClusterIndex]);
state.clusterIDForFaceID.set(fi.faceID, nnCluster.id);
state.clusterIndexForFaceID.set(fi.faceID, nnClusterIndex);
state.faceIDToClusterID.set(fi.faceID, nnCluster.id);
state.faceIDToClusterIndex.set(fi.faceID, nnClusterIndex);
nnCluster.faces.push(fi.faceID);
} else {
// No neighbour within the threshold. Create a new cluster.
@@ -348,13 +312,11 @@ const clusterBatchLinear = async (
const clusterIndex = state.clusters.length;
const cluster = { id: clusterID, faces: [fi.faceID] };
state.clusterIDForFaceID.set(fi.faceID, cluster.id);
state.clusterIndexForFaceID.set(fi.faceID, clusterIndex);
state.faceIDToClusterID.set(fi.faceID, cluster.id);
state.faceIDToClusterIndex.set(fi.faceID, clusterIndex);
state.clusters.push(cluster);
}
}
return state;
};
/**