diff --git a/web/apps/photos/src/pages/cluster-debug.tsx b/web/apps/photos/src/pages/cluster-debug.tsx index 3b9a7ad230..1f6a1036f6 100644 --- a/web/apps/photos/src/pages/cluster-debug.tsx +++ b/web/apps/photos/src/pages/cluster-debug.tsx @@ -6,9 +6,9 @@ import { type ClusterDebugPageContents, } from "@/new/photos/services/ml"; import { + type ClusterFace, type ClusteringOpts, type ClusteringProgress, - type FaceF32, type OnClusteringProgress, } from "@/new/photos/services/ml/cluster"; import { faceDirection } from "@/new/photos/services/ml/face"; @@ -22,6 +22,8 @@ import BackButton from "@mui/icons-material/ArrowBackOutlined"; import { Box, Button, + Checkbox, + FormControlLabel, IconButton, LinearProgress, Stack, @@ -68,10 +70,11 @@ export default function ClusterDebug() { minBlur: 10, minScore: 0.8, minClusterSize: 2, - joinThreshold: 0.6, + joinThreshold: 0.76, earlyExitThreshold: 0.9, batchSize: 10000, offsetIncrement: 7500, + badFaceHeuristics: true, }, onSubmit: (values) => cluster( @@ -83,6 +86,7 @@ export default function ClusterDebug() { earlyExitThreshold: toFloat(values.earlyExitThreshold), batchSize: toFloat(values.batchSize), offsetIncrement: toFloat(values.offsetIncrement), + badFaceHeuristics: values.badFaceHeuristics, }, (progress: ClusteringProgress) => onProgressRef.current?.(progress), @@ -227,7 +231,22 @@ const MemoizedForm = memo( onChange={handleChange} /> - + + + } + label={ + + Bad face heuristics + + } + /> - + ), @@ -325,7 +344,7 @@ const ClusterList: React.FC> = ({ index === 0 ? 140 : index === 1 - ? 130 + ? 110 : Array.isArray(items[index - 2]) ? listItemHeight : 36; @@ -447,15 +466,11 @@ const ClusterResHeader: React.FC = ({ clusterRes }) => { For each cluster showing only up to 50 faces, sorted by cosine - similarity to highest scoring face in the cluster. + similarity to its highest scoring face. - Below each face is its{" "} - blur - score - cosineSimilarity - direction. - - - Faces added to the cluster as a result of next batch merging are - outlined. + Below each face is its blur, score, cosineSimilarity, direction. + Bad faces are outlined. ); @@ -494,15 +509,15 @@ interface FaceItemProps { } interface FaceWithFile { - face: FaceF32; + face: ClusterFace; enteFile: EnteFile; cosineSimilarity?: number; wasMerged?: boolean; } const FaceItem: React.FC = ({ faceWithFile }) => { - const { face, enteFile, cosineSimilarity, wasMerged } = faceWithFile; - const { faceID } = face; + const { face, enteFile, cosineSimilarity } = faceWithFile; + const { faceID, isBadFace } = face; const [objectURL, setObjectURL] = useState(); @@ -526,7 +541,7 @@ const FaceItem: React.FC = ({ faceWithFile }) => { return ( diff --git a/web/packages/new/photos/services/ml/cluster.ts b/web/packages/new/photos/services/ml/cluster.ts index 1aa665666a..e4928aa991 100644 --- a/web/packages/new/photos/services/ml/cluster.ts +++ b/web/packages/new/photos/services/ml/cluster.ts @@ -3,7 +3,7 @@ import { newNonSecureID } from "@/base/id-worker"; import log from "@/base/log"; import { ensure } from "@/utils/ensure"; import type { EnteFile } from "../../types/file"; -import type { Face, FaceIndex } from "./face"; +import { faceDirection, type Face, type FaceIndex } from "./face"; import { dotProduct } from "./math"; /** @@ -121,6 +121,7 @@ export interface ClusteringOpts { earlyExitThreshold: number; batchSize: number; offsetIncrement: number; + badFaceHeuristics: boolean; } export interface ClusteringProgress { @@ -130,8 +131,10 @@ export interface ClusteringProgress { export type OnClusteringProgress = (progress: ClusteringProgress) => void; -export type FaceF32 = Omit & { +/** A {@link Face} annotated with data needed during clustering. */ +export type ClusterFace = Omit & { embedding: Float32Array; + isBadFace: boolean; }; export interface ClusterPreview { @@ -140,7 +143,7 @@ export interface ClusterPreview { } export interface ClusterPreviewFace { - face: FaceF32; + face: ClusterFace; cosineSimilarity: number; wasMerged: boolean; } @@ -148,25 +151,17 @@ export interface ClusterPreviewFace { /** * Cluster faces into groups. * - * [Note: Face clustering algorithm] - * * A cgroup (cluster group) consists of clusters, each of which itself is a set * of faces. * * cgroup << cluster << face * - * The clusters are generated locally by clients using the following algorithm: + * This function generates clusters locally using a batched form of linear + * clustering, with a bit of lookback (and a dollop of heuristics) to get the + * clusters to merge across batches. * - * 1. clusters = [] initially, or fetched from remote. - * - * 2. For each face, find its nearest neighbour in the embedding space. - * - * 3. If no such neighbour is found within our threshold, create a new cluster. - * - * 4. Otherwise assign this face to the same cluster as its nearest neighbour. - * - * This user can then tweak the output of the algorithm by performing the - * following actions to the list of clusters that they can see: + * This user can later tweak these clusters by performing the following actions + * to the list of clusters that they can see: * * - They can provide a name for a cluster ("name a person"). This upgrades a * cluster into a "cgroup", which is an entity that gets synced via remote @@ -200,16 +195,14 @@ export const clusterFaces = ( earlyExitThreshold, batchSize, offsetIncrement, + badFaceHeuristics, } = opts; const t = Date.now(); const localFileByID = new Map(localFiles.map((f) => [f.id, f])); - // A flattened array of faces. - const allFaces = [...enumerateFaces(faceIndexes)]; - const filteredFaces = allFaces - .filter((f) => f.blur > minBlur) - .filter((f) => f.score > minScore); + // A flattened array of filtered and annotated faces. + const filteredFaces = [...enumerateFaces(faceIndexes, minBlur, minScore)]; const fileForFaceID = new Map( filteredFaces.map(({ faceID }) => [ @@ -264,6 +257,7 @@ export const clusterFaces = ( oldState, joinThreshold, earlyExitThreshold, + badFaceHeuristics, ({ completed }: ClusteringProgress) => onProgress({ completed: offset + completed, @@ -335,7 +329,9 @@ export const clusterFaces = ( }); } - const totalFaceCount = allFaces.length; + // TODO-Cluster the total face count is only needed during debugging + let totalFaceCount = 0; + for (const fi of faceIndexes) totalFaceCount += fi.faces.length; const filteredFaceCount = faces.length; const clusteredFaceCount = clusterIDForFaceID.size; const unclusteredFaceCount = filteredFaceCount - clusteredFaceCount; @@ -364,20 +360,51 @@ export const clusterFaces = ( }; /** - * A generator function that returns a stream of {faceID, embedding} values, - * flattening all the the faces present in the given {@link faceIndices}. + * A generator function that returns a stream of eligible {@link ClusterFace}s + * by flattening all the the faces present in the given {@link faceIndices}. * - * It also converts the embeddings to Float32Arrays to speed up the dot product - * calculations that will happen during clustering. + * During this, it also converts the embeddings to Float32Arrays to speed up the + * dot product calculations that will happen during clustering and attaches + * other information that the clustering algorithm needs. */ -function* enumerateFaces(faceIndices: FaceIndex[]) { +function* enumerateFaces( + faceIndices: FaceIndex[], + minBlur: number, + minScore: number, +) { for (const fi of faceIndices) { for (const f of fi.faces) { - yield { ...f, embedding: new Float32Array(f.embedding) }; + if (shouldIncludeFace(f, minBlur, minScore)) { + yield { + ...f, + embedding: new Float32Array(f.embedding), + isBadFace: isBadFace(f), + }; + } } } } +/** + * Return true if the given face is above the minimum inclusion thresholds. + */ +const shouldIncludeFace = (face: Face, minBlur: number, minScore: number) => + face.blur > minBlur && face.score > minScore; + +/** + * Return true if the given face is above the minimum inclusion thresholds, but + * is otherwise heuristically determined to be possibly spurious face detection. + * + * We apply a higher threshold when clustering such faces. + */ +const isBadFace = (face: Face) => + face.blur < 50 || + (face.blur < 200 && face.blur < 0.85) || + isSidewaysFace(face); + +const isSidewaysFace = (face: Face) => + faceDirection(face.detection) != "straight"; + /** Generate a new cluster ID. */ const newClusterID = () => newNonSecureID("cluster_"); @@ -403,10 +430,11 @@ interface ClusteringState { } const clusterBatchLinear = ( - faces: FaceF32[], + faces: ClusterFace[], oldState: ClusteringState, joinThreshold: number, earlyExitThreshold: number, + badFaceHeuristics: boolean, onProgress: (progress: ClusteringProgress) => void, ) => { const state: ClusteringState = { @@ -429,7 +457,7 @@ const clusterBatchLinear = ( // Find the nearest neighbour among the previous faces in this batch. let nnIndex: number | undefined; - let nnCosineSimilarity = joinThreshold; + let nnCosineSimilarity = 0; for (let j = i - 1; j >= 0; j--) { // ! This is an O(n^2) loop, be careful when adding more code here. @@ -439,13 +467,15 @@ const clusterBatchLinear = ( // The vectors are already normalized, so we can directly use their // dot product as their cosine similarity. const csim = dotProduct(fi.embedding, fj.embedding); - if (csim > nnCosineSimilarity) { + const threshold = + badFaceHeuristics && fj.isBadFace ? 0.84 : joinThreshold; + if (csim > nnCosineSimilarity && csim >= threshold) { nnIndex = j; nnCosineSimilarity = csim; // If we've found something "near enough", stop looking for a // better match (A heuristic to speed up clustering). - if (earlyExitThreshold > 0 && csim > earlyExitThreshold) break; + if (earlyExitThreshold > 0 && csim >= earlyExitThreshold) break; } } diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts index 276db1083c..b3ce72ae0d 100644 --- a/web/packages/new/photos/services/ml/index.ts +++ b/web/packages/new/photos/services/ml/index.ts @@ -20,10 +20,10 @@ import { getRemoteFlag, updateRemoteFlag } from "../remote-store"; import type { SearchPerson } from "../search/types"; import type { UploadItem } from "../upload/types"; import { + type ClusterFace, type ClusteringOpts, type ClusterPreviewFace, type FaceCluster, - type FaceF32, type OnClusteringProgress, } from "./cluster"; import { regenerateFaceCrops } from "./crop"; @@ -366,7 +366,7 @@ export interface ClusterDebugPageContents { clusters: FaceCluster[]; clusterPreviewsWithFile: ClusterPreviewWithFile[]; unclusteredFacesWithFile: { - face: FaceF32; + face: ClusterFace; enteFile: EnteFile; }[]; }