diff --git a/web/apps/photos/src/pages/cluster-debug.tsx b/web/apps/photos/src/pages/cluster-debug.tsx index 90c0f91c42..871b7fb3e7 100644 --- a/web/apps/photos/src/pages/cluster-debug.tsx +++ b/web/apps/photos/src/pages/cluster-debug.tsx @@ -5,7 +5,11 @@ import { wipClusterDebugPageContents, type ClusterDebugPageContents, } from "@/new/photos/services/ml"; -import { type ClusteringOpts } from "@/new/photos/services/ml/cluster"; +import { + type ClusteringOpts, + type ClusteringProgress, + type OnClusteringProgress, +} from "@/new/photos/services/ml/cluster"; import { faceDirection, type Face } from "@/new/photos/services/ml/face"; import type { EnteFile } from "@/new/photos/types/file"; import { @@ -25,10 +29,17 @@ import { TextField, Typography, } from "@mui/material"; -import { useFormik } from "formik"; +import { useFormik, type FormikProps } from "formik"; import { useRouter } from "next/router"; import { AppContext } from "pages/_app"; -import React, { useContext, useEffect, useMemo, useRef, useState } from "react"; +import React, { + memo, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from "react"; import AutoSizer from "react-virtualized-auto-sizer"; import { areEqual, @@ -44,10 +55,13 @@ export default function ClusterDebug() { ClusterDebugPageContents | undefined >(); - const cluster = async (opts: ClusteringOpts) => { + const cluster = async ( + opts: ClusteringOpts, + onProgress: OnClusteringProgress, + ) => { setClusterRes(undefined); startLoading(); - setClusterRes(await wipClusterDebugPageContents(opts)); + setClusterRes(await wipClusterDebugPageContents(opts, onProgress)); finishLoading(); }; @@ -98,28 +112,36 @@ const Container = styled("div")` `; interface OptionsFormProps { - onCluster: (opts: ClusteringOpts) => Promise; + onCluster: ( + opts: ClusteringOpts, + onProgress: OnClusteringProgress, + ) => Promise; } const OptionsForm: React.FC = ({ onCluster }) => { + const [progress, setProgress] = useState({ + completed: 0, + total: 0, + }); + // Formik converts nums to a string on edit. const toFloat = (n: number | string) => typeof n == "string" ? parseFloat(n) : n; - const { values, handleSubmit, handleChange, isSubmitting } = - useFormik({ - initialValues: { - method: "linear", - minBlur: 10, - minScore: 0.8, - minClusterSize: 2, - joinThreshold: 0.7, - earlyExitThreshold: 0.2, - batchSize: 10000, - lookbackSize: 2500, - }, - onSubmit: (values) => - onCluster({ + const formik = useFormik({ + initialValues: { + method: "linear", + minBlur: 10, + minScore: 0.8, + minClusterSize: 2, + joinThreshold: 0.7, + earlyExitThreshold: 0.2, + batchSize: 10000, + lookbackSize: 2500, + }, + onSubmit: (values) => + onCluster( + { method: values.method, minBlur: toFloat(values.minBlur), minScore: toFloat(values.minScore), @@ -128,103 +150,125 @@ const OptionsForm: React.FC = ({ onCluster }) => { earlyExitThreshold: toFloat(values.earlyExitThreshold), batchSize: toFloat(values.batchSize), lookbackSize: toFloat(values.lookbackSize), - }), - }); + }, + (progress: ClusteringProgress) => setProgress(progress), + ), + }); + + const Form = memo( + ({ + values, + handleSubmit, + handleChange, + isSubmitting, + }: FormikProps) => ( +
+ + + + {["hdbscan", "linear"].map((v) => ( + + {v} + + ))} + + + + + + + + + + + + + + + +
+ ), + ); return ( -
- - Parameters - - - {["hdbscan", "linear"].map((v) => ( - - {v} - - ))} - - - - - - - - - - - - - - - {isSubmitting && } - -
+ + Parameters +
+ {formik.isSubmitting && } + ); }; +const Loader: React.FC = ({ completed, total }) => ( + + + {`${completed} / ${total}`} + +); + type ClusterListProps = ClusterResHeaderProps & { height: number; width: number; @@ -395,12 +439,6 @@ const ClusterResHeader: React.FC = ({ clusterRes }) => { ); }; -const Loader = () => ( - - - -); - const ListItem = styled("div")` display: flex; justify-content: center; diff --git a/web/packages/new/photos/services/ml/cluster.ts b/web/packages/new/photos/services/ml/cluster.ts index 42259d6f31..a2cd6b45f1 100644 --- a/web/packages/new/photos/services/ml/cluster.ts +++ b/web/packages/new/photos/services/ml/cluster.ts @@ -125,6 +125,13 @@ export interface ClusteringOpts { lookbackSize: number; } +export interface ClusteringProgress { + completed: number; + total: number; +} + +export type OnClusteringProgress = (progress: ClusteringProgress) => void; + export interface ClusterPreview { clusterSize: number; faces: ClusterPreviewFace[]; @@ -181,6 +188,7 @@ export const clusterFaces = ( faceIndexes: FaceIndex[], localFiles: EnteFile[], opts: ClusteringOpts, + onProgress: OnClusteringProgress, ) => { const { method, diff --git a/web/packages/new/photos/services/ml/index.ts b/web/packages/new/photos/services/ml/index.ts index 37d961b165..1e4c80ca34 100644 --- a/web/packages/new/photos/services/ml/index.ts +++ b/web/packages/new/photos/services/ml/index.ts @@ -23,6 +23,7 @@ import { type ClusteringOpts, type ClusterPreviewFace, type FaceCluster, + type OnClusteringProgress, } from "./cluster"; import { regenerateFaceCrops } from "./crop"; import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db"; @@ -372,6 +373,7 @@ export interface ClusterDebugPageContents { export const wipClusterDebugPageContents = async ( opts: ClusteringOpts, + onProgress: OnClusteringProgress, ): Promise => { if (!(await wipClusterEnable())) throw new Error("Not implemented"); @@ -387,7 +389,7 @@ export const wipClusterDebugPageContents = async ( cgroups, unclusteredFaces, ...rest - } = await worker().then((w) => w.clusterFaces(opts)); + } = await worker().then((w) => w.clusterFaces(opts, onProgress)); const fileForFace = ({ faceID }: Face) => ensure(localFileByID.get(ensure(fileIDFromFaceID(faceID)))); diff --git a/web/packages/new/photos/services/ml/worker.ts b/web/packages/new/photos/services/ml/worker.ts index 391dc672f7..7b80d3282d 100644 --- a/web/packages/new/photos/services/ml/worker.ts +++ b/web/packages/new/photos/services/ml/worker.ts @@ -24,7 +24,11 @@ import { indexCLIP, type CLIPIndex, } from "./clip"; -import { clusterFaces, type ClusteringOpts } from "./cluster"; +import { + clusterFaces, + type ClusteringOpts, + type OnClusteringProgress, +} from "./cluster"; import { saveFaceCrops } from "./crop"; import { faceIndexes, @@ -276,11 +280,12 @@ export class MLWorker { } // TODO-Cluster - async clusterFaces(opts: ClusteringOpts) { + async clusterFaces(opts: ClusteringOpts, onProgress: OnClusteringProgress) { return clusterFaces( await faceIndexes(), await getAllLocalFiles(), opts, + onProgress, ); } }