This commit is contained in:
Manav Rathi
2024-08-13 19:14:25 +05:30
parent 7f9391f89f
commit 5e4f0d4caf

View File

@@ -1,6 +1,7 @@
import { newNonSecureID } from "@/base/id-worker";
import log from "@/base/log";
import { ensure } from "@/utils/ensure";
import { faceClusters } from "./db";
import type { FaceIndex } from "./face";
import { dotProduct } from "./math";
@@ -63,7 +64,7 @@ export interface Person {
*
* The clusters are generated using locally by clients using this algorithm:
*
* 1. clusters = []
* 1. clusters = [] initially, or fetched from remote.
*
* 2. For each face, find its nearest neighbour in the embedding space from
* amongst the faces that have already been clustered.
@@ -82,18 +83,45 @@ export interface Person {
*
* - They can remove a cluster from a person.
*/
export const clusterFaces = (faceIndexes: FaceIndex[]) => {
export const clusterFaces = async (faceIndexes: FaceIndex[]) => {
const t = Date.now();
// The face data that we need (face ID and its embedding).
const faces = [...faceIDAndEmbeddings(faceIndexes)];
let clusters: FaceCluster[] = [];
const clusterIndexByFaceID = new Map<string, number>();
// Start with the clusters we already have (either from a previous indexing,
// or fetched from remote).
const clusters = await faceClusters();
// For fast reverse lookup - map from cluster ids to the index in the
// clusters array.
const clusterIndexForClusterID = new Map(clusters.map((c, i) => [c.id, i]));
// For fast reverse lookup - map from face ids to the id of the cluster to
// which they belong.
const clusterIDForFaceID = new Map(
clusters.flatMap((c) =>
c.faceIDs.map((faceID) => [faceID, c.id] as const),
),
);
// Generate a new cluster ID
const newClusterID = () => newNonSecureID("cluster_");
// For each face
for (const [i, { faceID, embedding }] of faces.entries()) {
// Find the nearest neighbour from among the faces we have already seen.
// If the face is already part of a cluster, then skip it.
if (clusterIDForFaceID.get(faceID)) continue;
// Find the nearest neighbour from among all the other faces.
let nnIndex: number | undefined;
let nnCosineSimilarity = 0;
for (let j = 0; j < i; j++) {
for (let j = 0; j < faces.length; j++) {
// ! This is an O(n^2) loop, be careful when adding more code here.
// Skip itself
if (i == j) continue;
// Can't find a way of avoiding the null assertion.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const n = faces[j]!;
@@ -106,36 +134,60 @@ export const clusterFaces = (faceIndexes: FaceIndex[]) => {
nnCosineSimilarity = csim;
}
}
if (nnIndex === undefined) {
// We didn't find a neighbour. Create a new cluster with this face.
const cluster = {
id: newNonSecureID("cluster_"),
faceIDs: [faceID],
};
if (nnIndex === undefined) {
// We didn't find a neighbour within the threshold. Create a new
// cluster with this face.
const cluster = { id: newClusterID(), faceIDs: [faceID] };
clusters.push(cluster);
clusterIndexByFaceID.set(faceID, clusters.length);
clusterIndexForClusterID.set(cluster.id, clusters.length);
clusterIDForFaceID.set(faceID, cluster.id);
} else {
// Found a neighbour near enough. Add this face to the neighbour's
// cluster.
// Found a neighbour near enough.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const nn = faces[nnIndex]!;
const nnClusterIndex = ensure(clusterIndexByFaceID.get(nn.faceID));
clusters[nnClusterIndex]?.faceIDs.push(faceID);
clusterIndexByFaceID.set(faceID, nnClusterIndex);
// Find the cluster the nearest neighbour belongs to, if any.
const nnClusterID = clusterIDForFaceID.get(nn.faceID);
if (nnClusterID) {
// If the neighbour is already part of a cluster, also add
// ourselves to that cluster.
const nnClusterIndex = ensure(
clusterIndexForClusterID.get(nnClusterID),
);
clusters[nnClusterIndex]?.faceIDs.push(faceID);
clusterIDForFaceID.set(faceID, nnClusterID);
} else {
// Create a new cluster with us and our nearest neighbour.
const cluster = {
id: newClusterID(),
faceIDs: [faceID, nn.faceID],
};
clusters.push(cluster);
clusterIndexForClusterID.set(cluster.id, clusters.length);
clusterIDForFaceID.set(faceID, cluster.id);
clusterIDForFaceID.set(nn.faceID, cluster.id);
}
}
}
clusters = clusters.filter(({ faceIDs }) => faceIDs.length > 1);
const validClusters = clusters.filter(({ faceIDs }) => faceIDs.length > 1);
log.debug(() => ["ml/cluster", { faces, clusters, clusterIndexByFaceID }]);
log.debug(() => [
"ml/cluster",
{ faces, validClusters, clusterIndexForClusterID, clusterIDForFaceID },
]);
log.debug(
() =>
`Clustered ${faces.length} faces into ${clusters.length} clusters (${Date.now() - t} ms)`,
`Clustered ${faces.length} faces into ${validClusters.length} clusters (${Date.now() - t} ms)`,
);
return clusters;
return validClusters;
};
/**