[desktop] Clustering - Incorporate low quality face heuristics (#3123)
This commit is contained in:
@@ -6,9 +6,9 @@ import {
|
||||
type ClusterDebugPageContents,
|
||||
} from "@/new/photos/services/ml";
|
||||
import {
|
||||
type ClusterFace,
|
||||
type ClusteringOpts,
|
||||
type ClusteringProgress,
|
||||
type FaceF32,
|
||||
type OnClusteringProgress,
|
||||
} from "@/new/photos/services/ml/cluster";
|
||||
import { faceDirection } from "@/new/photos/services/ml/face";
|
||||
@@ -22,6 +22,8 @@ import BackButton from "@mui/icons-material/ArrowBackOutlined";
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Checkbox,
|
||||
FormControlLabel,
|
||||
IconButton,
|
||||
LinearProgress,
|
||||
Stack,
|
||||
@@ -68,10 +70,11 @@ export default function ClusterDebug() {
|
||||
minBlur: 10,
|
||||
minScore: 0.8,
|
||||
minClusterSize: 2,
|
||||
joinThreshold: 0.6,
|
||||
joinThreshold: 0.76,
|
||||
earlyExitThreshold: 0.9,
|
||||
batchSize: 10000,
|
||||
offsetIncrement: 7500,
|
||||
badFaceHeuristics: true,
|
||||
},
|
||||
onSubmit: (values) =>
|
||||
cluster(
|
||||
@@ -83,6 +86,7 @@ export default function ClusterDebug() {
|
||||
earlyExitThreshold: toFloat(values.earlyExitThreshold),
|
||||
batchSize: toFloat(values.batchSize),
|
||||
offsetIncrement: toFloat(values.offsetIncrement),
|
||||
badFaceHeuristics: values.badFaceHeuristics,
|
||||
},
|
||||
(progress: ClusteringProgress) =>
|
||||
onProgressRef.current?.(progress),
|
||||
@@ -227,7 +231,22 @@ const MemoizedForm = memo(
|
||||
onChange={handleChange}
|
||||
/>
|
||||
</Stack>
|
||||
<Box marginInlineStart={"auto"} p={1}>
|
||||
<Stack direction="row" justifyContent={"space-between"} p={1}>
|
||||
<FormControlLabel
|
||||
control={
|
||||
<Checkbox
|
||||
name={"badFaceHeuristics"}
|
||||
checked={values.badFaceHeuristics}
|
||||
size="small"
|
||||
onChange={handleChange}
|
||||
/>
|
||||
}
|
||||
label={
|
||||
<Typography color="text.secondary">
|
||||
Bad face heuristics
|
||||
</Typography>
|
||||
}
|
||||
/>
|
||||
<Button
|
||||
color="secondary"
|
||||
type="submit"
|
||||
@@ -235,7 +254,7 @@ const MemoizedForm = memo(
|
||||
>
|
||||
Cluster
|
||||
</Button>
|
||||
</Box>
|
||||
</Stack>
|
||||
</Stack>
|
||||
</form>
|
||||
),
|
||||
@@ -325,7 +344,7 @@ const ClusterList: React.FC<React.PropsWithChildren<ClusterListProps>> = ({
|
||||
index === 0
|
||||
? 140
|
||||
: index === 1
|
||||
? 130
|
||||
? 110
|
||||
: Array.isArray(items[index - 2])
|
||||
? listItemHeight
|
||||
: 36;
|
||||
@@ -447,15 +466,11 @@ const ClusterResHeader: React.FC<ClusterResHeaderProps> = ({ clusterRes }) => {
|
||||
</Typography>
|
||||
<Typography variant="small" color="text.muted">
|
||||
For each cluster showing only up to 50 faces, sorted by cosine
|
||||
similarity to highest scoring face in the cluster.
|
||||
similarity to its highest scoring face.
|
||||
</Typography>
|
||||
<Typography variant="small" color="text.muted">
|
||||
Below each face is its{" "}
|
||||
<b>blur - score - cosineSimilarity - direction</b>.
|
||||
</Typography>
|
||||
<Typography variant="small" color="text.muted">
|
||||
Faces added to the cluster as a result of next batch merging are
|
||||
outlined.
|
||||
Below each face is its blur, score, cosineSimilarity, direction.
|
||||
Bad faces are outlined.
|
||||
</Typography>
|
||||
</Stack>
|
||||
);
|
||||
@@ -494,15 +509,15 @@ interface FaceItemProps {
|
||||
}
|
||||
|
||||
interface FaceWithFile {
|
||||
face: FaceF32;
|
||||
face: ClusterFace;
|
||||
enteFile: EnteFile;
|
||||
cosineSimilarity?: number;
|
||||
wasMerged?: boolean;
|
||||
}
|
||||
|
||||
const FaceItem: React.FC<FaceItemProps> = ({ faceWithFile }) => {
|
||||
const { face, enteFile, cosineSimilarity, wasMerged } = faceWithFile;
|
||||
const { faceID } = face;
|
||||
const { face, enteFile, cosineSimilarity } = faceWithFile;
|
||||
const { faceID, isBadFace } = face;
|
||||
|
||||
const [objectURL, setObjectURL] = useState<string | undefined>();
|
||||
|
||||
@@ -526,7 +541,7 @@ const FaceItem: React.FC<FaceItemProps> = ({ faceWithFile }) => {
|
||||
return (
|
||||
<FaceChip
|
||||
style={{
|
||||
outline: wasMerged ? `1px solid gray` : undefined,
|
||||
outline: isBadFace ? `1px solid rosybrown` : undefined,
|
||||
outlineOffset: "2px",
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -3,7 +3,7 @@ import { newNonSecureID } from "@/base/id-worker";
|
||||
import log from "@/base/log";
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import type { EnteFile } from "../../types/file";
|
||||
import type { Face, FaceIndex } from "./face";
|
||||
import { faceDirection, type Face, type FaceIndex } from "./face";
|
||||
import { dotProduct } from "./math";
|
||||
|
||||
/**
|
||||
@@ -121,6 +121,7 @@ export interface ClusteringOpts {
|
||||
earlyExitThreshold: number;
|
||||
batchSize: number;
|
||||
offsetIncrement: number;
|
||||
badFaceHeuristics: boolean;
|
||||
}
|
||||
|
||||
export interface ClusteringProgress {
|
||||
@@ -130,8 +131,10 @@ export interface ClusteringProgress {
|
||||
|
||||
export type OnClusteringProgress = (progress: ClusteringProgress) => void;
|
||||
|
||||
export type FaceF32 = Omit<Face, "embedding"> & {
|
||||
/** A {@link Face} annotated with data needed during clustering. */
|
||||
export type ClusterFace = Omit<Face, "embedding"> & {
|
||||
embedding: Float32Array;
|
||||
isBadFace: boolean;
|
||||
};
|
||||
|
||||
export interface ClusterPreview {
|
||||
@@ -140,7 +143,7 @@ export interface ClusterPreview {
|
||||
}
|
||||
|
||||
export interface ClusterPreviewFace {
|
||||
face: FaceF32;
|
||||
face: ClusterFace;
|
||||
cosineSimilarity: number;
|
||||
wasMerged: boolean;
|
||||
}
|
||||
@@ -148,25 +151,17 @@ export interface ClusterPreviewFace {
|
||||
/**
|
||||
* Cluster faces into groups.
|
||||
*
|
||||
* [Note: Face clustering algorithm]
|
||||
*
|
||||
* A cgroup (cluster group) consists of clusters, each of which itself is a set
|
||||
* of faces.
|
||||
*
|
||||
* cgroup << cluster << face
|
||||
*
|
||||
* The clusters are generated locally by clients using the following algorithm:
|
||||
* This function generates clusters locally using a batched form of linear
|
||||
* clustering, with a bit of lookback (and a dollop of heuristics) to get the
|
||||
* clusters to merge across batches.
|
||||
*
|
||||
* 1. clusters = [] initially, or fetched from remote.
|
||||
*
|
||||
* 2. For each face, find its nearest neighbour in the embedding space.
|
||||
*
|
||||
* 3. If no such neighbour is found within our threshold, create a new cluster.
|
||||
*
|
||||
* 4. Otherwise assign this face to the same cluster as its nearest neighbour.
|
||||
*
|
||||
* This user can then tweak the output of the algorithm by performing the
|
||||
* following actions to the list of clusters that they can see:
|
||||
* This user can later tweak these clusters by performing the following actions
|
||||
* to the list of clusters that they can see:
|
||||
*
|
||||
* - They can provide a name for a cluster ("name a person"). This upgrades a
|
||||
* cluster into a "cgroup", which is an entity that gets synced via remote
|
||||
@@ -200,16 +195,14 @@ export const clusterFaces = (
|
||||
earlyExitThreshold,
|
||||
batchSize,
|
||||
offsetIncrement,
|
||||
badFaceHeuristics,
|
||||
} = opts;
|
||||
const t = Date.now();
|
||||
|
||||
const localFileByID = new Map(localFiles.map((f) => [f.id, f]));
|
||||
|
||||
// A flattened array of faces.
|
||||
const allFaces = [...enumerateFaces(faceIndexes)];
|
||||
const filteredFaces = allFaces
|
||||
.filter((f) => f.blur > minBlur)
|
||||
.filter((f) => f.score > minScore);
|
||||
// A flattened array of filtered and annotated faces.
|
||||
const filteredFaces = [...enumerateFaces(faceIndexes, minBlur, minScore)];
|
||||
|
||||
const fileForFaceID = new Map(
|
||||
filteredFaces.map(({ faceID }) => [
|
||||
@@ -264,6 +257,7 @@ export const clusterFaces = (
|
||||
oldState,
|
||||
joinThreshold,
|
||||
earlyExitThreshold,
|
||||
badFaceHeuristics,
|
||||
({ completed }: ClusteringProgress) =>
|
||||
onProgress({
|
||||
completed: offset + completed,
|
||||
@@ -335,7 +329,9 @@ export const clusterFaces = (
|
||||
});
|
||||
}
|
||||
|
||||
const totalFaceCount = allFaces.length;
|
||||
// TODO-Cluster the total face count is only needed during debugging
|
||||
let totalFaceCount = 0;
|
||||
for (const fi of faceIndexes) totalFaceCount += fi.faces.length;
|
||||
const filteredFaceCount = faces.length;
|
||||
const clusteredFaceCount = clusterIDForFaceID.size;
|
||||
const unclusteredFaceCount = filteredFaceCount - clusteredFaceCount;
|
||||
@@ -364,20 +360,51 @@ export const clusterFaces = (
|
||||
};
|
||||
|
||||
/**
|
||||
* A generator function that returns a stream of {faceID, embedding} values,
|
||||
* flattening all the the faces present in the given {@link faceIndices}.
|
||||
* A generator function that returns a stream of eligible {@link ClusterFace}s
|
||||
* by flattening all the the faces present in the given {@link faceIndices}.
|
||||
*
|
||||
* It also converts the embeddings to Float32Arrays to speed up the dot product
|
||||
* calculations that will happen during clustering.
|
||||
* During this, it also converts the embeddings to Float32Arrays to speed up the
|
||||
* dot product calculations that will happen during clustering and attaches
|
||||
* other information that the clustering algorithm needs.
|
||||
*/
|
||||
function* enumerateFaces(faceIndices: FaceIndex[]) {
|
||||
function* enumerateFaces(
|
||||
faceIndices: FaceIndex[],
|
||||
minBlur: number,
|
||||
minScore: number,
|
||||
) {
|
||||
for (const fi of faceIndices) {
|
||||
for (const f of fi.faces) {
|
||||
yield { ...f, embedding: new Float32Array(f.embedding) };
|
||||
if (shouldIncludeFace(f, minBlur, minScore)) {
|
||||
yield {
|
||||
...f,
|
||||
embedding: new Float32Array(f.embedding),
|
||||
isBadFace: isBadFace(f),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return true if the given face is above the minimum inclusion thresholds.
|
||||
*/
|
||||
const shouldIncludeFace = (face: Face, minBlur: number, minScore: number) =>
|
||||
face.blur > minBlur && face.score > minScore;
|
||||
|
||||
/**
|
||||
* Return true if the given face is above the minimum inclusion thresholds, but
|
||||
* is otherwise heuristically determined to be possibly spurious face detection.
|
||||
*
|
||||
* We apply a higher threshold when clustering such faces.
|
||||
*/
|
||||
const isBadFace = (face: Face) =>
|
||||
face.blur < 50 ||
|
||||
(face.blur < 200 && face.blur < 0.85) ||
|
||||
isSidewaysFace(face);
|
||||
|
||||
const isSidewaysFace = (face: Face) =>
|
||||
faceDirection(face.detection) != "straight";
|
||||
|
||||
/** Generate a new cluster ID. */
|
||||
const newClusterID = () => newNonSecureID("cluster_");
|
||||
|
||||
@@ -403,10 +430,11 @@ interface ClusteringState {
|
||||
}
|
||||
|
||||
const clusterBatchLinear = (
|
||||
faces: FaceF32[],
|
||||
faces: ClusterFace[],
|
||||
oldState: ClusteringState,
|
||||
joinThreshold: number,
|
||||
earlyExitThreshold: number,
|
||||
badFaceHeuristics: boolean,
|
||||
onProgress: (progress: ClusteringProgress) => void,
|
||||
) => {
|
||||
const state: ClusteringState = {
|
||||
@@ -429,7 +457,7 @@ const clusterBatchLinear = (
|
||||
|
||||
// Find the nearest neighbour among the previous faces in this batch.
|
||||
let nnIndex: number | undefined;
|
||||
let nnCosineSimilarity = joinThreshold;
|
||||
let nnCosineSimilarity = 0;
|
||||
for (let j = i - 1; j >= 0; j--) {
|
||||
// ! This is an O(n^2) loop, be careful when adding more code here.
|
||||
|
||||
@@ -439,13 +467,15 @@ const clusterBatchLinear = (
|
||||
// The vectors are already normalized, so we can directly use their
|
||||
// dot product as their cosine similarity.
|
||||
const csim = dotProduct(fi.embedding, fj.embedding);
|
||||
if (csim > nnCosineSimilarity) {
|
||||
const threshold =
|
||||
badFaceHeuristics && fj.isBadFace ? 0.84 : joinThreshold;
|
||||
if (csim > nnCosineSimilarity && csim >= threshold) {
|
||||
nnIndex = j;
|
||||
nnCosineSimilarity = csim;
|
||||
|
||||
// If we've found something "near enough", stop looking for a
|
||||
// better match (A heuristic to speed up clustering).
|
||||
if (earlyExitThreshold > 0 && csim > earlyExitThreshold) break;
|
||||
if (earlyExitThreshold > 0 && csim >= earlyExitThreshold) break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,10 +20,10 @@ import { getRemoteFlag, updateRemoteFlag } from "../remote-store";
|
||||
import type { SearchPerson } from "../search/types";
|
||||
import type { UploadItem } from "../upload/types";
|
||||
import {
|
||||
type ClusterFace,
|
||||
type ClusteringOpts,
|
||||
type ClusterPreviewFace,
|
||||
type FaceCluster,
|
||||
type FaceF32,
|
||||
type OnClusteringProgress,
|
||||
} from "./cluster";
|
||||
import { regenerateFaceCrops } from "./crop";
|
||||
@@ -366,7 +366,7 @@ export interface ClusterDebugPageContents {
|
||||
clusters: FaceCluster[];
|
||||
clusterPreviewsWithFile: ClusterPreviewWithFile[];
|
||||
unclusteredFacesWithFile: {
|
||||
face: FaceF32;
|
||||
face: ClusterFace;
|
||||
enteFile: EnteFile;
|
||||
}[];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user