diff --git a/web/apps/photos/src/pages/cluster-debug.tsx b/web/apps/photos/src/pages/cluster-debug.tsx
index 3b9a7ad230..1f6a1036f6 100644
--- a/web/apps/photos/src/pages/cluster-debug.tsx
+++ b/web/apps/photos/src/pages/cluster-debug.tsx
@@ -6,9 +6,9 @@ import {
type ClusterDebugPageContents,
} from "@/new/photos/services/ml";
import {
+ type ClusterFace,
type ClusteringOpts,
type ClusteringProgress,
- type FaceF32,
type OnClusteringProgress,
} from "@/new/photos/services/ml/cluster";
import { faceDirection } from "@/new/photos/services/ml/face";
@@ -22,6 +22,8 @@ import BackButton from "@mui/icons-material/ArrowBackOutlined";
import {
Box,
Button,
+ Checkbox,
+ FormControlLabel,
IconButton,
LinearProgress,
Stack,
@@ -68,10 +70,11 @@ export default function ClusterDebug() {
minBlur: 10,
minScore: 0.8,
minClusterSize: 2,
- joinThreshold: 0.6,
+ joinThreshold: 0.76,
earlyExitThreshold: 0.9,
batchSize: 10000,
offsetIncrement: 7500,
+ badFaceHeuristics: true,
},
onSubmit: (values) =>
cluster(
@@ -83,6 +86,7 @@ export default function ClusterDebug() {
earlyExitThreshold: toFloat(values.earlyExitThreshold),
batchSize: toFloat(values.batchSize),
offsetIncrement: toFloat(values.offsetIncrement),
+ badFaceHeuristics: values.badFaceHeuristics,
},
(progress: ClusteringProgress) =>
onProgressRef.current?.(progress),
@@ -227,7 +231,22 @@ const MemoizedForm = memo(
onChange={handleChange}
/>
-
+
+
+ }
+ label={
+
+ Bad face heuristics
+
+ }
+ />
-
+
),
@@ -325,7 +344,7 @@ const ClusterList: React.FC> = ({
index === 0
? 140
: index === 1
- ? 130
+ ? 110
: Array.isArray(items[index - 2])
? listItemHeight
: 36;
@@ -447,15 +466,11 @@ const ClusterResHeader: React.FC = ({ clusterRes }) => {
For each cluster showing only up to 50 faces, sorted by cosine
- similarity to highest scoring face in the cluster.
+ similarity to its highest scoring face.
- Below each face is its{" "}
- blur - score - cosineSimilarity - direction.
-
-
- Faces added to the cluster as a result of next batch merging are
- outlined.
+ Below each face is its blur, score, cosineSimilarity, direction.
+ Bad faces are outlined.
);
@@ -494,15 +509,15 @@ interface FaceItemProps {
}
interface FaceWithFile {
- face: FaceF32;
+ face: ClusterFace;
enteFile: EnteFile;
cosineSimilarity?: number;
wasMerged?: boolean;
}
const FaceItem: React.FC = ({ faceWithFile }) => {
- const { face, enteFile, cosineSimilarity, wasMerged } = faceWithFile;
- const { faceID } = face;
+ const { face, enteFile, cosineSimilarity } = faceWithFile;
+ const { faceID, isBadFace } = face;
const [objectURL, setObjectURL] = useState();
@@ -526,7 +541,7 @@ const FaceItem: React.FC = ({ faceWithFile }) => {
return (
diff --git a/web/packages/new/photos/services/ml/cluster.ts b/web/packages/new/photos/services/ml/cluster.ts
index 1aa665666a..e4928aa991 100644
--- a/web/packages/new/photos/services/ml/cluster.ts
+++ b/web/packages/new/photos/services/ml/cluster.ts
@@ -3,7 +3,7 @@ import { newNonSecureID } from "@/base/id-worker";
import log from "@/base/log";
import { ensure } from "@/utils/ensure";
import type { EnteFile } from "../../types/file";
-import type { Face, FaceIndex } from "./face";
+import { faceDirection, type Face, type FaceIndex } from "./face";
import { dotProduct } from "./math";
/**
@@ -121,6 +121,7 @@ export interface ClusteringOpts {
earlyExitThreshold: number;
batchSize: number;
offsetIncrement: number;
+ badFaceHeuristics: boolean;
}
export interface ClusteringProgress {
@@ -130,8 +131,10 @@ export interface ClusteringProgress {
export type OnClusteringProgress = (progress: ClusteringProgress) => void;
-export type FaceF32 = Omit & {
+/** A {@link Face} annotated with data needed during clustering. */
+export type ClusterFace = Omit & {
embedding: Float32Array;
+ isBadFace: boolean;
};
export interface ClusterPreview {
@@ -140,7 +143,7 @@ export interface ClusterPreview {
}
export interface ClusterPreviewFace {
- face: FaceF32;
+ face: ClusterFace;
cosineSimilarity: number;
wasMerged: boolean;
}
@@ -148,25 +151,17 @@ export interface ClusterPreviewFace {
/**
* Cluster faces into groups.
*
- * [Note: Face clustering algorithm]
- *
* A cgroup (cluster group) consists of clusters, each of which itself is a set
* of faces.
*
* cgroup << cluster << face
*
- * The clusters are generated locally by clients using the following algorithm:
+ * This function generates clusters locally using a batched form of linear
+ * clustering, with a bit of lookback (and a dollop of heuristics) to get the
+ * clusters to merge across batches.
*
- * 1. clusters = [] initially, or fetched from remote.
- *
- * 2. For each face, find its nearest neighbour in the embedding space.
- *
- * 3. If no such neighbour is found within our threshold, create a new cluster.
- *
- * 4. Otherwise assign this face to the same cluster as its nearest neighbour.
- *
- * This user can then tweak the output of the algorithm by performing the
- * following actions to the list of clusters that they can see:
+ * This user can later tweak these clusters by performing the following actions
+ * to the list of clusters that they can see:
*
* - They can provide a name for a cluster ("name a person"). This upgrades a
* cluster into a "cgroup", which is an entity that gets synced via remote
@@ -200,16 +195,14 @@ export const clusterFaces = (
earlyExitThreshold,
batchSize,
offsetIncrement,
+ badFaceHeuristics,
} = opts;
const t = Date.now();
const localFileByID = new Map(localFiles.map((f) => [f.id, f]));
- // A flattened array of faces.
- const allFaces = [...enumerateFaces(faceIndexes)];
- const filteredFaces = allFaces
- .filter((f) => f.blur > minBlur)
- .filter((f) => f.score > minScore);
+ // A flattened array of filtered and annotated faces.
+ const filteredFaces = [...enumerateFaces(faceIndexes, minBlur, minScore)];
const fileForFaceID = new Map(
filteredFaces.map(({ faceID }) => [
@@ -264,6 +257,7 @@ export const clusterFaces = (
oldState,
joinThreshold,
earlyExitThreshold,
+ badFaceHeuristics,
({ completed }: ClusteringProgress) =>
onProgress({
completed: offset + completed,
@@ -335,7 +329,9 @@ export const clusterFaces = (
});
}
- const totalFaceCount = allFaces.length;
+ // TODO-Cluster the total face count is only needed during debugging
+ let totalFaceCount = 0;
+ for (const fi of faceIndexes) totalFaceCount += fi.faces.length;
const filteredFaceCount = faces.length;
const clusteredFaceCount = clusterIDForFaceID.size;
const unclusteredFaceCount = filteredFaceCount - clusteredFaceCount;
@@ -364,20 +360,51 @@ export const clusterFaces = (
};
/**
- * A generator function that returns a stream of {faceID, embedding} values,
- * flattening all the the faces present in the given {@link faceIndices}.
+ * A generator function that returns a stream of eligible {@link ClusterFace}s
+ * by flattening all the the faces present in the given {@link faceIndices}.
*
- * It also converts the embeddings to Float32Arrays to speed up the dot product
- * calculations that will happen during clustering.
+ * During this, it also converts the embeddings to Float32Arrays to speed up the
+ * dot product calculations that will happen during clustering and attaches
+ * other information that the clustering algorithm needs.
*/
-function* enumerateFaces(faceIndices: FaceIndex[]) {
+function* enumerateFaces(
+ faceIndices: FaceIndex[],
+ minBlur: number,
+ minScore: number,
+) {
for (const fi of faceIndices) {
for (const f of fi.faces) {
- yield { ...f, embedding: new Float32Array(f.embedding) };
+ if (shouldIncludeFace(f, minBlur, minScore)) {
+ yield {
+ ...f,
+ embedding: new Float32Array(f.embedding),
+ isBadFace: isBadFace(f),
+ };
+ }
}
}
}
+/**
+ * Return true if the given face is above the minimum inclusion thresholds.
+ */
+const shouldIncludeFace = (face: Face, minBlur: number, minScore: number) =>
+ face.blur > minBlur && face.score > minScore;
+
+/**
+ * Return true if the given face is above the minimum inclusion thresholds, but
+ * is otherwise heuristically determined to be possibly spurious face detection.
+ *
+ * We apply a higher threshold when clustering such faces.
+ */
+const isBadFace = (face: Face) =>
+ face.blur < 50 ||
+ (face.blur < 200 && face.blur < 0.85) ||
+ isSidewaysFace(face);
+
+const isSidewaysFace = (face: Face) =>
+ faceDirection(face.detection) != "straight";
+
/** Generate a new cluster ID. */
const newClusterID = () => newNonSecureID("cluster_");
@@ -403,10 +430,11 @@ interface ClusteringState {
}
const clusterBatchLinear = (
- faces: FaceF32[],
+ faces: ClusterFace[],
oldState: ClusteringState,
joinThreshold: number,
earlyExitThreshold: number,
+ badFaceHeuristics: boolean,
onProgress: (progress: ClusteringProgress) => void,
) => {
const state: ClusteringState = {
@@ -429,7 +457,7 @@ const clusterBatchLinear = (
// Find the nearest neighbour among the previous faces in this batch.
let nnIndex: number | undefined;
- let nnCosineSimilarity = joinThreshold;
+ let nnCosineSimilarity = 0;
for (let j = i - 1; j >= 0; j--) {
// ! This is an O(n^2) loop, be careful when adding more code here.
@@ -439,13 +467,15 @@ const clusterBatchLinear = (
// The vectors are already normalized, so we can directly use their
// dot product as their cosine similarity.
const csim = dotProduct(fi.embedding, fj.embedding);
- if (csim > nnCosineSimilarity) {
+ const threshold =
+ badFaceHeuristics && fj.isBadFace ? 0.84 : joinThreshold;
+ if (csim > nnCosineSimilarity && csim >= threshold) {
nnIndex = j;
nnCosineSimilarity = csim;
// If we've found something "near enough", stop looking for a
// better match (A heuristic to speed up clustering).
- if (earlyExitThreshold > 0 && csim > earlyExitThreshold) break;
+ if (earlyExitThreshold > 0 && csim >= earlyExitThreshold) break;
}
}
diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts
index 276db1083c..b3ce72ae0d 100644
--- a/web/packages/new/photos/services/ml/index.ts
+++ b/web/packages/new/photos/services/ml/index.ts
@@ -20,10 +20,10 @@ import { getRemoteFlag, updateRemoteFlag } from "../remote-store";
import type { SearchPerson } from "../search/types";
import type { UploadItem } from "../upload/types";
import {
+ type ClusterFace,
type ClusteringOpts,
type ClusterPreviewFace,
type FaceCluster,
- type FaceF32,
type OnClusteringProgress,
} from "./cluster";
import { regenerateFaceCrops } from "./crop";
@@ -366,7 +366,7 @@ export interface ClusterDebugPageContents {
clusters: FaceCluster[];
clusterPreviewsWithFile: ClusterPreviewWithFile[];
unclusteredFacesWithFile: {
- face: FaceF32;
+ face: ClusterFace;
enteFile: EnteFile;
}[];
}