This commit is contained in:
Manav Rathi
2024-10-12 11:20:15 +05:30
parent 7506a35f1f
commit 1cb0799bd0
5 changed files with 69 additions and 12 deletions

View File

@@ -1,6 +1,6 @@
import type { ElectronMLWorker } from "@/base/types/ipc";
import type { ImageBitmapAndData } from "./blob";
import { getCLIPIndexes } from "./db";
import { savedCLIPIndexes } from "./db";
import { dotProduct, norm } from "./math";
import type { CLIPMatches } from "./worker-types";
@@ -169,7 +169,7 @@ let _cachedCLIPIndexes:
* produces potentially new CLIP indexes).
*/
const cachedOrReadCLIPIndexes = async () =>
(_cachedCLIPIndexes ??= (await getCLIPIndexes()).map(
(_cachedCLIPIndexes ??= (await savedCLIPIndexes()).map(
({ fileID, embedding }) => ({
fileID,
embedding: new Float32Array(embedding),

View File

@@ -206,7 +206,7 @@ const newFileStatus = (fileID: number): FileStatus => ({
/**
* Return the {@link FaceIndex}, if any, for {@link fileID}.
*/
export const getFaceIndex = async (fileID: number) => {
export const savedFaceIndex = async (fileID: number) => {
const db = await mlDB();
return db.get("face-index", fileID);
};
@@ -214,7 +214,7 @@ export const getFaceIndex = async (fileID: number) => {
/**
* Return all face indexes present locally.
*/
export const getFaceIndexes = async () => {
export const savedFaceIndexes = async () => {
const db = await mlDB();
return await db.getAll("face-index");
};
@@ -222,7 +222,7 @@ export const getFaceIndexes = async () => {
/**
* Return all CLIP indexes present locally.
*/
export const getCLIPIndexes = async () => {
export const savedCLIPIndexes = async () => {
const db = await mlDB();
return await db.getAll("clip-index");
};

View File

@@ -26,7 +26,7 @@ import {
import { deleteUserEntity } from "../user-entity/remote";
import type { FaceCluster } from "./cluster";
import { regenerateFaceCrops } from "./crop";
import { clearMLDB, getFaceIndex, getIndexableAndIndexedCounts } from "./db";
import { clearMLDB, getIndexableAndIndexedCounts, savedFaceIndex } from "./db";
import { filterNamedPeople, reconstructPeople, type Person } from "./people";
import { MLWorker } from "./worker";
import type { CLIPMatches } from "./worker-types";
@@ -633,7 +633,7 @@ export const getAnnotatedFacesForFile = async (
const annotatedFaceIDs: AnnotatedFaceID[] = [];
const otherFaceIDs: string[] = [];
const index = await getFaceIndex(enteFile.id);
const index = await savedFaceIndex(enteFile.id);
if (!index) return { annotatedFaceIDs, otherFaceIDs };
const people = _state.peopleSnapshot ?? [];
@@ -702,7 +702,7 @@ export const faceCrop = async (faceID: string, enteFile: EnteFile) => {
* the file (updating the "face-crops" {@link BlobCache}).
*/
const regenerateFaceCropsIfNeeded = async (enteFile: EnteFile) => {
const index = await getFaceIndex(enteFile.id);
const index = await savedFaceIndex(enteFile.id);
if (!index) return;
const cache = await blobCache("face-crops");

View File

@@ -2,8 +2,9 @@ import type { EnteFile } from "../../types/file";
import { getLocalFiles } from "../files";
import { savedCGroups, type CGroup } from "../user-entity";
import type { FaceCluster } from "./cluster";
import { getFaceIndexes, savedFaceClusters } from "./db";
import { savedFaceClusters, savedFaceIndexes } from "./db";
import { fileIDFromFaceID } from "./face";
import { dotProduct } from "./math";
/**
* A cgroup ("cluster group") is a group of clusters (possibly containing just a
@@ -167,7 +168,7 @@ export const reconstructPeople = async (): Promise<Person[]> => {
{ faceID: string; file: EnteFile; score: number }
>();
const faceIndexes = await getFaceIndexes();
const faceIndexes = await savedFaceIndexes();
for (const { faces } of faceIndexes) {
for (const { faceID, score } of faces) {
const fileID = fileIDFromFaceID(faceID);
@@ -305,3 +306,59 @@ export const filterNamedPeople = (people: Person[]): NamedPerson[] => {
}
return namedPeople;
};
/**
* Returns suggestions for the given person.
*/
export const suggestionsForPerson = async (person: CGroupPerson) => {
const personClusters = person.cgroup.data.assigned;
// TODO-Cluster: Persist this.
const ignoredClusters: FaceCluster[] = [];
const clusters = await savedFaceClusters();
const faceIndexes = await savedFaceIndexes();
const embeddingByFaceID = new Map(
faceIndexes
.map(({ faces }) =>
faces.map(
(f) => [f.faceID, new Float32Array(f.embedding)] as const,
),
)
.flat(),
);
const personClusterIDs = new Set(personClusters.map(({ id }) => id));
const ignoredClusterIDs = new Set(ignoredClusters.map(({ id }) => id));
const personFaceEmbeddings = personClusters
.map(({ faces }) => faces.map((id) => embeddingByFaceID.get(id)))
.flat()
.filter((e) => !!e);
const suggestedClusters: FaceCluster[] = [];
for (const cluster of clusters) {
const { id, faces } = cluster;
if (personClusterIDs.has(id)) continue;
if (ignoredClusterIDs.has(id)) continue;
let suggest = false;
for (const fi of faces) {
const ei = embeddingByFaceID.get(fi);
if (!ei) continue;
for (const ej of personFaceEmbeddings) {
const csim = dotProduct(ei, ej);
if (csim >= 0.6) {
suggest = true;
break;
}
}
if (suggest) break;
}
if (suggest) suggestedClusters.push(cluster);
}
return suggestedClusters.sort((a, b) => b.faces.length - a.faces.length);
};

View File

@@ -32,9 +32,9 @@ import {
} from "./cluster";
import { saveFaceCrops } from "./crop";
import {
getFaceIndexes,
getIndexableFileIDs,
markIndexingFailed,
savedFaceIndexes,
saveIndexes,
updateAssumingLocalFiles,
} from "./db";
@@ -327,7 +327,7 @@ export class MLWorker {
*/
async clusterFaces(masterKey: Uint8Array) {
const clusters = await clusterFaces(
await getFaceIndexes(),
await savedFaceIndexes(),
await getAllLocalFiles(),
(progress) => this.updateClusteringProgress(progress),
);