This commit is contained in:
Manav Rathi
2024-09-02 17:08:58 +05:30
parent 371dcf8ab9
commit 3cef7ff545
4 changed files with 174 additions and 121 deletions

View File

@@ -5,7 +5,11 @@ import {
wipClusterDebugPageContents,
type ClusterDebugPageContents,
} from "@/new/photos/services/ml";
import { type ClusteringOpts } from "@/new/photos/services/ml/cluster";
import {
type ClusteringOpts,
type ClusteringProgress,
type OnClusteringProgress,
} from "@/new/photos/services/ml/cluster";
import { faceDirection, type Face } from "@/new/photos/services/ml/face";
import type { EnteFile } from "@/new/photos/types/file";
import {
@@ -25,10 +29,17 @@ import {
TextField,
Typography,
} from "@mui/material";
import { useFormik } from "formik";
import { useFormik, type FormikProps } from "formik";
import { useRouter } from "next/router";
import { AppContext } from "pages/_app";
import React, { useContext, useEffect, useMemo, useRef, useState } from "react";
import React, {
memo,
useContext,
useEffect,
useMemo,
useRef,
useState,
} from "react";
import AutoSizer from "react-virtualized-auto-sizer";
import {
areEqual,
@@ -44,10 +55,13 @@ export default function ClusterDebug() {
ClusterDebugPageContents | undefined
>();
const cluster = async (opts: ClusteringOpts) => {
const cluster = async (
opts: ClusteringOpts,
onProgress: OnClusteringProgress,
) => {
setClusterRes(undefined);
startLoading();
setClusterRes(await wipClusterDebugPageContents(opts));
setClusterRes(await wipClusterDebugPageContents(opts, onProgress));
finishLoading();
};
@@ -98,28 +112,36 @@ const Container = styled("div")`
`;
interface OptionsFormProps {
onCluster: (opts: ClusteringOpts) => Promise<void>;
onCluster: (
opts: ClusteringOpts,
onProgress: OnClusteringProgress,
) => Promise<void>;
}
const OptionsForm: React.FC<OptionsFormProps> = ({ onCluster }) => {
const [progress, setProgress] = useState<ClusteringProgress>({
completed: 0,
total: 0,
});
// Formik converts nums to a string on edit.
const toFloat = (n: number | string) =>
typeof n == "string" ? parseFloat(n) : n;
const { values, handleSubmit, handleChange, isSubmitting } =
useFormik<ClusteringOpts>({
initialValues: {
method: "linear",
minBlur: 10,
minScore: 0.8,
minClusterSize: 2,
joinThreshold: 0.7,
earlyExitThreshold: 0.2,
batchSize: 10000,
lookbackSize: 2500,
},
onSubmit: (values) =>
onCluster({
const formik = useFormik<ClusteringOpts>({
initialValues: {
method: "linear",
minBlur: 10,
minScore: 0.8,
minClusterSize: 2,
joinThreshold: 0.7,
earlyExitThreshold: 0.2,
batchSize: 10000,
lookbackSize: 2500,
},
onSubmit: (values) =>
onCluster(
{
method: values.method,
minBlur: toFloat(values.minBlur),
minScore: toFloat(values.minScore),
@@ -128,103 +150,125 @@ const OptionsForm: React.FC<OptionsFormProps> = ({ onCluster }) => {
earlyExitThreshold: toFloat(values.earlyExitThreshold),
batchSize: toFloat(values.batchSize),
lookbackSize: toFloat(values.lookbackSize),
}),
});
},
(progress: ClusteringProgress) => setProgress(progress),
),
});
const Form = memo(
({
values,
handleSubmit,
handleChange,
isSubmitting,
}: FormikProps<ClusteringOpts>) => (
<form onSubmit={handleSubmit}>
<Stack>
<Stack
direction="row"
gap={1}
sx={{ ".MuiFormControl-root": { flex: "1" } }}
>
<TextField
name="method"
label="method"
value={values.method}
select
size="small"
onChange={handleChange}
>
{["hdbscan", "linear"].map((v) => (
<MenuItem key={v} value={v}>
{v}
</MenuItem>
))}
</TextField>
<TextField
name="minBlur"
label="minBlur"
value={values.minBlur}
size="small"
onChange={handleChange}
/>
<TextField
name="minScore"
label="minScore"
value={values.minScore}
size="small"
onChange={handleChange}
/>
<TextField
name="minClusterSize"
label="minClusterSize"
value={values.minClusterSize}
size="small"
onChange={handleChange}
/>
</Stack>
<Stack
direction="row"
gap={1}
sx={{ ".MuiFormControl-root": { flex: "1" } }}
>
<TextField
name="joinThreshold"
label="joinThreshold"
value={values.joinThreshold}
size="small"
onChange={handleChange}
/>
<TextField
name="earlyExitThreshold"
label="earlyExitThreshold"
value={values.earlyExitThreshold}
size="small"
onChange={handleChange}
/>
<TextField
name="batchSize"
label="batchSize"
value={values.batchSize}
size="small"
onChange={handleChange}
/>
<TextField
name="lookbackSize"
label="lookbackSize"
value={values.lookbackSize}
size="small"
onChange={handleChange}
/>
</Stack>
<Box marginInlineStart={"auto"} p={1}>
<Button
color="secondary"
type="submit"
disabled={isSubmitting}
>
Cluster
</Button>
</Box>
</Stack>
</form>
),
);
return (
<form onSubmit={handleSubmit}>
<Stack>
<Typography paddingInline={1}>Parameters</Typography>
<Stack
direction="row"
gap={1}
sx={{ ".MuiFormControl-root": { flex: "1" } }}
>
<TextField
name="method"
label="method"
value={values.method}
select
size="small"
onChange={handleChange}
>
{["hdbscan", "linear"].map((v) => (
<MenuItem key={v} value={v}>
{v}
</MenuItem>
))}
</TextField>
<TextField
name="minBlur"
label="minBlur"
value={values.minBlur}
size="small"
onChange={handleChange}
/>
<TextField
name="minScore"
label="minScore"
value={values.minScore}
size="small"
onChange={handleChange}
/>
<TextField
name="minClusterSize"
label="minClusterSize"
value={values.minClusterSize}
size="small"
onChange={handleChange}
/>
</Stack>
<Stack
direction="row"
gap={1}
sx={{ ".MuiFormControl-root": { flex: "1" } }}
>
<TextField
name="joinThreshold"
label="joinThreshold"
value={values.joinThreshold}
size="small"
onChange={handleChange}
/>
<TextField
name="earlyExitThreshold"
label="earlyExitThreshold"
value={values.earlyExitThreshold}
size="small"
onChange={handleChange}
/>
<TextField
name="batchSize"
label="batchSize"
value={values.batchSize}
size="small"
onChange={handleChange}
/>
<TextField
name="lookbackSize"
label="lookbackSize"
value={values.lookbackSize}
size="small"
onChange={handleChange}
/>
</Stack>
<Box marginInlineStart={"auto"} p={1}>
<Button
color="secondary"
type="submit"
disabled={isSubmitting}
>
Cluster
</Button>
</Box>
{isSubmitting && <Loader />}
</Stack>
</form>
<Stack>
<Typography paddingInline={1}>Parameters</Typography>
<Form {...formik} />
{formik.isSubmitting && <Loader {...progress} />}
</Stack>
);
};
const Loader: React.FC<ClusteringProgress> = ({ completed, total }) => (
<VerticallyCentered mt={4}>
<EnteSpinner />
<Typography>{`${completed} / ${total}`}</Typography>
</VerticallyCentered>
);
type ClusterListProps = ClusterResHeaderProps & {
height: number;
width: number;
@@ -395,12 +439,6 @@ const ClusterResHeader: React.FC<ClusterResHeaderProps> = ({ clusterRes }) => {
);
};
const Loader = () => (
<VerticallyCentered mt={4}>
<EnteSpinner />
</VerticallyCentered>
);
const ListItem = styled("div")`
display: flex;
justify-content: center;

View File

@@ -125,6 +125,13 @@ export interface ClusteringOpts {
lookbackSize: number;
}
export interface ClusteringProgress {
completed: number;
total: number;
}
export type OnClusteringProgress = (progress: ClusteringProgress) => void;
export interface ClusterPreview {
clusterSize: number;
faces: ClusterPreviewFace[];
@@ -181,6 +188,7 @@ export const clusterFaces = (
faceIndexes: FaceIndex[],
localFiles: EnteFile[],
opts: ClusteringOpts,
onProgress: OnClusteringProgress,
) => {
const {
method,

View File

@@ -23,6 +23,7 @@ import {
type ClusteringOpts,
type ClusterPreviewFace,
type FaceCluster,
type OnClusteringProgress,
} from "./cluster";
import { regenerateFaceCrops } from "./crop";
import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db";
@@ -372,6 +373,7 @@ export interface ClusterDebugPageContents {
export const wipClusterDebugPageContents = async (
opts: ClusteringOpts,
onProgress: OnClusteringProgress,
): Promise<ClusterDebugPageContents> => {
if (!(await wipClusterEnable())) throw new Error("Not implemented");
@@ -387,7 +389,7 @@ export const wipClusterDebugPageContents = async (
cgroups,
unclusteredFaces,
...rest
} = await worker().then((w) => w.clusterFaces(opts));
} = await worker().then((w) => w.clusterFaces(opts, onProgress));
const fileForFace = ({ faceID }: Face) =>
ensure(localFileByID.get(ensure(fileIDFromFaceID(faceID))));

View File

@@ -24,7 +24,11 @@ import {
indexCLIP,
type CLIPIndex,
} from "./clip";
import { clusterFaces, type ClusteringOpts } from "./cluster";
import {
clusterFaces,
type ClusteringOpts,
type OnClusteringProgress,
} from "./cluster";
import { saveFaceCrops } from "./crop";
import {
faceIndexes,
@@ -276,11 +280,12 @@ export class MLWorker {
}
// TODO-Cluster
async clusterFaces(opts: ClusteringOpts) {
async clusterFaces(opts: ClusteringOpts, onProgress: OnClusteringProgress) {
return clusterFaces(
await faceIndexes(),
await getAllLocalFiles(),
opts,
onProgress,
);
}
}