Start with existing clusters
This commit is contained in:
@@ -3,6 +3,8 @@ import log from "@/base/log";
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import { wait } from "@/utils/promise";
|
||||
import type { EnteFile } from "../../types/file";
|
||||
import { savedCGroups } from "../user-entity";
|
||||
import { savedFaceClusters } from "./db";
|
||||
import {
|
||||
faceDirection,
|
||||
fileIDFromFaceID,
|
||||
@@ -31,17 +33,6 @@ export interface FaceCluster {
|
||||
faces: string[];
|
||||
}
|
||||
|
||||
const clusteringOptions = {
|
||||
minBlur: 10,
|
||||
minScore: 0.8,
|
||||
minClusterSize: 2,
|
||||
joinThreshold: 0.76,
|
||||
earlyExitThreshold: 0.9,
|
||||
batchSize: 10000,
|
||||
offsetIncrement: 7500,
|
||||
badFaceHeuristics: true,
|
||||
};
|
||||
|
||||
export interface ClusteringProgress {
|
||||
completed: number;
|
||||
total: number;
|
||||
@@ -86,22 +77,12 @@ export const clusterFaces = async (
|
||||
localFiles: EnteFile[],
|
||||
onProgress: (progress: ClusteringProgress) => void,
|
||||
) => {
|
||||
const {
|
||||
minBlur,
|
||||
minScore,
|
||||
minClusterSize,
|
||||
joinThreshold,
|
||||
earlyExitThreshold,
|
||||
batchSize,
|
||||
offsetIncrement,
|
||||
badFaceHeuristics,
|
||||
} = clusteringOptions;
|
||||
const t = Date.now();
|
||||
|
||||
const localFileByID = new Map(localFiles.map((f) => [f.id, f]));
|
||||
|
||||
// A flattened array of filtered and annotated faces.
|
||||
const filteredFaces = [...enumerateFaces(faceIndexes, minBlur, minScore)];
|
||||
const filteredFaces = [...enumerateFaces(faceIndexes)];
|
||||
|
||||
const fileForFaceID = new Map(
|
||||
filteredFaces.map(({ faceID }) => [
|
||||
@@ -121,64 +102,69 @@ export const clusterFaces = async (
|
||||
fileForFace(a).metadata.creationTime,
|
||||
);
|
||||
|
||||
// For fast reverse lookup - map from face ids to the face.
|
||||
const faceForFaceID = new Map(faces.map((f) => [f.faceID, f]));
|
||||
|
||||
// The existing, and resultant, clusters.
|
||||
//
|
||||
// We start with clusters we have currently (whether from remote or from a
|
||||
// previous local run). This function is called after we have already synced
|
||||
// cgroups with remote, so we know that it is up to date.
|
||||
//
|
||||
// We generate two sets of maps.
|
||||
//
|
||||
// Remote:
|
||||
// - faceID -> clusterID
|
||||
// - clusterID -> cgroupID
|
||||
//
|
||||
// Local:
|
||||
// - faceID -> clusterID
|
||||
//
|
||||
// Then merge them into a single set of maps, giving preference to the
|
||||
// information that we get from remote.
|
||||
//
|
||||
// We then run clustering using these maps (skipping over faces that have
|
||||
// already been clustered).
|
||||
// previous local run, but preferring the remote ones). This function is
|
||||
// called after we have already synced cgroups with remote, so we know that
|
||||
// it is up to date.
|
||||
|
||||
// TODO-Cluster Later on, instead of starting from a blank slate, this will
|
||||
// be list of existing clusters we fetch from remote.
|
||||
// - fetchRemoteClusterFeedback
|
||||
// [..local, ..remote]
|
||||
let clusters: FaceCluster[] = [];
|
||||
|
||||
// Get the locally available remote cluster groups.
|
||||
const cgroups = await savedCGroups();
|
||||
// Sort them so that the latest ones are first.
|
||||
const sortedCGroups = cgroups.sort((a, b) => b.updatedAt - a.updatedAt);
|
||||
// Extract the remote clusters.
|
||||
clusters = clusters.concat(
|
||||
sortedCGroups.map((cg) => cg.data.assigned).flat(),
|
||||
);
|
||||
|
||||
// Add on the clusters we have available locally.
|
||||
clusters = clusters.concat(await savedFaceClusters());
|
||||
|
||||
// For fast reverse lookup - map from the id of a face to the id of the
|
||||
// cluster to which it belongs.
|
||||
let clusterIDForFaceID = new Map<string, string>();
|
||||
const faceIDToClusterID = new Map<string, string>();
|
||||
|
||||
// For fast reverse lookup - map from the id of a cluster to its position in
|
||||
// the clusters array.
|
||||
let clusterIndexForFaceID = new Map<string, number>();
|
||||
// For fast reverse lookup - map from the id of a face to the index of its
|
||||
// corresponding cluster in the clusters array.
|
||||
const faceIDToClusterIndex = new Map<string, number>();
|
||||
|
||||
// Fill in the maps with the existing data. Since the remote clusters come
|
||||
// first, they'll be preferred over any existing local clusters for the same
|
||||
// face (as is the desired behaviour).
|
||||
|
||||
for (const [i, cluster] of clusters.entries()) {
|
||||
for (const faceID of cluster.faces) {
|
||||
if (!faceIDToClusterID.has(faceID)) {
|
||||
faceIDToClusterID.set(faceID, cluster.id);
|
||||
faceIDToClusterIndex.set(faceID, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const state = {
|
||||
faceIDToClusterID,
|
||||
faceIDToClusterIndex,
|
||||
clusters,
|
||||
};
|
||||
|
||||
// Process the faces in batches, but keep an overlap between batches to
|
||||
// allow "links" to form with existing clusters.
|
||||
|
||||
const batchSize = 10000;
|
||||
const offsetIncrement = 7500;
|
||||
|
||||
for (let offset = 0; offset < faces.length; offset += offsetIncrement) {
|
||||
const it = Date.now();
|
||||
|
||||
const batch = faces.slice(offset, offset + batchSize);
|
||||
log.info(`[batch] processing ${offset} to ${offset + batch.length}`);
|
||||
|
||||
const oldState = {
|
||||
clusterIDForFaceID,
|
||||
clusterIndexForFaceID,
|
||||
clusters,
|
||||
};
|
||||
|
||||
const newState = await clusterBatchLinear(
|
||||
await clusterBatchLinear(
|
||||
batch,
|
||||
oldState,
|
||||
joinThreshold,
|
||||
earlyExitThreshold,
|
||||
badFaceHeuristics,
|
||||
state,
|
||||
({ completed }: ClusteringProgress) =>
|
||||
onProgress({
|
||||
completed: offset + completed,
|
||||
@@ -186,19 +172,13 @@ export const clusterFaces = async (
|
||||
}),
|
||||
);
|
||||
|
||||
clusterIDForFaceID = newState.clusterIDForFaceID;
|
||||
clusterIndexForFaceID = newState.clusterIndexForFaceID;
|
||||
clusters = newState.clusters;
|
||||
|
||||
log.info(
|
||||
`[batch] ${newState.clusters.length} clusters from ${batch.length} faces (${Date.now() - it} ms)`,
|
||||
`[batch] ${state.clusters.length} clusters from ${batch.length} faces (${Date.now() - it} ms)`,
|
||||
);
|
||||
}
|
||||
|
||||
// Prune clusters that are smaller than the threshold.
|
||||
const validClusters = clusters.filter(
|
||||
(cs) => cs.faces.length > minClusterSize,
|
||||
);
|
||||
const validClusters = clusters.filter((cs) => cs.faces.length > 2);
|
||||
|
||||
const sortedClusters = validClusters.sort(
|
||||
(a, b) => b.faces.length - a.faces.length,
|
||||
@@ -207,9 +187,13 @@ export const clusterFaces = async (
|
||||
// TODO-Cluster
|
||||
// This isn't really part of the clustering, but help the main thread out by
|
||||
// pre-computing temporary in-memory people, one per cluster.
|
||||
|
||||
// For fast reverse lookup - map from face ids to the face.
|
||||
const faceForFaceID = new Map(faces.map((f) => [f.faceID, f]));
|
||||
|
||||
const people = toPeople(sortedClusters, localFileByID, faceForFaceID);
|
||||
|
||||
const clusteredFaceCount = clusterIDForFaceID.size;
|
||||
const clusteredFaceCount = faceIDToClusterID.size;
|
||||
const timeTakenMs = Date.now() - t;
|
||||
log.info(
|
||||
`Generated ${sortedClusters.length} clusters from ${faces.length} faces (${clusteredFaceCount} clustered ${faces.length - clusteredFaceCount} unclustered) (${timeTakenMs} ms)`,
|
||||
@@ -230,30 +214,20 @@ export const clusterFaces = async (
|
||||
* dot product calculations that will happen during clustering and attaches
|
||||
* other information that the clustering algorithm needs.
|
||||
*/
|
||||
function* enumerateFaces(
|
||||
faceIndices: FaceIndex[],
|
||||
minBlur: number,
|
||||
minScore: number,
|
||||
) {
|
||||
function* enumerateFaces(faceIndices: FaceIndex[]) {
|
||||
for (const fi of faceIndices) {
|
||||
for (const f of fi.faces) {
|
||||
if (shouldIncludeFace(f, minBlur, minScore)) {
|
||||
for (const face of fi.faces) {
|
||||
if (face.blur > 10 && face.score > 0.8) {
|
||||
yield {
|
||||
...f,
|
||||
embedding: new Float32Array(f.embedding),
|
||||
isBadFace: isBadFace(f),
|
||||
...face,
|
||||
embedding: new Float32Array(face.embedding),
|
||||
isBadFace: isBadFace(face),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return true if the given face is above the minimum inclusion thresholds.
|
||||
*/
|
||||
const shouldIncludeFace = (face: Face, minBlur: number, minScore: number) =>
|
||||
face.blur > minBlur && face.score > minScore;
|
||||
|
||||
/**
|
||||
* Return true if the given face is above the minimum inclusion thresholds, but
|
||||
* is otherwise heuristically determined to be possibly spurious face detection.
|
||||
@@ -272,29 +246,20 @@ const isSidewaysFace = (face: Face) =>
|
||||
const newClusterID = () => newNonSecureID("cluster_");
|
||||
|
||||
interface ClusteringState {
|
||||
clusterIDForFaceID: Map<string, string>;
|
||||
clusterIndexForFaceID: Map<string, number>;
|
||||
faceIDToClusterID: Map<string, string>;
|
||||
faceIDToClusterIndex: Map<string, number>;
|
||||
clusters: FaceCluster[];
|
||||
}
|
||||
|
||||
const clusterBatchLinear = async (
|
||||
faces: ClusterFace[],
|
||||
oldState: ClusteringState,
|
||||
joinThreshold: number,
|
||||
earlyExitThreshold: number,
|
||||
badFaceHeuristics: boolean,
|
||||
state: ClusteringState,
|
||||
onProgress: (progress: ClusteringProgress) => void,
|
||||
) => {
|
||||
const state: ClusteringState = {
|
||||
clusterIDForFaceID: new Map(oldState.clusterIDForFaceID),
|
||||
clusterIndexForFaceID: new Map(oldState.clusterIndexForFaceID),
|
||||
clusters: [...oldState.clusters],
|
||||
};
|
||||
|
||||
// Sort the faces so that the already clustered ones are at the front.
|
||||
faces = faces
|
||||
.filter((f) => state.clusterIDForFaceID.has(f.faceID))
|
||||
.concat(faces.filter((f) => !state.clusterIDForFaceID.has(f.faceID)));
|
||||
.filter((f) => state.faceIDToClusterID.has(f.faceID))
|
||||
.concat(faces.filter((f) => !state.faceIDToClusterID.has(f.faceID)));
|
||||
|
||||
// For each face in the batch
|
||||
for (const [i, fi] of faces.entries()) {
|
||||
@@ -305,7 +270,7 @@ const clusterBatchLinear = async (
|
||||
}
|
||||
|
||||
// If the face is already part of a cluster, then skip it.
|
||||
if (state.clusterIDForFaceID.has(fi.faceID)) continue;
|
||||
if (state.faceIDToClusterID.has(fi.faceID)) continue;
|
||||
|
||||
// Find the nearest neighbour among the previous faces in this batch.
|
||||
let nnIndex: number | undefined;
|
||||
@@ -319,15 +284,14 @@ const clusterBatchLinear = async (
|
||||
// The vectors are already normalized, so we can directly use their
|
||||
// dot product as their cosine similarity.
|
||||
const csim = dotProduct(fi.embedding, fj.embedding);
|
||||
const threshold =
|
||||
badFaceHeuristics && fj.isBadFace ? 0.84 : joinThreshold;
|
||||
const threshold = fj.isBadFace ? 0.84 : 0.76;
|
||||
if (csim > nnCosineSimilarity && csim >= threshold) {
|
||||
nnIndex = j;
|
||||
nnCosineSimilarity = csim;
|
||||
|
||||
// If we've found something "near enough", stop looking for a
|
||||
// better match (A heuristic to speed up clustering).
|
||||
if (earlyExitThreshold > 0 && csim >= earlyExitThreshold) break;
|
||||
// If we've find something above our early exit threshold, stop
|
||||
// looking for a better match (A way to speed up clustering).
|
||||
if (csim >= 0.9) break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,12 +299,12 @@ const clusterBatchLinear = async (
|
||||
// Found a neighbour close enough, add ourselves to its cluster.
|
||||
const nnFace = ensure(faces[nnIndex]);
|
||||
const nnClusterIndex = ensure(
|
||||
state.clusterIndexForFaceID.get(nnFace.faceID),
|
||||
state.faceIDToClusterIndex.get(nnFace.faceID),
|
||||
);
|
||||
const nnCluster = ensure(state.clusters[nnClusterIndex]);
|
||||
|
||||
state.clusterIDForFaceID.set(fi.faceID, nnCluster.id);
|
||||
state.clusterIndexForFaceID.set(fi.faceID, nnClusterIndex);
|
||||
state.faceIDToClusterID.set(fi.faceID, nnCluster.id);
|
||||
state.faceIDToClusterIndex.set(fi.faceID, nnClusterIndex);
|
||||
nnCluster.faces.push(fi.faceID);
|
||||
} else {
|
||||
// No neighbour within the threshold. Create a new cluster.
|
||||
@@ -348,13 +312,11 @@ const clusterBatchLinear = async (
|
||||
const clusterIndex = state.clusters.length;
|
||||
const cluster = { id: clusterID, faces: [fi.faceID] };
|
||||
|
||||
state.clusterIDForFaceID.set(fi.faceID, cluster.id);
|
||||
state.clusterIndexForFaceID.set(fi.faceID, clusterIndex);
|
||||
state.faceIDToClusterID.set(fi.faceID, cluster.id);
|
||||
state.faceIDToClusterIndex.set(fi.faceID, clusterIndex);
|
||||
state.clusters.push(cluster);
|
||||
}
|
||||
}
|
||||
|
||||
return state;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user