[desktop] Clustering WIP - Part 2/x (#2701)

This commit is contained in:
Manav Rathi
2024-08-14 16:06:17 +05:30
committed by GitHub
12 changed files with 668 additions and 255 deletions

View File

@@ -1,6 +1,6 @@
import { FileType } from "@/media/file-type";
import { isMLEnabled } from "@/new/photos/services/ml";
import type { Person } from "@/new/photos/services/ml/people";
import type { SearchPerson } from "@/new/photos/services/search";
import { EnteFile } from "@/new/photos/types/file";
import CloseIcon from "@mui/icons-material/Close";
import { IconButton } from "@mui/material";
@@ -146,7 +146,7 @@ export default function SearchInput(props: Iprops) {
search = { files: selectedOption.value as number[] };
break;
case SuggestionType.PERSON:
search = { person: selectedOption.value as Person };
search = { person: selectedOption.value as SearchPerson };
break;
case SuggestionType.FILE_TYPE:
search = { fileType: selectedOption.value as FileType };

View File

@@ -9,8 +9,9 @@ import {
wipCluster,
wipClusterEnable,
} from "@/new/photos/services/ml";
import type { Person } from "@/new/photos/services/ml/people";
import { personDiff } from "@/new/photos/services/user-entity";
import { persons } from "@/new/photos/services/ml/db";
import type { SearchPerson } from "@/new/photos/services/search";
import { syncPersons } from "@/new/photos/services/user-entity";
import { EnteFile } from "@/new/photos/types/file";
import * as chrono from "chrono-node";
import { t } from "i18next";
@@ -27,7 +28,7 @@ import {
import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker";
import { getUniqueFiles } from "utils/file";
import { getFormattedDate } from "utils/search";
import { getEntityKey, getLatestEntities } from "./entityService";
import { getLatestEntities } from "./entityService";
import locationSearchService, { City } from "./locationSearchService";
const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]);
@@ -406,7 +407,7 @@ function convertSuggestionToSearchQuery(option: Suggestion): Search {
return { files: option.value as number[] };
case SuggestionType.PERSON:
return { person: option.value as Person };
return { person: option.value as SearchPerson };
case SuggestionType.FILE_TYPE:
return { fileType: option.value as FileType };
@@ -416,20 +417,19 @@ function convertSuggestionToSearchQuery(option: Suggestion): Search {
}
}
let done = false;
async function getAllPeople(limit: number = undefined) {
if (!(await wipClusterEnable())) return [];
if (done) return [];
done = true;
if (process.env.NEXT_PUBLIC_ENTE_WIP_CL_FETCH) {
const entityKey = await getEntityKey("person" as EntityType);
const peopleR = await personDiff(entityKey.data);
const r = peopleR.length;
log.debug(() => ["people", peopleR]);
if (r) return [];
return [];
await syncPersons();
const people = await persons();
log.debug(() => ["people", { people }]);
}
let people: Array<Person> = []; // await mlIDbStorage.getAllPeople();
let people: Array<SearchPerson> = []; // await mlIDbStorage.getAllPeople();
people = await wipCluster();
// await mlPeopleStore.iterate<Person, void>((person) => {
// people.push(person);

View File

@@ -1,6 +1,6 @@
import { FileType } from "@/media/file-type";
import type { MLStatus } from "@/new/photos/services/ml";
import type { Person } from "@/new/photos/services/ml/people";
import type { SearchPerson } from "@/new/photos/services/search";
import { EnteFile } from "@/new/photos/types/file";
import { City } from "services/locationSearchService";
import { LocationTagData } from "types/entity";
@@ -30,7 +30,7 @@ export interface Suggestion {
value:
| DateValue
| number[]
| Person
| SearchPerson
| MLStatus
| LocationTagData
| City
@@ -45,7 +45,7 @@ export type Search = {
city?: City;
collection?: number;
files?: number[];
person?: Person;
person?: SearchPerson;
fileType?: FileType;
clip?: ClipSearchScores;
};

View File

@@ -0,0 +1,31 @@
import { sharedCryptoWorker } from "@/base/crypto";
import { z } from "zod";
/**
* Return the base64 encoded user's encryption key from session storage.
*
* Precondition: The user should be logged in.
*/
export const usersEncryptionKeyB64 = async () => {
// TODO: Same value as the deprecated SESSION_KEYS.ENCRYPTION_KEY.
const value = sessionStorage.getItem("encryptionKey");
if (!value) {
throw new Error(
"The user's encryption key was not found in session storage. Likely they are not logged in.",
);
}
const { encryptedData, key, nonce } = EncryptionKeyAttributes.parse(
JSON.parse(value),
);
const cryptoWorker = await sharedCryptoWorker();
return cryptoWorker.decryptB64(encryptedData, nonce, key);
};
// TODO: Same as B64EncryptionResult. Revisit.
const EncryptionKeyAttributes = z.object({
encryptedData: z.string(),
key: z.string(),
nonce: z.string(),
});

View File

@@ -1,14 +1,14 @@
import { faceCrop, unidentifiedFaceIDs } from "@/new/photos/services/ml";
import type { Person } from "@/new/photos/services/ml/people";
import type { EnteFile } from "@/new/photos/types/file";
import { Skeleton, Typography, styled } from "@mui/material";
import { t } from "i18next";
import React, { useEffect, useState } from "react";
import type { SearchPerson } from "../services/search";
export interface PeopleListProps {
people: Person[];
people: SearchPerson[];
maxRows: number;
onSelect?: (person: Person, index: number) => void;
onSelect?: (person: SearchPerson, index: number) => void;
}
export const PeopleList: React.FC<PeopleListProps> = ({
@@ -60,7 +60,7 @@ const FaceChip = styled("div")<{ clickable?: boolean }>`
export interface PhotoPeopleListProps {
file: EnteFile;
onSelect?: (person: Person, index: number) => void;
onSelect?: (person: SearchPerson, index: number) => void;
}
export function PhotoPeopleList() {

View File

@@ -186,5 +186,8 @@ export const clipMatches = async (
// This code is on the hot path, so these optimizations help.
[fileID, dotProduct(embedding, textEmbedding)] as const,
);
return new Map(items.filter(([, score]) => score >= 0.2));
// This score threshold was obtain heuristically. 0.2 generally gives solid
// results, and around 0.15 we start getting many false positives (all this
// is query dependent too).
return new Map(items.filter(([, score]) => score >= 0.175));
};

View File

@@ -1,13 +1,18 @@
import { newNonSecureID } from "@/base/id-worker";
import log from "@/base/log";
import { ensure } from "@/utils/ensure";
import type { FaceIndex } from "./face";
import { faceClusters, persons } from "./db";
import type { Face, FaceIndex } from "./face";
import { dotProduct } from "./math";
/**
* A face cluster is an set of faces.
*
* Each cluster has an id so that a Person (a set of clusters) can refer to it.
* Each cluster has an id so that a {@link Person} can refer to it.
*
* The cluster is not directly synced to remote. But it does indirectly get
* synced if it gets promoted or attached to a person (which can be thought of
* as a named or hidden clusters).
*/
export interface FaceCluster {
/**
@@ -15,42 +20,78 @@ export interface FaceCluster {
*/
id: string;
/**
* An unordered set of ids of the faces that belong to the cluster.
* An unordered set of ids of the faces that belong to this cluster.
*
* For ergonomics of transportation and persistence this is an array but it
* For ergonomics of transportation and persistence this is an array, but it
* should conceptually be thought of as a set.
*/
faceIDs: string[];
}
/**
* A Person is a set of clusters, with some attached metadata.
* A Person is a set of clusters with some attached metadata.
*
* The person is the user visible concept. It consists of a set of clusters,
* each of which itself is a set of faces.
* More precisely, a person is a a single cluster or a set of clusters that the
* user has interacted with.
*
* For ease of transportation, the Person entity on remote looks like
* The most frequent interaction is naming a {@link FaceCluster}, which promotes
* it to a become a {@link Person}. The promotion comes with the ability to be
* synced with remote (as a "person_v2" user entity).
*
* There after, the user may attach more clusters to the same {@link Person}.
*
* The other form of interaction is hiding. The user may hide a single (unnamed)
* cluster, or they may hide a person.
*
* The Person entity on remote has clusters embedded within itself
*
* { name, clusters: [{ clusterID, faceIDs }] }
*
* That is, it has the clusters embedded within itself.
* Since clusters don't get independently synced, one way to think about a
* Person is that it is an interaction with a cluster that we want to sync.
*/
export interface Person {
/**
* A nanoid for this person.
* A UUID or nanoid for this person.
*
* This is the ID of the Person user entity, it is not contained as part of
* the Person entity payload.
*/
id: string;
/**
* An optional name assigned by the user to this person.
* A name assigned by the user to this person.
*
* This can be missing or an empty string for an unnamed cluster that was
* hidden.
*/
name: string | undefined;
/**
* An unordered set of ids of the clusters that belong to this person.
*
* For ergonomics of transportation and persistence this is an array but it
* For ergonomics of transportation and persistence this is an array, but it
* should conceptually be thought of as a set.
*/
clusterIDs: string[];
/**
* True if this person should be hidden.
*
* This can also be true for unnamed hidden clusters. When the user hides a
* single cluster that was offered as a suggestion to them on a client, then
* the client will create a new person entity without a name, and set its
* hidden flag to sync it with remote (so that other clients can also stop
* showing this cluster).
*/
isHidden: boolean;
/**
* The ID of the face that should be used as the cover photo for this person
* (if the user has set one).
*/
avatarFaceID: string | undefined;
/**
* Locally determined ID of the "best" face that should be used as the
* display face, to represent this person in the UI.
*/
displayFaceID: string | undefined;
}
/**
@@ -58,34 +99,79 @@ export interface Person {
*
* [Note: Face clustering algorithm]
*
* 1. clusters = []
* A person consists of clusters, each of which itself is a set of faces.
*
* The clusters are generated using locally by clients using this algorithm:
*
* 1. clusters = [] initially, or fetched from remote.
*
* 2. For each face, find its nearest neighbour in the embedding space from
* amongst the faces that have already been clustered.
*
* 3. If no such neighbour is found within our threshold, create a new cluster.
*
* 4. Otherwise assign this face to the same cluster as its nearest neighbour.
*
* [Note: Face clustering feedback]
* This user can then tweak the output of the algorithm by performing the
* following actions to the list of clusters that they can see:
*
* This user can tweak the output of the algorithm by providing feedback. They
* can perform the following actions:
* - They can provide a name for a cluster. This upgrades a cluster into a
* "Person", which then gets synced via remote to all their devices.
*
* 1. Move a cluster from one person to another.
* 2. Break a cluster.
* - They can attach more clusters to a person.
*
* - They can remove a cluster from a person.
*
* After clustering, we also do some routine cleanup. Faces belonging to files
* that have been deleted (including those in Trash) should be pruned off.
*
* We should not make strict assumptions about the clusters we get from remote.
* In particular, the same face ID can be in different clusters. In such cases
* we should assign it arbitrarily assign it to the last cluster we find it in.
* Such leeway is intentionally provided to allow clients some slack in how they
* implement the sync without making an blocking API request for every user
* interaction.
*/
export const clusterFaces = (faceIndexes: FaceIndex[]) => {
export const clusterFaces = async (faceIndexes: FaceIndex[]) => {
const t = Date.now();
const faces = [...faceIDAndEmbeddings(faceIndexes)];
// A flattened array of faces.
const faces = [...enumerateFaces(faceIndexes)];
let clusters: FaceCluster[] = [];
const clusterIndexByFaceID = new Map<string, number>();
// Start with the clusters we already have (either from a previous indexing,
// or fetched from remote).
const clusters = await faceClusters();
// For fast reverse lookup - map from cluster ids to the index in the
// clusters array.
const clusterIndexForClusterID = new Map(clusters.map((c, i) => [c.id, i]));
// For fast reverse lookup - map from face ids to the id of the cluster to
// which they belong.
const clusterIDForFaceID = new Map(
clusters.flatMap((c) =>
c.faceIDs.map((faceID) => [faceID, c.id] as const),
),
);
// New cluster ID generator function.
const newClusterID = () => newNonSecureID("cluster_");
// For each face,
for (const [i, { faceID, embedding }] of faces.entries()) {
// Find the nearest neighbour from among the faces we have already seen.
let nnIndex: number | undefined;
// If the face is already part of a cluster, then skip it.
if (clusterIDForFaceID.get(faceID)) continue;
// Find the nearest neighbour from among all the other faces.
let nn: Face | undefined;
let nnCosineSimilarity = 0;
for (let j = 0; j < i; j++) {
// Can't find a way of avoiding the null assertion.
for (let j = 0; j < faces.length; j++) {
// ! This is an O(n^2) loop, be careful when adding more code here.
// Skip ourselves.
if (i == j) continue;
// Can't find a way of avoiding the null assertion here.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const n = faces[j]!;
@@ -93,50 +179,100 @@ export const clusterFaces = (faceIndexes: FaceIndex[]) => {
// dot product as their cosine similarity.
const csim = dotProduct(embedding, n.embedding);
if (csim > 0.76 && csim > nnCosineSimilarity) {
nnIndex = j;
nn = n;
nnCosineSimilarity = csim;
}
}
if (nnIndex === undefined) {
// We didn't find a neighbour. Create a new cluster with this face.
const cluster = {
id: newNonSecureID("cluster_"),
faceIDs: [faceID],
};
clusters.push(cluster);
clusterIndexByFaceID.set(faceID, clusters.length);
if (nn) {
// Found a neighbour near enough.
// Find the cluster the nearest neighbour belongs to, if any.
const nnClusterID = clusterIDForFaceID.get(nn.faceID);
if (nnClusterID) {
// If the neighbour is already part of a cluster, also add
// ourselves to that cluster.
const nnClusterIndex = ensure(
clusterIndexForClusterID.get(nnClusterID),
);
clusters[nnClusterIndex]?.faceIDs.push(faceID);
clusterIDForFaceID.set(faceID, nnClusterID);
} else {
// Otherwise create a new cluster with us and our nearest
// neighbour.
const cluster = {
id: newClusterID(),
faceIDs: [faceID, nn.faceID],
};
clusters.push(cluster);
clusterIndexForClusterID.set(cluster.id, clusters.length);
clusterIDForFaceID.set(faceID, cluster.id);
clusterIDForFaceID.set(nn.faceID, cluster.id);
}
} else {
// Found a neighbour near enough. Add this face to the neighbour's
// cluster.
// We didn't find a neighbour within the threshold. Create a new
// cluster with only this face.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const nn = faces[nnIndex]!;
const nnClusterIndex = ensure(clusterIndexByFaceID.get(nn.faceID));
clusters[nnClusterIndex]?.faceIDs.push(faceID);
clusterIndexByFaceID.set(faceID, nnClusterIndex);
const cluster = { id: newClusterID(), faceIDs: [faceID] };
clusters.push(cluster);
clusterIndexForClusterID.set(cluster.id, clusters.length);
clusterIDForFaceID.set(faceID, cluster.id);
}
}
clusters = clusters.filter(({ faceIDs }) => faceIDs.length > 1);
// Prune too small clusters.
const validClusters = clusters.filter(({ faceIDs }) => faceIDs.length > 1);
log.debug(() => ["ml/cluster", { faces, clusters, clusterIndexByFaceID }]);
// For each person, use the highest scoring face in any of its clusters as
// its display face.
const faceForFaceID = new Map(faces.map((f) => [f.faceID, f]));
const people = await persons();
for (const person of people) {
person.avatarFaceID = person.clusterIDs
.map((clusterID) => clusterIndexForClusterID.get(clusterID))
.map((clusterIndex) =>
clusterIndex ? clusters[clusterIndex] : undefined,
)
.filter((cluster) => !!cluster)
.flatMap((cluster) => cluster.faceIDs)
.map((id) => faceForFaceID.get(id))
.filter((face) => !!face)
.reduce((topFace, face) =>
topFace.score > face.score ? topFace : face,
).faceID;
}
log.debug(() => [
"ml/cluster",
{
faces,
validClusters,
clusterIndexForClusterID,
clusterIDForFaceID,
people,
},
]);
log.debug(
() =>
`Clustered ${faces.length} faces into ${clusters.length} clusters (${Date.now() - t} ms)`,
`Clustered ${faces.length} faces into ${validClusters.length} clusters (${Date.now() - t} ms)`,
);
return clusters;
return { clusters: validClusters, people };
};
/**
* A generator function that returns a stream of {faceID, embedding} values,
* flattening all the all the faces present in the given {@link faceIndices}.
*/
function* faceIDAndEmbeddings(faceIndices: FaceIndex[]) {
function* enumerateFaces(faceIndices: FaceIndex[]) {
for (const fi of faceIndices) {
for (const f of fi.faces) {
yield { faceID: f.faceID, embedding: f.embedding };
yield f;
}
}
}

View File

@@ -3,6 +3,7 @@ import log from "@/base/log";
import localForage from "@ente/shared/storage/localForage";
import { deleteDB, openDB, type DBSchema } from "idb";
import type { LocalCLIPIndex } from "./clip";
import type { FaceCluster, Person } from "./cluster-new";
import type { LocalFaceIndex } from "./face";
/**
@@ -43,6 +44,14 @@ interface MLDBSchema extends DBSchema {
key: number;
value: LocalCLIPIndex;
};
"face-cluster": {
key: string;
value: FaceCluster;
};
person: {
key: string;
value: Person;
};
}
interface FileStatus {
@@ -98,6 +107,13 @@ const openMLDB = async () => {
if (oldVersion < 2) {
db.createObjectStore("clip-index", { keyPath: "fileID" });
}
// TODO-Cluster
if (oldVersion < 3) {
if (process.env.NEXT_PUBLIC_ENTE_WIP_CL) {
db.createObjectStore("face-cluster", { keyPath: "id" });
db.createObjectStore("person", { keyPath: "id" });
}
}
},
blocking() {
log.info(
@@ -393,3 +409,92 @@ export const markIndexingFailed = async (fileID: number) => {
fileStatus.failureCount = fileStatus.failureCount + 1;
await Promise.all([tx.store.put(fileStatus), tx.done]);
};
/**
* Return all face clusters present locally.
*/
export const faceClusters = async () => {
const db = await mlDB();
return db.getAll("face-cluster");
};
/**
* Return all person entries (aka "people") present locally.
*/
export const persons = async () => {
const db = await mlDB();
return db.getAll("person");
};
/**
* Replace the face clusters stored locally with the given ones.
*
* This function deletes all entries from the person object store, and then
* inserts the given {@link clusters} into it.
*/
export const setFaceClusters = async (clusters: FaceCluster[]) => {
const db = await mlDB();
const tx = db.transaction("face-cluster", "readwrite");
await tx.store.clear();
await Promise.all(clusters.map((cluster) => tx.store.put(cluster)));
return tx.done;
};
/**
* Update the person store to reflect the given changes, in order.
*
* @param diff A list of changes to apply. Each entry is either
*
* - A string, in which case the person with the given string as their ID
* should be deleted from the store, or
*
* - A person, in which case it should add or overwrite the entry for the
* corresponding person (as identified by their {@link id}).
*/
export const applyPersonDiff = async (diff: (string | Person)[]) => {
const db = await mlDB();
const tx = db.transaction("person", "readwrite");
// See: [Note: Diff response will have at most one entry for an id]
await Promise.all(
diff.map((d) =>
typeof d == "string" ? tx.store.delete(d) : tx.store.put(d),
),
);
return tx.done;
};
/**
* Add or overwrite the entry for the given {@link person}, as identified by
* their {@link id}.
*/
// TODO-Cluster: Remove me
export const savePerson = async (person: Person) => {
const db = await mlDB();
const tx = db.transaction("person", "readwrite");
await Promise.all([tx.store.put(person), tx.done]);
};
/**
* Delete the entry for the persons with the given {@link id}, if any.
*/
// TODO-Cluster: Remove me
export const deletePerson = async (id: string) => {
const db = await mlDB();
const tx = db.transaction("person", "readwrite");
await Promise.all([tx.store.delete(id), tx.done]);
};
/**
* Replace the persons stored locally with the given ones.
*
* This function deletes all entries from the person object store, and then
* inserts the given {@link persons} into it.
*/
// TODO-Cluster: Remove me
export const setPersons = async (persons: Person[]) => {
const db = await mlDB();
const tx = db.transaction("person", "readwrite");
await tx.store.clear();
await Promise.all(persons.map((person) => tx.store.put(person)));
return tx.done;
};

View File

@@ -18,6 +18,7 @@ import { proxy, transfer } from "comlink";
import { isInternalUser } from "../feature-flags";
import { getAllLocalFiles } from "../files";
import { getRemoteFlag, updateRemoteFlag } from "../remote-store";
import type { SearchPerson } from "../search";
import type { UploadItem } from "../upload/types";
import { clusterFaces } from "./cluster-new";
import { regenerateFaceCrops } from "./crop";
@@ -27,7 +28,6 @@ import {
faceIndexes,
indexableAndIndexedCounts,
} from "./db";
import type { Person } from "./people";
import { MLWorker } from "./worker";
import type { CLIPMatches } from "./worker-types";
@@ -314,7 +314,8 @@ export const indexNewUpload = (enteFile: EnteFile, uploadItem: UploadItem) => {
void worker().then((w) => w.onUpload(enteFile, uploadItem));
};
let last: Person[] | undefined;
// TODO-Cluster temporary import here
let last: SearchPerson[] | undefined;
/**
* WIP! Don't enable, dragon eggs are hatching here.
@@ -330,32 +331,48 @@ export const wipCluster = async () => {
if (last) return last;
const clusters = clusterFaces(await faceIndexes());
const { clusters, people } = await clusterFaces(await faceIndexes());
const clusterByID = new Map(
clusters.map((cluster) => [cluster.id, cluster]),
);
const localFiles = await getAllLocalFiles();
const localFilesByID = new Map(localFiles.map((f) => [f.id, f]));
const people: Person[] = []; // await mlIDbStorage.getAllPeople();
for (const cluster of clusters) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const dfID = cluster.faceIDs[0]!;
const dfFile = localFilesByID.get(fileIDFromFaceID(dfID) ?? 0);
if (!dfFile) {
assertionFailed(`Face ID ${dfID} without local file`);
const result: SearchPerson[] = [];
for (const person of people) {
let avatarFaceID = person.avatarFaceID;
// TODO-Cluster
// Temp
if (!avatarFaceID) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
avatarFaceID = person.clusterIDs
.map((id) => clusterByID.get(id))
.flatMap((cluster) => cluster?.faceIDs ?? [])[0]!;
}
person.clusterIDs;
const avatarFaceFileID = fileIDFromFaceID(avatarFaceID);
const avatarFaceFile = localFilesByID.get(avatarFaceFileID ?? 0);
if (!avatarFaceFileID || !avatarFaceFile) {
assertionFailed(`Face ID ${avatarFaceID} without local file`);
continue;
}
people.push({
id: Math.random(), //cluster.id,
name: "test",
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
files: cluster.faceIDs.map((s) => parseInt(s.split("_")[0]!)),
displayFaceID: dfID,
displayFaceFile: dfFile,
const files = person.clusterIDs
.map((id) => clusterByID.get(id))
.flatMap((cluster) => cluster?.faceIDs ?? [])
.map((faceID) => fileIDFromFaceID(faceID))
.filter((fileID) => fileID !== undefined);
result.push({
id: person.id,
name: person.name,
files,
displayFaceID: avatarFaceID,
displayFaceFile: avatarFaceFile,
});
}
last = people;
return people;
last = result;
return result;
};
export type MLStatus =
@@ -510,7 +527,8 @@ export const unidentifiedFaceIDs = async (
};
/**
* Extract the ID of the {@link EnteFile} to which a face belongs from its ID.
* Extract the fileID of the {@link EnteFile} to which the face belongs from its
* faceID.
*/
const fileIDFromFaceID = (faceID: string) => {
const fileID = parseInt(faceID.split("_")[0] ?? "");

View File

@@ -1,136 +0,0 @@
import type { EnteFile } from "../../types/file";
export interface Person {
id: number;
name?: string;
files: number[];
displayFaceID: string;
displayFaceFile: EnteFile;
}
// Forced disable clustering. It doesn't currently work.
//
// > Error: Failed to execute 'transferToImageBitmap' on
// > 'OffscreenCanvas': ImageBitmap construction failed
/*
export const syncPeopleIndex = async () => {
if (
syncContext.outOfSyncFiles.length <= 0 ||
(syncContext.nSyncedFiles === batchSize && Math.random() < 0)
) {
await this.syncIndex(syncContext);
}
public async syncIndex(syncContext: MLSyncContext) {
await this.getMLLibraryData(syncContext);
await syncPeopleIndex(syncContext);
await this.persistMLLibraryData(syncContext);
}
const filesVersion = await mlIDbStorage.getIndexVersion("files");
if (filesVersion <= (await mlIDbStorage.getIndexVersion("people"))) {
return;
}
// TODO: have faces addresable through fileId + faceId
// to avoid index based addressing, which is prone to wrong results
// one way could be to match nearest face within threshold in the file
const allFacesMap =
syncContext.allSyncedFacesMap ??
(syncContext.allSyncedFacesMap = await mlIDbStorage.getAllFacesMap());
// await this.init();
const allFacesMap = await mlIDbStorage.getAllFacesMap();
const allFaces = [...allFacesMap.values()].flat();
if (!allFaces || allFaces.length < 50) {
log.info(
`Skipping clustering since number of faces (${allFaces.length}) is less than the clustering threshold (50)`,
);
return;
}
log.info("Running clustering allFaces: ", allFaces.length);
const faceClusteringResults = await clusterFaces(
allFaces.map((f) => Array.from(f.embedding)),
);
log.info(
"[MLService] Got face clustering results: ",
JSON.stringify(faceClusteringResults),
);
const clusters = faceClusteringResults?.clusters;
if (!clusters || clusters.length < 1) {
return;
}
for (const face of allFaces) {
face.personId = undefined;
}
await mlIDbStorage.clearAllPeople();
for (const [index, cluster] of clusters.entries()) {
const faces = cluster.map((f) => allFaces[f]).filter((f) => f);
// TODO: take default display face from last leaves of hdbscan clusters
const personFace = faces.reduce((best, face) =>
face.detection.probability > best.detection.probability
? face
: best,
);
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
if (personFace && !personFace.crop?.cacheKey) {
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
await saveFaceCrop(imageBitmap, personFace);
}
const person: Person = {
id: index,
files: faces.map((f) => f.fileId),
displayFaceId: personFace?.id,
};
await mlIDbStorage.putPerson(person);
faces.forEach((face) => {
face.personId = person.id;
});
// log.info("Creating person: ", person, faces);
}
await mlIDbStorage.updateFaces(allFacesMap);
// await mlIDbStorage.setIndexVersion("people", filesVersion);
};
public async regenerateFaceCrop(token: string, faceID: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.regenerateFaceCrop(faceID);
}
export const regenerateFaceCrop = async (faceID: string) => {
const fileID = Number(faceID.split("-")[0]);
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
throw Error("Face not found");
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
return await saveFaceCrop(imageBitmap, personFace);
};
*/

View File

@@ -0,0 +1,13 @@
import type { EnteFile } from "@/new/photos/types/file";
/**
* A massaged version of {@link Person} suitable for being shown in search
* results.
*/
export interface SearchPerson {
id: string;
name?: string;
files: number[];
displayFaceID: string;
displayFaceFile: EnteFile;
}

View File

@@ -1,7 +1,14 @@
import { sharedCryptoWorker } from "@/base/crypto";
import { decryptAssociatedB64Data } from "@/base/crypto/ente";
import { authenticatedRequestHeaders, ensureOk } from "@/base/http";
import { authenticatedRequestHeaders, ensureOk, HTTPError } from "@/base/http";
import { getKV, getKVN, setKV } from "@/base/kv";
import { apiURL } from "@/base/origins";
import { usersEncryptionKeyB64 } from "@/base/session-store";
import { nullToUndefined } from "@/utils/transform";
import { z } from "zod";
import { gunzip } from "./gzip";
import type { Person } from "./ml/cluster-new";
import { applyPersonDiff } from "./ml/db";
/**
* User entities are predefined lists of otherwise arbitrary data that the user
@@ -10,12 +17,11 @@ import { z } from "zod";
* e.g. location tags, people in their photos.
*/
export type EntityType =
| "person"
/**
* A new version of the Person entity where the data is gzipped before
* encryption.
* The latest iteration of the Person entity format, where the data is
* gzipped before encryption.
*/
| "person_v2";
"person_v2";
/**
* The maximum number of items to fetch in a single diff
@@ -51,6 +57,10 @@ const defaultDiffLimit = 500;
* expected to be associated with this entity type.
*/
interface UserEntity {
/**
* A UUID or nanoid for the entity.
*/
id: string;
/**
* Arbitrary data associated with the entity. The format of this data is
* specific to each entity type.
@@ -64,18 +74,31 @@ interface UserEntity {
updatedAt: number;
}
/** Zod schema for {@link RemoteUserEntity} */
const RemoteUserEntity = z.object({
/** Base64 string containing the encrypted contents of the entity. */
encryptedData: z.string(),
/** Base64 string containing the decryption header. */
header: z.string(),
id: z.string(),
/**
* Base64 string containing the encrypted contents of the entity.
*
* Will be `null` when isDeleted is true.
*/
encryptedData: z.string().nullable(),
/**
* Base64 string containing the decryption header.
*
* Will be `null` when isDeleted is true.
*/
header: z.string().nullable(),
isDeleted: z.boolean(),
updatedAt: z.number(),
});
/** An item in the user entity diff response we get from remote. */
type RemoteUserEntity = z.infer<typeof RemoteUserEntity>;
/**
* Fetch all user entities of the given type that have been created or updated
* since the given time.
* Fetch the next batch of user entities of the given type that have been
* created or updated since the given time.
*
* @param type The type of the entities to fetch.
*
@@ -85,12 +108,45 @@ const RemoteUserEntity = z.object({
*
* @param entityKeyB64 The base64 encoded key to use for decrypting the
* encrypted contents of the user entity.
*
* [Note: Diff response will have at most one entry for an id]
*
* Unlike git diffs which track all changes, the diffs we get from remote are
* guaranteed to contain only one entry (upsert or delete) for particular Ente
* object. This holds true irrespective of the diff limit.
*
* For example, in the user entity diff response, it is guaranteed that there
* will only be at max one entry for a particular entity id. The entry will have
* no data to indicate that the corresponding entity was deleted. Otherwise,
* when the data is present, it is taken as the creation of a new entity or the
* updation of an existing one.
*
* This behaviour comes from how remote stores the underlying, e.g., entities. A
* diff returns just entities whose updation times greater than the provided
* since time (limited to the given diff limit). So there will be at most one
* row for a particular entity id. And if that entity has been deleted, then the
* row will be a tombstone so data will be not be present.
*/
export const userEntityDiff = async (
type: EntityType,
sinceTime: number,
entityKeyB64: string,
): Promise<UserEntity[]> => {
const parse = async ({
id,
encryptedData,
header,
isDeleted,
updatedAt,
}: RemoteUserEntity) => ({
id,
data:
encryptedData && header && !isDeleted
? await decrypt(encryptedData, header)
: undefined,
updatedAt,
});
const decrypt = (encryptedDataB64: string, decryptionHeaderB64: string) =>
decryptAssociatedB64Data({
encryptedDataB64,
@@ -111,26 +167,213 @@ export const userEntityDiff = async (
const entities = z
.object({ diff: z.array(RemoteUserEntity) })
.parse(await res.json()).diff;
return Promise.all(
entities.map(
async ({ encryptedData, header, isDeleted, updatedAt }) => ({
data: isDeleted
? undefined
: await decrypt(encryptedData, header),
updatedAt,
}),
),
);
return Promise.all(entities.map(parse));
};
/**
* Fetch all Person entities that have been created or updated since the last
* time we checked.
* Return the entity key that can be used to decrypt the encrypted contents of
* user entities of the given {@link type}.
*
* 1. We'll see if we have the (encrypted) entity key present locally. If so,
* we'll decrypt it using the user's master key and return it.
*
* 2. Otherwise we'll fetch the entity key for that type from remote. If found,
* we'll decrypte it using the user's master key and return it, also saving
* it locally for future use.
*
* 3. Otherwise we'll create a new one, save it locally and put it to remote.
*
* See also, [Note: User entity keys].
*/
export const personDiff = async (entityKeyB64: string) => {
const entities = await userEntityDiff("person", 0, entityKeyB64);
return entities.map(({ data }) => {
if (!data) return undefined;
return JSON.parse(new TextDecoder().decode(data)) as unknown;
});
const getOrCreateEntityKeyB64 = async (type: EntityType) => {
const encryptionKeyB64 = await usersEncryptionKeyB64();
const worker = await sharedCryptoWorker();
const decrypt = async ({ encryptedKey, header }: RemoteUserEntityKey) => {
return worker.decryptB64(encryptedKey, header, encryptionKeyB64);
};
// See if we already have it locally.
const saved = await savedRemoteUserEntityKey(type);
if (saved) return decrypt(saved);
// See if remote already has it.
const existing = await getUserEntityKey(type);
if (existing) {
// Only save it if we can decrypt it to avoid corrupting our local state
// in unforeseen circumstances.
const result = decrypt(existing);
await saveRemoteUserEntityKey(type, existing);
return result;
}
// Nada. Create a new one, put it to remote, save it locally, and return.
// TODO-Cluster Keep this read only, only add the writeable bits after other
// stuff has been tested.
throw new Error("Not implemented");
// const generatedKeyB64 = await worker.generateEncryptionKey();
// const encryptedNewKey = await worker.encryptToB64(
// generatedKeyB64,
// encryptionKeyB64,
// );
// await postUserEntityKey(type, newKey);
// return decrypt(newKey);
};
const entityKeyKey = (type: EntityType) => `entityKey/${type}`;
/**
* Return the locally persisted {@link RemoteUserEntityKey}, if any,
* corresponding the given {@link type}.
*/
const savedRemoteUserEntityKey = (
type: EntityType,
): Promise<RemoteUserEntityKey | undefined> =>
getKV(entityKeyKey(type)).then((s) =>
s ? RemoteUserEntityKey.parse(JSON.parse(s)) : undefined,
);
/**
* Setter for {@link entityKey}.
*/
const saveRemoteUserEntityKey = (
type: EntityType,
entityKey: RemoteUserEntityKey,
) => setKV(entityKeyKey(type), JSON.stringify(entityKey));
/**
* Fetch the encryption key for the given user entity {@link type} from remote.
*
* [Note: User entity keys]
*
* There is one encryption key (itself encrypted with the user's encryption key)
* for each user entity type. If the key doesn't exist on remote, then the
* client is expected to create one on the user's behalf. Remote will disallow
* attempts to multiple keys for the same user entity type.
*/
const getUserEntityKey = async (
type: EntityType,
): Promise<RemoteUserEntityKey | undefined> => {
const params = new URLSearchParams({ type });
const url = await apiURL("/user-entity/key");
const res = await fetch(`${url}?${params.toString()}`, {
headers: await authenticatedRequestHeaders(),
});
if (!res.ok) {
// Remote says HTTP 404 Not Found if there is no key yet for the user.
if (res.status == 404) return undefined;
throw new HTTPError(res);
} else {
return RemoteUserEntityKey.parse(await res.json());
}
};
const RemoteUserEntityKey = z.object({
encryptedKey: z.string(),
header: z.string(),
});
type RemoteUserEntityKey = z.infer<typeof RemoteUserEntityKey>;
/**
* Create a new encryption key for the given user entity {@link type} on remote.
*
* See: [Note: User entity keys]
*/
// TODO-Cluster remove export
export const postUserEntityKey = async (
type: EntityType,
entityKey: RemoteUserEntityKey,
) => {
const url = await apiURL("/user-entity/key");
const res = await fetch(url, {
method: "POST",
headers: await authenticatedRequestHeaders(),
body: JSON.stringify({ type, ...entityKey }),
});
ensureOk(res);
};
const latestUpdatedAtKey = (type: EntityType) => `latestUpdatedAt/${type}`;
/**
* Return the locally persisted value for the latest `updatedAt` time for the
* given entity {@link type}.
*
* This is used to checkpoint diffs, so that we can resume fetching from the
* last time we did a fetch.
*/
const savedLatestUpdatedAt = (type: EntityType) =>
getKVN(latestUpdatedAtKey(type));
/**
* Setter for {@link savedLatestUpdatedAt}.
*/
const saveLatestUpdatedAt = (type: EntityType, value: number) =>
setKV(latestUpdatedAtKey(type), value);
/**
* Sync the {@link Person} entities that we have locally with remote.
*
* This fetches all the user entities corresponding to the "person_v2" entity
* type from remote that have been created, updated or deleted since the last
* time we checked.
*
* This diff is then applied to the data we have persisted locally.
*/
export const syncPersons = async () => {
const type: EntityType = "person_v2";
const entityKeyB64 = await getOrCreateEntityKeyB64(type);
const parse = async (id: string, data: Uint8Array): Promise<Person> => {
const rp = RemotePerson.parse(JSON.parse(await gunzip(data)));
return {
id,
name: rp.name,
clusterIDs: rp.assigned.map(({ id }) => id),
isHidden: rp.isHidden,
avatarFaceID: rp.avatarFaceID,
displayFaceID: undefined,
};
};
let sinceTime = (await savedLatestUpdatedAt(type)) ?? 0;
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition, no-constant-condition
while (true) {
const entities = await userEntityDiff(type, sinceTime, entityKeyB64);
if (entities.length == 0) break;
await applyPersonDiff(
await Promise.all(
entities.map(async ({ id, data }) =>
data ? await parse(id, data) : id,
),
),
);
sinceTime = entities.reduce(
(max, entity) => Math.max(max, entity.updatedAt),
sinceTime,
);
await saveLatestUpdatedAt(type, sinceTime);
}
};
/** Zod schema for the {@link RemotePerson} type. */
const RemotePerson = z.object({
name: z.string().nullish().transform(nullToUndefined),
assigned: z.array(
z.object({
id: z.string(),
faces: z.string().array(),
}),
),
isHidden: z.boolean(),
avatarFaceID: z.string().nullish().transform(nullToUndefined),
});
/**
* A "person_v2" entity as synced via remote.
*/
type RemotePerson = z.infer<typeof RemotePerson>;