Outline
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import type { ElectronMLWorker } from "@/base/types/ipc";
|
||||
import type { ImageBitmapAndData } from "./blob";
|
||||
import { getCLIPIndexes } from "./db";
|
||||
import { savedCLIPIndexes } from "./db";
|
||||
import { dotProduct, norm } from "./math";
|
||||
import type { CLIPMatches } from "./worker-types";
|
||||
|
||||
@@ -169,7 +169,7 @@ let _cachedCLIPIndexes:
|
||||
* produces potentially new CLIP indexes).
|
||||
*/
|
||||
const cachedOrReadCLIPIndexes = async () =>
|
||||
(_cachedCLIPIndexes ??= (await getCLIPIndexes()).map(
|
||||
(_cachedCLIPIndexes ??= (await savedCLIPIndexes()).map(
|
||||
({ fileID, embedding }) => ({
|
||||
fileID,
|
||||
embedding: new Float32Array(embedding),
|
||||
|
||||
@@ -206,7 +206,7 @@ const newFileStatus = (fileID: number): FileStatus => ({
|
||||
/**
|
||||
* Return the {@link FaceIndex}, if any, for {@link fileID}.
|
||||
*/
|
||||
export const getFaceIndex = async (fileID: number) => {
|
||||
export const savedFaceIndex = async (fileID: number) => {
|
||||
const db = await mlDB();
|
||||
return db.get("face-index", fileID);
|
||||
};
|
||||
@@ -214,7 +214,7 @@ export const getFaceIndex = async (fileID: number) => {
|
||||
/**
|
||||
* Return all face indexes present locally.
|
||||
*/
|
||||
export const getFaceIndexes = async () => {
|
||||
export const savedFaceIndexes = async () => {
|
||||
const db = await mlDB();
|
||||
return await db.getAll("face-index");
|
||||
};
|
||||
@@ -222,7 +222,7 @@ export const getFaceIndexes = async () => {
|
||||
/**
|
||||
* Return all CLIP indexes present locally.
|
||||
*/
|
||||
export const getCLIPIndexes = async () => {
|
||||
export const savedCLIPIndexes = async () => {
|
||||
const db = await mlDB();
|
||||
return await db.getAll("clip-index");
|
||||
};
|
||||
|
||||
@@ -26,7 +26,7 @@ import {
|
||||
import { deleteUserEntity } from "../user-entity/remote";
|
||||
import type { FaceCluster } from "./cluster";
|
||||
import { regenerateFaceCrops } from "./crop";
|
||||
import { clearMLDB, getFaceIndex, getIndexableAndIndexedCounts } from "./db";
|
||||
import { clearMLDB, getIndexableAndIndexedCounts, savedFaceIndex } from "./db";
|
||||
import { filterNamedPeople, reconstructPeople, type Person } from "./people";
|
||||
import { MLWorker } from "./worker";
|
||||
import type { CLIPMatches } from "./worker-types";
|
||||
@@ -633,7 +633,7 @@ export const getAnnotatedFacesForFile = async (
|
||||
const annotatedFaceIDs: AnnotatedFaceID[] = [];
|
||||
const otherFaceIDs: string[] = [];
|
||||
|
||||
const index = await getFaceIndex(enteFile.id);
|
||||
const index = await savedFaceIndex(enteFile.id);
|
||||
if (!index) return { annotatedFaceIDs, otherFaceIDs };
|
||||
|
||||
const people = _state.peopleSnapshot ?? [];
|
||||
@@ -702,7 +702,7 @@ export const faceCrop = async (faceID: string, enteFile: EnteFile) => {
|
||||
* the file (updating the "face-crops" {@link BlobCache}).
|
||||
*/
|
||||
const regenerateFaceCropsIfNeeded = async (enteFile: EnteFile) => {
|
||||
const index = await getFaceIndex(enteFile.id);
|
||||
const index = await savedFaceIndex(enteFile.id);
|
||||
if (!index) return;
|
||||
|
||||
const cache = await blobCache("face-crops");
|
||||
|
||||
@@ -2,8 +2,9 @@ import type { EnteFile } from "../../types/file";
|
||||
import { getLocalFiles } from "../files";
|
||||
import { savedCGroups, type CGroup } from "../user-entity";
|
||||
import type { FaceCluster } from "./cluster";
|
||||
import { getFaceIndexes, savedFaceClusters } from "./db";
|
||||
import { savedFaceClusters, savedFaceIndexes } from "./db";
|
||||
import { fileIDFromFaceID } from "./face";
|
||||
import { dotProduct } from "./math";
|
||||
|
||||
/**
|
||||
* A cgroup ("cluster group") is a group of clusters (possibly containing just a
|
||||
@@ -167,7 +168,7 @@ export const reconstructPeople = async (): Promise<Person[]> => {
|
||||
{ faceID: string; file: EnteFile; score: number }
|
||||
>();
|
||||
|
||||
const faceIndexes = await getFaceIndexes();
|
||||
const faceIndexes = await savedFaceIndexes();
|
||||
for (const { faces } of faceIndexes) {
|
||||
for (const { faceID, score } of faces) {
|
||||
const fileID = fileIDFromFaceID(faceID);
|
||||
@@ -305,3 +306,59 @@ export const filterNamedPeople = (people: Person[]): NamedPerson[] => {
|
||||
}
|
||||
return namedPeople;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns suggestions for the given person.
|
||||
*/
|
||||
export const suggestionsForPerson = async (person: CGroupPerson) => {
|
||||
const personClusters = person.cgroup.data.assigned;
|
||||
// TODO-Cluster: Persist this.
|
||||
const ignoredClusters: FaceCluster[] = [];
|
||||
|
||||
const clusters = await savedFaceClusters();
|
||||
const faceIndexes = await savedFaceIndexes();
|
||||
|
||||
const embeddingByFaceID = new Map(
|
||||
faceIndexes
|
||||
.map(({ faces }) =>
|
||||
faces.map(
|
||||
(f) => [f.faceID, new Float32Array(f.embedding)] as const,
|
||||
),
|
||||
)
|
||||
.flat(),
|
||||
);
|
||||
|
||||
const personClusterIDs = new Set(personClusters.map(({ id }) => id));
|
||||
const ignoredClusterIDs = new Set(ignoredClusters.map(({ id }) => id));
|
||||
|
||||
const personFaceEmbeddings = personClusters
|
||||
.map(({ faces }) => faces.map((id) => embeddingByFaceID.get(id)))
|
||||
.flat()
|
||||
.filter((e) => !!e);
|
||||
|
||||
const suggestedClusters: FaceCluster[] = [];
|
||||
for (const cluster of clusters) {
|
||||
const { id, faces } = cluster;
|
||||
|
||||
if (personClusterIDs.has(id)) continue;
|
||||
if (ignoredClusterIDs.has(id)) continue;
|
||||
|
||||
let suggest = false;
|
||||
for (const fi of faces) {
|
||||
const ei = embeddingByFaceID.get(fi);
|
||||
if (!ei) continue;
|
||||
for (const ej of personFaceEmbeddings) {
|
||||
const csim = dotProduct(ei, ej);
|
||||
if (csim >= 0.6) {
|
||||
suggest = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (suggest) break;
|
||||
}
|
||||
|
||||
if (suggest) suggestedClusters.push(cluster);
|
||||
}
|
||||
|
||||
return suggestedClusters.sort((a, b) => b.faces.length - a.faces.length);
|
||||
};
|
||||
|
||||
@@ -32,9 +32,9 @@ import {
|
||||
} from "./cluster";
|
||||
import { saveFaceCrops } from "./crop";
|
||||
import {
|
||||
getFaceIndexes,
|
||||
getIndexableFileIDs,
|
||||
markIndexingFailed,
|
||||
savedFaceIndexes,
|
||||
saveIndexes,
|
||||
updateAssumingLocalFiles,
|
||||
} from "./db";
|
||||
@@ -327,7 +327,7 @@ export class MLWorker {
|
||||
*/
|
||||
async clusterFaces(masterKey: Uint8Array) {
|
||||
const clusters = await clusterFaces(
|
||||
await getFaceIndexes(),
|
||||
await savedFaceIndexes(),
|
||||
await getAllLocalFiles(),
|
||||
(progress) => this.updateClusteringProgress(progress),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user