propagate

This commit is contained in:
Manav Rathi
2024-08-30 16:44:58 +05:30
parent 20c742d43d
commit 598d5aab10
4 changed files with 75 additions and 158 deletions

View File

@@ -2,11 +2,12 @@ import { SelectionBar } from "@/base/components/Navbar";
import { pt } from "@/base/i18n";
import {
faceCrop,
wipClusterDebugPageContents,
type ClusterDebugPageContents,
type ClusterPreviewFaceWF,
type ClusterPreviewFaceWithFile,
} from "@/new/photos/services/ml";
import { type ClusteringOpts } from "@/new/photos/services/ml/cluster-new";
import { faceDirection } from "@/new/photos/services/ml/face";
import { wait } from "@/utils/promise";
import {
FlexWrapper,
FluidContainer,
@@ -87,12 +88,6 @@ interface ClusterListProps {
width: number;
}
interface ClusteringOpts {
method: "linear" | "hdbscan";
batchSize: number;
joinThreshold: number;
}
const ClusterList: React.FC<ClusterListProps> = ({ height, width }) => {
const { startLoading, finishLoading } = useContext(AppContext);
@@ -105,18 +100,7 @@ const ClusterList: React.FC<ClusterListProps> = ({ height, width }) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const cluster = async (opts: ClusteringOpts) => {
startLoading();
// setClusterRes(await wipClusterDebugPageContents());
console.log(opts);
await wait(5000);
setClusterRes({
clusteredCount: 1,
unclusteredCount: 2,
clusterPreviewWFs: Array(100)
.fill(0)
.map(() => ({ clusterSize: 0, faces: [] })),
clusters: [],
clusterIDForFaceID: new Map(),
});
setClusterRes(await wipClusterDebugPageContents(opts));
finishLoading();
};
@@ -136,8 +120,6 @@ const ClusterList: React.FC<ClusterListProps> = ({ height, width }) => {
listRef.current?.resetAfterIndex(0);
}, [items]);
const clusterIDForFaceID = clusterRes?.clusterIDForFaceID;
const getItemSize = (index: number) =>
index === 0
? 270
@@ -177,10 +159,10 @@ const ClusterList: React.FC<ClusterListProps> = ({ height, width }) => {
{`cluster size ${item.toFixed(2)}`}
</LabelContainer>
) : (
item.map((faceWF, i) => (
item.map((f, i) => (
<FaceItem
key={i.toString()}
{...{ faceWF, clusterIDForFaceID }}
faceWithFile={f}
/>
))
)}
@@ -192,17 +174,17 @@ const ClusterList: React.FC<ClusterListProps> = ({ height, width }) => {
);
};
type Item = number | ClusterPreviewFaceWF[];
type Item = number | ClusterPreviewFaceWithFile[];
const itemsFromClusterRes = (
clusterRes: ClusterDebugPageContents,
columns: number,
) => {
const { clusterPreviewWFs } = clusterRes;
const { clusterPreviewsWithFile } = clusterRes;
const result: Item[] = [];
for (let index = 0; index < clusterPreviewWFs.length; index++) {
const { clusterSize, faces } = clusterPreviewWFs[index];
for (let index = 0; index < clusterPreviewsWithFile.length; index++) {
const { clusterSize, faces } = clusterPreviewsWithFile[index];
result.push(clusterSize);
let lastIndex = 0;
while (lastIndex < faces.length) {
@@ -315,7 +297,7 @@ const Header: React.FC<HeaderProps> = ({ clusterRes, onCluster }) => {
const clusterInfo = clusterRes && (
<Stack m={1}>
<Typography variant="small" mb={1}>
{`${clusterRes.clusters.length} clusters from ${clusterRes.clusteredCount} faces. ${clusterRes.unclusteredCount} unclustered faces.`}
{`${clusterRes.clusters.length} clusters from ${clusterRes.clusteredFaceCount} faces. ${clusterRes.unclusteredFaceCount} unclustered faces.`}
</Typography>
<Typography variant="small" color="text.muted">
Showing only top 30 and bottom 30 clusters.
@@ -326,7 +308,10 @@ const Header: React.FC<HeaderProps> = ({ clusterRes, onCluster }) => {
</Typography>
<Typography variant="small" color="text.muted">
Below each face is its{" "}
<b>blur - score - cosineSimilarity - direction</b>
<b>blur - score - cosineSimilarity - direction</b>.
</Typography>
<Typography variant="small" color="text.muted">
Faces added to the cluster as a result of merging are outlined.
</Typography>
</Stack>
);
@@ -347,12 +332,11 @@ const Loader = () => (
);
interface FaceItemProps {
faceWF: ClusterPreviewFaceWF;
clusterIDForFaceID: Map<string, string> | undefined;
faceWithFile: ClusterPreviewFaceWithFile;
}
const FaceItem: React.FC<FaceItemProps> = ({ faceWF, clusterIDForFaceID }) => {
const { face, enteFile, cosineSimilarity } = faceWF;
const FaceItem: React.FC<FaceItemProps> = ({ faceWithFile }) => {
const { face, enteFile, cosineSimilarity, wasMerged } = faceWithFile;
const { faceID } = face;
const [objectURL, setObjectURL] = useState<string | undefined>();
@@ -377,7 +361,7 @@ const FaceItem: React.FC<FaceItemProps> = ({ faceWF, clusterIDForFaceID }) => {
return (
<FaceChip
style={{
outline: outlineForCluster(clusterIDForFaceID?.get(faceID)),
outline: wasMerged ? `1px solid gray` : undefined,
outlineOffset: "2px",
}}
>
@@ -413,9 +397,3 @@ const FaceChip = styled(Box)`
width: 120px;
height: 120px;
`;
const outlineForCluster = (clusterID: string | undefined) =>
clusterID ? `1px solid oklch(0.8 0.2 ${hForID(clusterID)})` : undefined;
const hForID = (id: string) =>
([...id].reduce((s, c) => s + c.charCodeAt(0), 0) % 10) * 36;

View File

@@ -113,15 +113,10 @@ export interface CGroup {
displayFaceID: string | undefined;
}
// TODO-Cluster
export interface FaceNeighbours {
face: Face;
neighbours: FaceNeighbour[];
}
interface FaceNeighbour {
face: Face;
cosineSimilarity: number;
export interface ClusteringOpts {
method: "linear" | "hdbscan";
batchSize: number;
joinThreshold: number;
}
export interface ClusterPreview {
@@ -129,9 +124,10 @@ export interface ClusterPreview {
faces: ClusterPreviewFace[];
}
interface ClusterPreviewFace {
export interface ClusterPreviewFace {
face: Face;
cosineSimilarity: number;
wasMerged: boolean;
}
/**
@@ -348,7 +344,11 @@ function* enumerateFaces(faceIndices: FaceIndex[]) {
}
}
export const clusterFacesHdb = (faceIndexes: FaceIndex[]) => {
export const clusterFacesHdb = (
faceIndexes: FaceIndex[],
opts: ClusteringOpts,
) => {
const { batch } = opts;
const t = Date.now();
// A flattened array of faces.

View File

@@ -20,7 +20,11 @@ import { getAllLocalFiles } from "../files";
import { getRemoteFlag, updateRemoteFlag } from "../remote-store";
import type { SearchPerson } from "../search/types";
import type { UploadItem } from "../upload/types";
import { type CGroup, type FaceCluster } from "./cluster-new";
import {
type ClusteringOpts,
type ClusterPreviewFace,
type FaceCluster,
} from "./cluster-new";
import { regenerateFaceCrops } from "./crop";
import { clearMLDB, faceIndex, indexableAndIndexedCounts } from "./db";
import type { Face } from "./face";
@@ -344,42 +348,30 @@ export const wipSearchPersons = async () => {
return _wip_searchPersons ?? [];
};
export interface FaceFileNeighbours {
face: Face;
neighbours: FaceFileNeighbour[];
}
export interface FaceFileNeighbour {
face: Face;
enteFile: EnteFile;
cosineSimilarity: number;
}
// "with file"
export interface ClusterPreviewWF {
export interface ClusterPreviewWithFile {
clusterSize: number;
faces: ClusterPreviewFaceWF[];
faces: ClusterPreviewFaceWithFile[];
}
export interface ClusterPreviewFaceWF {
face: Face;
export type ClusterPreviewFaceWithFile = ClusterPreviewFace & {
enteFile: EnteFile;
cosineSimilarity: number;
}
};
export interface ClusterDebugPageContents {
clusteredCount: number;
unclusteredCount: number;
// faceFNs: FaceFileNeighbours[];
clusterPreviewWFs: ClusterPreviewWF[];
clusteredFaceCount: number;
unclusteredFaceCount: number;
clusters: FaceCluster[];
clusterIDForFaceID: Map<string, string>;
clusterPreviewsWithFile: ClusterPreviewWithFile[];
unclusteredFacesWithFile: {
face: Face;
enteFile: EnteFile;
};
}
export const wipClusterDebugPageContents = async (): Promise<
ClusterDebugPageContents | undefined
> => {
if (!(await wipClusterEnable())) return undefined;
export const wipClusterDebugPageContents = async (
opts: ClusteringOpts,
): Promise<ClusterDebugPageContents> => {
if (!(await wipClusterEnable())) throw new Error("Not implemented");
log.info("clustering");
_wip_isClustering = true;
@@ -388,38 +380,33 @@ export const wipClusterDebugPageContents = async (): Promise<
// const { faceAndNeigbours, clusters, cgroups } = await clusterFaces(
const {
clusteredCount,
unclusteredCount,
clusteredFaceCount,
unclusteredFaceCount,
clusterPreviews,
clusters,
cgroups,
clusterIDForFaceID,
} = await worker().then((w) => w.clusterFacesHdb());
// const searchPersons = await convertToSearchPersons(clusters, cgroups);
unclusteredFaces,
} = await worker().then((w) => w.clusterFacesHdb(opts));
const localFiles = await getAllLocalFiles();
const localFileByID = new Map(localFiles.map((f) => [f.id, f]));
const fileForFace = ({ faceID }: Face) =>
ensure(localFileByID.get(ensure(fileIDFromFaceID(faceID))));
// const faceFNs = faceAndNeigbours.map(
// ({ topFace: face, faces: neighbours }) => ({
// face,
// neighbours: neighbours.map(({ face, cosineSimilarity }) => ({
// face,
// enteFile: fileForFace(face),
// cosineSimilarity,
// })),
// }),
// );
const clusterPreviewWFs = clusterPreviews.map(({ clusterSize, faces }) => ({
clusterSize,
faces: faces.map(({ face, cosineSimilarity }) => ({
face,
enteFile: fileForFace(face),
cosineSimilarity,
})),
const clusterPreviewsWithFile = clusterPreviews.map(
({ clusterSize, faces }) => ({
clusterSize,
faces: faces.map(({ face, cosineSimilarity }) => ({
face,
enteFile: fileForFace(face),
cosineSimilarity,
})),
}),
);
const unclusteredFacesWithFile = unclusteredFaces.map((face) => ({
face,
enteFile: fileForFace(face),
}));
const clusterByID = new Map(clusters.map((c) => [c.id, c]));
@@ -453,62 +440,14 @@ export const wipClusterDebugPageContents = async (): Promise<
triggerStatusUpdate();
return {
clusteredCount,
unclusteredCount,
clusterPreviewWFs,
clusteredFaceCount,
unclusteredFaceCount,
clusters,
clusterIDForFaceID,
clusterPreviewsWithFile,
unclusteredFacesWithFile,
};
};
export const wipCluster = () => void wipClusterDebugPageContents();
// TODO-Cluster remove me
export const convertToSearchPersons = async (
clusters: FaceCluster[],
cgroups: CGroup[],
) => {
const clusterByID = new Map(clusters.map((c) => [c.id, c]));
const localFiles = await getAllLocalFiles();
const localFileByID = new Map(localFiles.map((f) => [f.id, f]));
const result: SearchPerson[] = [];
for (const cgroup of cgroups) {
const displayFaceID = cgroup.displayFaceID;
if (!displayFaceID) {
// TODO-Cluster
assertionFailed(`cgroup ${cgroup.id} without displayFaceID`);
continue;
}
const displayFaceFileID = fileIDFromFaceID(displayFaceID);
if (!displayFaceFileID) continue;
const displayFaceFile = localFileByID.get(displayFaceFileID);
if (!displayFaceFile) {
assertionFailed(`Face ID ${displayFaceFileID} without local file`);
continue;
}
const fileIDs = cgroup.clusterIDs
.map((id) => clusterByID.get(id))
.flatMap((cluster) => cluster?.faceIDs ?? [])
.map((faceID) => fileIDFromFaceID(faceID))
.filter((fileID) => fileID !== undefined);
result.push({
id: cgroup.id,
name: cgroup.name,
files: [...new Set(fileIDs)],
displayFaceID,
displayFaceFile,
});
}
return result.sort((a, b) => b.files.length - a.files.length);
};
export type MLStatus =
| { phase: "disabled" /* The ML remote flag is off */ }
| {

View File

@@ -24,7 +24,7 @@ import {
indexCLIP,
type CLIPIndex,
} from "./clip";
import { clusterFacesHdb } from "./cluster-new";
import { clusterFacesHdb, type ClusteringOpts } from "./cluster-new";
import { saveFaceCrops } from "./crop";
import {
faceIndexes,
@@ -276,8 +276,8 @@ export class MLWorker {
}
// TODO-Cluster
async clusterFacesHdb() {
return clusterFacesHdb(await faceIndexes());
async clusterFacesHdb(opts: ClusteringOpts) {
return clusterFacesHdb(await faceIndexes(), opts);
}
}