[web] Switch to new storage layer for face indexes (#1931)

This commit is contained in:
Manav Rathi
2024-05-30 19:58:46 +05:30
committed by GitHub
21 changed files with 664 additions and 1036 deletions

View File

@@ -11,7 +11,7 @@ import { Box, DialogProps, Link, Stack, styled } from "@mui/material";
import { Chip } from "components/Chip";
import { EnteDrawer } from "components/EnteDrawer";
import Titlebar from "components/Titlebar";
import { PhotoPeopleList, UnidentifiedFaces } from "components/ml/PeopleList";
import { UnidentifiedFaces } from "components/ml/PeopleList";
import LinkButton from "components/pages/gallery/LinkButton";
import { t } from "i18next";
import { AppContext } from "pages/_app";
@@ -96,8 +96,6 @@ export function FileInfo({
const [parsedExifData, setParsedExifData] = useState<Record<string, any>>();
const [showExif, setShowExif] = useState(false);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [updateMLDataIndex, setUpdateMLDataIndex] = useState(0);
const openExif = () => setShowExif(true);
const closeExif = () => setShowExif(false);
@@ -332,14 +330,8 @@ export function FileInfo({
{appContext.mlSearchEnabled && (
<>
<PhotoPeopleList
file={file}
updateMLDataIndex={updateMLDataIndex}
/>
<UnidentifiedFaces
file={file}
updateMLDataIndex={updateMLDataIndex}
/>
{/* <PhotoPeopleList file={file} /> */}
<UnidentifiedFaces file={file} />
</>
)}
</Stack>

View File

@@ -5,7 +5,6 @@ import { t } from "i18next";
import { AppContext } from "pages/_app";
import { useContext } from "react";
import { components } from "react-select";
import { IndexStatus } from "services/face/db-old";
import { Suggestion, SuggestionType } from "types/search";
const { Menu } = components;
@@ -35,7 +34,7 @@ const MenuWithPeople = (props) => {
(o) => o.type === SuggestionType.INDEX_STATUS,
)[0] as Suggestion;
const indexStatus = indexStatusSuggestion?.value as IndexStatus;
const indexStatus = indexStatusSuggestion?.value;
return (
<Menu {...props}>
<Box my={1}>

View File

@@ -1,10 +1,9 @@
import { blobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { Skeleton, styled } from "@mui/material";
import { Legend } from "components/PhotoViewer/styledComponents/Legend";
import { t } from "i18next";
import React, { useEffect, useState } from "react";
import mlIDbStorage from "services/face/db-old";
import { unidentifiedFaceIDs } from "services/face/indexer";
import type { Person } from "services/face/people";
import { EnteFile } from "types/file";
@@ -67,63 +66,29 @@ export const PeopleList = React.memo((props: PeopleListProps) => {
export interface PhotoPeopleListProps extends PeopleListPropsBase {
file: EnteFile;
updateMLDataIndex: number;
}
export function PhotoPeopleList(props: PhotoPeopleListProps) {
const [people, setPeople] = useState<Array<Person>>([]);
export function PhotoPeopleList() {
return <></>;
}
export function UnidentifiedFaces({ file }: { file: EnteFile }) {
const [faceIDs, setFaceIDs] = useState<string[]>([]);
useEffect(() => {
let didCancel = false;
async function updateFaceImages() {
log.info("calling getPeopleList");
const startTime = Date.now();
const people = await getPeopleList(props.file);
log.info(`getPeopleList ${Date.now() - startTime} ms`);
log.info(`getPeopleList done, didCancel: ${didCancel}`);
!didCancel && setPeople(people);
}
updateFaceImages();
(async () => {
const faceIDs = await unidentifiedFaceIDs(file);
!didCancel && setFaceIDs(faceIDs);
})();
return () => {
didCancel = true;
};
}, [props.file, props.updateMLDataIndex]);
}, [file]);
if (people.length === 0) return <></>;
return (
<div>
<Legend>{t("PEOPLE")}</Legend>
<PeopleList people={people} onSelect={props.onSelect}></PeopleList>
</div>
);
}
export function UnidentifiedFaces(props: {
file: EnteFile;
updateMLDataIndex: number;
}) {
const [faces, setFaces] = useState<{ id: string }[]>([]);
useEffect(() => {
let didCancel = false;
async function updateFaceImages() {
const faces = await getUnidentifiedFaces(props.file);
!didCancel && setFaces(faces);
}
updateFaceImages();
return () => {
didCancel = true;
};
}, [props.file, props.updateMLDataIndex]);
if (!faces || faces.length === 0) return <></>;
if (faceIDs.length == 0) return <></>;
return (
<>
@@ -131,12 +96,11 @@ export function UnidentifiedFaces(props: {
<Legend>{t("UNIDENTIFIED_FACES")}</Legend>
</div>
<FaceChipContainer>
{faces &&
faces.map((face, index) => (
<FaceChip key={index}>
<FaceCropImageView faceID={face.id} />
</FaceChip>
))}
{faceIDs.map((faceID) => (
<FaceChip key={faceID}>
<FaceCropImageView {...{ faceID }} />
</FaceChip>
))}
</FaceChipContainer>
</>
);
@@ -179,45 +143,3 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({ faceID }) => {
<Skeleton variant="circular" height={120} width={120} />
);
};
async function getPeopleList(file: EnteFile): Promise<Person[]> {
let startTime = Date.now();
const mlFileData = await mlIDbStorage.getFile(file.id);
log.info(
"getPeopleList:mlFilesStore:getItem",
Date.now() - startTime,
"ms",
);
if (!mlFileData?.faces || mlFileData.faces.length < 1) {
return [];
}
const peopleIds = mlFileData.faces
.filter((f) => f.personId !== null && f.personId !== undefined)
.map((f) => f.personId);
if (!peopleIds || peopleIds.length < 1) {
return [];
}
// log.info("peopleIds: ", peopleIds);
startTime = Date.now();
const peoplePromises = peopleIds.map(
(p) => mlIDbStorage.getPerson(p) as Promise<Person>,
);
const peopleList = await Promise.all(peoplePromises);
log.info(
"getPeopleList:mlPeopleStore:getItems",
Date.now() - startTime,
"ms",
);
// log.info("peopleList: ", peopleList);
return peopleList;
}
async function getUnidentifiedFaces(file: EnteFile): Promise<{ id: string }[]> {
const mlFileData = await mlIDbStorage.getFile(file.id);
return mlFileData?.faces?.filter(
(f) => f.personId === null || f.personId === undefined,
);
}

View File

@@ -50,11 +50,11 @@ import { createContext, useContext, useEffect, useRef, useState } from "react";
import LoadingBar from "react-top-loading-bar";
import DownloadManager from "services/download";
import { resumeExportsIfNeeded } from "services/export";
import { photosLogout } from "services/logout";
import {
getMLSearchConfig,
updateMLSearchConfig,
} from "services/machineLearning/machineLearningService";
isFaceIndexingEnabled,
setIsFaceIndexingEnabled,
} from "services/face/indexer";
import { photosLogout } from "services/logout";
import mlWorkManager from "services/machineLearning/mlWorkManager";
import {
getFamilyPortalRedirectURL,
@@ -186,9 +186,9 @@ export default function App({ Component, pageProps }: AppProps) {
}
const loadMlSearchState = async () => {
try {
const mlSearchConfig = await getMLSearchConfig();
setMlSearchEnabled(mlSearchConfig.enabled);
mlWorkManager.setMlSearchEnabled(mlSearchConfig.enabled);
const enabled = await isFaceIndexingEnabled();
setMlSearchEnabled(enabled);
mlWorkManager.setMlSearchEnabled(enabled);
} catch (e) {
log.error("Error while loading mlSearchEnabled", e);
}
@@ -286,9 +286,7 @@ export default function App({ Component, pageProps }: AppProps) {
const showNavBar = (show: boolean) => setShowNavBar(show);
const updateMlSearchEnabled = async (enabled: boolean) => {
try {
const mlSearchConfig = await getMLSearchConfig();
mlSearchConfig.enabled = enabled;
await updateMLSearchConfig(mlSearchConfig);
await setIsFaceIndexingEnabled(enabled);
setMlSearchEnabled(enabled);
mlWorkManager.setMlSearchEnabled(enabled);
} catch (e) {

View File

@@ -85,10 +85,7 @@ import {
getSectionSummaries,
} from "services/collectionService";
import downloadManager from "services/download";
import {
syncCLIPEmbeddings,
syncFaceEmbeddings,
} from "services/embeddingService";
import { syncCLIPEmbeddings } from "services/embeddingService";
import { syncEntities } from "services/entityService";
import { getLocalFiles, syncFiles } from "services/fileService";
import locationSearchService from "services/locationSearchService";
@@ -130,7 +127,6 @@ import {
} from "utils/file";
import { isArchivedFile } from "utils/magicMetadata";
import { getSessionExpiredMessage } from "utils/ui";
import { isInternalUserForML } from "utils/user";
import { getLocalFamilyData } from "utils/user/family";
export const DeadCenter = styled("div")`
@@ -720,7 +716,9 @@ export default function Gallery() {
const electron = globalThis.electron;
if (electron) {
await syncCLIPEmbeddings();
if (isInternalUserForML()) await syncFaceEmbeddings();
// TODO-ML(MR): Disable fetch until we start storing it in the
// same place as the local ones.
// if (isInternalUserForML()) await syncFaceEmbeddings();
}
if (clipService.isPlatformSupported()) {
void clipService.scheduleImageEmbeddingExtraction();

View File

@@ -7,7 +7,6 @@ import HTTPService from "@ente/shared/network/HTTPService";
import { getEndpoint } from "@ente/shared/network/api";
import localForage from "@ente/shared/storage/localForage";
import { getToken } from "@ente/shared/storage/localStorage/helpers";
import { FileML } from "services/face/remote";
import type {
Embedding,
EmbeddingModel,
@@ -17,9 +16,13 @@ import type {
} from "types/embedding";
import { EnteFile } from "types/file";
import { getLocalCollections } from "./collectionService";
import type { FaceIndex } from "./face/types";
import { getAllLocalFiles } from "./fileService";
import { getLocalTrashedFiles } from "./trashService";
type FileML = FaceIndex & {
updatedAt: number;
};
const DIFF_LIMIT = 500;
/** Local storage key suffix for embedding sync times */

View File

@@ -1,14 +1,18 @@
import { blobCache } from "@/next/blob-cache";
import type { FaceAlignment } from "./f-index";
import type { Box } from "./types";
import type { Face, FaceAlignment } from "./types-old";
export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => {
const faceCrop = extractFaceCrop(imageBitmap, face.alignment);
export const saveFaceCrop = async (
imageBitmap: ImageBitmap,
faceID: string,
alignment: FaceAlignment,
) => {
const faceCrop = extractFaceCrop(imageBitmap, alignment);
const blob = await imageBitmapToBlob(faceCrop);
faceCrop.close();
const cache = await blobCache("face-crops");
await cache.put(face.id, blob);
await cache.put(faceID, blob);
return blob;
};

View File

@@ -1,387 +0,0 @@
import { haveWindow } from "@/next/env";
import log from "@/next/log";
import {
DBSchema,
IDBPDatabase,
IDBPTransaction,
StoreNames,
deleteDB,
openDB,
} from "idb";
import isElectron from "is-electron";
import type { Person } from "services/face/people";
import type { MlFileData } from "services/face/types-old";
import {
DEFAULT_ML_SEARCH_CONFIG,
MAX_ML_SYNC_ERROR_COUNT,
} from "services/machineLearning/machineLearningService";
export interface IndexStatus {
outOfSyncFilesExists: boolean;
nSyncedFiles: number;
nTotalFiles: number;
localFilesSynced: boolean;
peopleIndexSynced: boolean;
}
/**
* TODO(MR): Transient type with an intersection of values that both existing
* and new types during the migration will have. Eventually we'll store the the
* server ML data shape here exactly.
*/
export interface MinimalPersistedFileData {
fileId: number;
mlVersion: number;
errorCount: number;
faces?: { personId?: number; id: string }[];
}
interface Config {}
export const ML_SEARCH_CONFIG_NAME = "ml-search";
const MLDATA_DB_NAME = "mldata";
interface MLDb extends DBSchema {
files: {
key: number;
value: MinimalPersistedFileData;
indexes: { mlVersion: [number, number] };
};
people: {
key: number;
value: Person;
};
// Unused, we only retain this is the schema so that we can delete it during
// migration.
things: {
key: number;
value: unknown;
};
versions: {
key: string;
value: number;
};
library: {
key: string;
value: unknown;
};
configs: {
key: string;
value: Config;
};
}
class MLIDbStorage {
public _db: Promise<IDBPDatabase<MLDb>>;
constructor() {
if (!haveWindow() || !isElectron()) {
return;
}
this.db;
}
private openDB(): Promise<IDBPDatabase<MLDb>> {
return openDB<MLDb>(MLDATA_DB_NAME, 4, {
terminated: async () => {
log.error("ML Indexed DB terminated");
this._db = undefined;
// TODO: remove if there is chance of this going into recursion in some case
await this.db;
},
blocked() {
// TODO: make sure we dont allow multiple tabs of app
log.error("ML Indexed DB blocked");
},
blocking() {
// TODO: make sure we dont allow multiple tabs of app
log.error("ML Indexed DB blocking");
},
async upgrade(db, oldVersion, newVersion, tx) {
let wasMLSearchEnabled = false;
try {
const searchConfig: unknown = await tx
.objectStore("configs")
.get(ML_SEARCH_CONFIG_NAME);
if (
searchConfig &&
typeof searchConfig == "object" &&
"enabled" in searchConfig &&
typeof searchConfig.enabled == "boolean"
) {
wasMLSearchEnabled = searchConfig.enabled;
}
} catch (e) {
// The configs store might not exist (e.g. during logout).
// Ignore.
}
log.info(
`Previous ML database v${oldVersion} had ML search ${wasMLSearchEnabled ? "enabled" : "disabled"}`,
);
if (oldVersion < 1) {
const filesStore = db.createObjectStore("files", {
keyPath: "fileId",
});
filesStore.createIndex("mlVersion", [
"mlVersion",
"errorCount",
]);
db.createObjectStore("people", {
keyPath: "id",
});
db.createObjectStore("things", {
keyPath: "id",
});
db.createObjectStore("versions");
db.createObjectStore("library");
}
if (oldVersion < 2) {
// TODO: update configs if version is updated in defaults
db.createObjectStore("configs");
/*
await tx
.objectStore("configs")
.add(
DEFAULT_ML_SYNC_JOB_CONFIG,
"ml-sync-job",
);
await tx
.objectStore("configs")
.add(DEFAULT_ML_SYNC_CONFIG, ML_SYNC_CONFIG_NAME);
*/
}
if (oldVersion < 3) {
await tx
.objectStore("configs")
.add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME);
}
/*
This'll go in version 5. Note that version 4 was never released,
but it was in main for a while, so we'll just skip it to avoid
breaking the upgrade path for people who ran the mainline.
*/
if (oldVersion < 4) {
/*
try {
await tx
.objectStore("configs")
.delete(ML_SEARCH_CONFIG_NAME);
await tx
.objectStore("configs")
.delete(""ml-sync"");
await tx
.objectStore("configs")
.delete("ml-sync-job");
await tx
.objectStore("configs")
.add(
{ enabled: wasMLSearchEnabled },
ML_SEARCH_CONFIG_NAME,
);
db.deleteObjectStore("library");
db.deleteObjectStore("things");
} catch {
// TODO: ignore for now as we finalize the new version
// the shipped implementation should have a more
// deterministic migration.
}
*/
}
log.info(
`ML DB upgraded from version ${oldVersion} to version ${newVersion}`,
);
},
});
}
public get db(): Promise<IDBPDatabase<MLDb>> {
if (!this._db) {
this._db = this.openDB();
log.info("Opening Ml DB");
}
return this._db;
}
public async clearMLDB() {
const db = await this.db;
db.close();
await deleteDB(MLDATA_DB_NAME);
log.info("Cleared Ml DB");
this._db = undefined;
await this.db;
}
public async getAllFileIdsForUpdate(
tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) {
return tx.store.getAllKeys();
}
public async getFileIds(
count: number,
limitMlVersion: number,
maxErrorCount: number,
) {
const db = await this.db;
const tx = db.transaction("files", "readonly");
const index = tx.store.index("mlVersion");
let cursor = await index.openKeyCursor(
IDBKeyRange.upperBound([limitMlVersion], true),
);
const fileIds: number[] = [];
while (cursor && fileIds.length < count) {
if (
cursor.key[0] < limitMlVersion &&
cursor.key[1] <= maxErrorCount
) {
fileIds.push(cursor.primaryKey);
}
cursor = await cursor.continue();
}
await tx.done;
return fileIds;
}
public async getFile(fileId: number): Promise<MinimalPersistedFileData> {
const db = await this.db;
return db.get("files", fileId);
}
public async putFile(mlFile: MlFileData) {
const db = await this.db;
return db.put("files", mlFile);
}
public async upsertFileInTx(
fileId: number,
upsert: (mlFile: MinimalPersistedFileData) => MinimalPersistedFileData,
) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
const existing = await tx.store.get(fileId);
const updated = upsert(existing);
await tx.store.put(updated);
await tx.done;
return updated;
}
public async putAllFiles(
mlFiles: MinimalPersistedFileData[],
tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) {
await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile)));
}
public async removeAllFiles(
fileIds: Array<number>,
tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) {
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
}
public async getPerson(id: number) {
const db = await this.db;
return db.get("people", id);
}
public async getAllPeople() {
const db = await this.db;
return db.getAll("people");
}
public async incrementIndexVersion(index: StoreNames<MLDb>) {
if (index === "versions") {
throw new Error("versions store can not be versioned");
}
const db = await this.db;
const tx = db.transaction(["versions", index], "readwrite");
let version = await tx.objectStore("versions").get(index);
version = (version || 0) + 1;
tx.objectStore("versions").put(version, index);
await tx.done;
return version;
}
public async getConfig<T extends Config>(name: string, def: T) {
const db = await this.db;
const tx = db.transaction("configs", "readwrite");
let config = (await tx.store.get(name)) as T;
if (!config) {
config = def;
await tx.store.put(def, name);
}
await tx.done;
return config;
}
public async putConfig(name: string, data: Config) {
const db = await this.db;
return db.put("configs", data, name);
}
public async getIndexStatus(latestMlVersion: number): Promise<IndexStatus> {
const db = await this.db;
const tx = db.transaction(["files", "versions"], "readonly");
const mlVersionIdx = tx.objectStore("files").index("mlVersion");
let outOfSyncCursor = await mlVersionIdx.openKeyCursor(
IDBKeyRange.upperBound([latestMlVersion], true),
);
let outOfSyncFilesExists = false;
while (outOfSyncCursor && !outOfSyncFilesExists) {
if (
outOfSyncCursor.key[0] < latestMlVersion &&
outOfSyncCursor.key[1] <= MAX_ML_SYNC_ERROR_COUNT
) {
outOfSyncFilesExists = true;
}
outOfSyncCursor = await outOfSyncCursor.continue();
}
const nSyncedFiles = await mlVersionIdx.count(
IDBKeyRange.lowerBound([latestMlVersion]),
);
const nTotalFiles = await mlVersionIdx.count();
const filesIndexVersion = await tx.objectStore("versions").get("files");
const peopleIndexVersion = await tx
.objectStore("versions")
.get("people");
const filesIndexVersionExists =
filesIndexVersion !== null && filesIndexVersion !== undefined;
const peopleIndexVersionExists =
peopleIndexVersion !== null && peopleIndexVersion !== undefined;
await tx.done;
return {
outOfSyncFilesExists,
nSyncedFiles,
nTotalFiles,
localFilesSynced: filesIndexVersionExists,
peopleIndexSynced:
peopleIndexVersionExists &&
peopleIndexVersion === filesIndexVersion,
};
}
}
export default new MLIDbStorage();

View File

@@ -82,6 +82,8 @@ interface FileStatus {
let _faceDB: ReturnType<typeof openFaceDB> | undefined;
const openFaceDB = async () => {
deleteLegacyDB();
const db = await openDB<FaceDBSchema>("face", 1, {
upgrade(db, oldVersion, newVersion) {
log.info(`Upgrading face DB ${oldVersion} => ${newVersion}`);
@@ -112,6 +114,13 @@ const openFaceDB = async () => {
return db;
};
const deleteLegacyDB = () => {
// Delete the legacy face DB.
// This code was added June 2024 (v1.7.1-rc) and can be removed once clients
// have migrated over.
void deleteDB("mldata");
};
/**
* @returns a lazily created, cached connection to the face DB.
*/
@@ -138,6 +147,7 @@ export const closeFaceDBConnectionsIfNeeded = async () => {
* Meant to be called during logout.
*/
export const clearFaceData = async () => {
deleteLegacyDB();
await closeFaceDBConnectionsIfNeeded();
return deleteDB("face", {
blocked() {
@@ -173,6 +183,14 @@ export const saveFaceIndex = async (faceIndex: FaceIndex) => {
]);
};
/**
* Return the {@link FaceIndex}, if any, for {@link fileID}.
*/
export const faceIndex = async (fileID: number) => {
const db = await faceDB();
return db.get("face-index", fileID);
};
/**
* Record the existence of a file so that entities in the face indexing universe
* know about it (e.g. can index it if it is new and it needs indexing).
@@ -197,6 +215,66 @@ export const addFileEntry = async (fileID: number) => {
return tx.done;
};
/**
* Sync entries in the face DB to align with the given list of local indexable
* file IDs.
*
* @param localFileIDs The IDs of all the files that the client is aware of,
* filtered to only keep the files that the user owns and the formats that can
* be indexed by our current face indexing pipeline.
*
* This function syncs the state of file entries in face DB to the state of file
* entries stored otherwise by the local client.
*
* - Files (identified by their ID) that are present locally but are not yet in
* face DB get a fresh entry in face DB (and are marked as indexable).
*
* - Files that are not present locally but still exist in face DB are removed
* from face DB (including its face index, if any).
*/
export const syncWithLocalIndexableFileIDs = async (localFileIDs: number[]) => {
const db = await faceDB();
const tx = db.transaction(["face-index", "file-status"], "readwrite");
const fdbFileIDs = await tx.objectStore("file-status").getAllKeys();
const local = new Set(localFileIDs);
const fdb = new Set(fdbFileIDs);
const newFileIDs = localFileIDs.filter((id) => !fdb.has(id));
const removedFileIDs = fdbFileIDs.filter((id) => !local.has(id));
return Promise.all(
[
newFileIDs.map((id) =>
tx.objectStore("file-status").put({
fileID: id,
isIndexable: 1,
failureCount: 0,
}),
),
removedFileIDs.map((id) =>
tx.objectStore("file-status").delete(id),
),
removedFileIDs.map((id) => tx.objectStore("face-index").delete(id)),
tx.done,
].flat(),
);
};
/**
* Return the count of files that can be, and that have been, indexed.
*/
export const indexedAndIndexableCounts = async () => {
const db = await faceDB();
const tx = db.transaction(["face-index", "file-status"], "readwrite");
const indexedCount = await tx.objectStore("face-index").count();
const indexableCount = await tx
.objectStore("file-status")
.index("isIndexable")
.count(IDBKeyRange.only(1));
return { indexedCount, indexableCount };
};
/**
* Return a list of fileIDs that need to be indexed.
*
@@ -204,11 +282,13 @@ export const addFileEntry = async (fileID: number) => {
* (can use {@link addFileEntry} to inform it about new files). From this
* universe, we filter out fileIDs the files corresponding to which have already
* been indexed, or for which we attempted indexing but failed.
*
* @param count Limit the result to up to {@link count} items.
*/
export const unindexedFileIDs = async () => {
export const unindexedFileIDs = async (count?: number) => {
const db = await faceDB();
const tx = db.transaction("file-status", "readonly");
return tx.store.index("isIndexable").getAllKeys(IDBKeyRange.only(1));
return tx.store.index("isIndexable").getAllKeys(IDBKeyRange.only(1), count);
};
/**

View File

@@ -1,8 +1,9 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import log from "@/next/log";
import { workerBridge } from "@/next/worker/worker-bridge";
import { Matrix } from "ml-matrix";
import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import DownloadManager from "services/download";
import { getSimilarityTransformation } from "similarity-transformation";
import {
Matrix as TransformationMatrix,
@@ -12,21 +13,15 @@ import {
translate,
} from "transformation-matrix";
import type { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
import { saveFaceCrop } from "./crop";
import { fetchImageBitmap, getLocalFileImageBitmap } from "./file";
import {
clamp,
grayscaleIntMatrixFromNormalized2List,
pixelRGBBilinear,
warpAffineFloat32List,
} from "./image";
import type { Box, Dimensions } from "./types";
import type {
Face,
FaceAlignment,
FaceDetection,
MlFileData,
} from "./types-old";
import type { Box, Dimensions, Face, Point } from "./types";
/**
* Index faces in the given file.
@@ -43,95 +38,120 @@ import type {
* they can be saved locally for offline use, and encrypts and uploads them to
* the user's remote storage so that their other devices can download them
* instead of needing to reindex.
*
* @param enteFile The {@link EnteFile} to index.
*
* @param file The contents of {@link enteFile} as a web {@link File}, if
* available. These are used when they are provided, otherwise the file is
* downloaded and decrypted from remote.
*
* @param userAgent The UA of the current client (the client that is generating
* the embedding).
*/
export const indexFaces = async (enteFile: EnteFile, localFile?: File) => {
const startTime = Date.now();
export const indexFaces = async (
enteFile: EnteFile,
file: File | undefined,
userAgent: string,
) => {
const imageBitmap = await renderableImageBlob(enteFile, file).then(
createImageBitmap,
);
const { width, height } = imageBitmap;
const fileID = enteFile.id;
const imageBitmap = await fetchOrCreateImageBitmap(enteFile, localFile);
let mlFile: MlFileData;
try {
mlFile = await indexFaces_(enteFile, imageBitmap);
return {
fileID,
width,
height,
faceEmbedding: {
version: 1,
client: userAgent,
faces: await indexFacesInBitmap(fileID, imageBitmap),
},
};
} finally {
imageBitmap.close();
}
log.debug(() => {
const nf = mlFile.faces?.length ?? 0;
const ms = Date.now() - startTime;
return `Indexed ${nf} faces in file ${enteFile.id} (${ms} ms)`;
});
return mlFile;
};
/**
* Return a {@link ImageBitmap}, using {@link localFile} if present otherwise
* Return a "renderable" image blob, using {@link file} if present otherwise
* downloading the source image corresponding to {@link enteFile} from remote.
*/
const fetchOrCreateImageBitmap = async (
enteFile: EnteFile,
localFile: File,
) => {
const fileType = enteFile.metadata.fileType;
if (localFile) {
// TODO-ML(MR): Could also be image part of live photo?
if (fileType !== FILE_TYPE.IMAGE)
throw new Error("Local file of only image type is supported");
const renderableImageBlob = async (enteFile: EnteFile, file: File) =>
file
? getRenderableImage(enteFile.metadata.title, file)
: fetchRenderableBlob(enteFile);
return await getLocalFileImageBitmap(enteFile, localFile);
} else if ([FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes(fileType)) {
return await fetchImageBitmap(enteFile);
const fetchRenderableBlob = async (enteFile: EnteFile) => {
const fileStream = await DownloadManager.getFile(enteFile);
const fileBlob = await new Response(fileStream).blob();
const fileType = enteFile.metadata.fileType;
if (fileType == FILE_TYPE.IMAGE) {
return getRenderableImage(enteFile.metadata.title, fileBlob);
} else if (fileType == FILE_TYPE.LIVE_PHOTO) {
const { imageFileName, imageData } = await decodeLivePhoto(
enteFile.metadata.title,
fileBlob,
);
return getRenderableImage(imageFileName, new Blob([imageData]));
} else {
// A layer above us should've already filtered these out.
throw new Error(`Cannot index unsupported file type ${fileType}`);
}
};
const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => {
const fileID = enteFile.id;
const indexFacesInBitmap = async (
fileID: number,
imageBitmap: ImageBitmap,
): Promise<Face[]> => {
const { width, height } = imageBitmap;
const imageDimensions = { width, height };
const mlFile: MlFileData = {
fileId: fileID,
mlVersion: defaultMLVersion,
imageDimensions,
errorCount: 0,
};
const faceDetections = await detectFaces(imageBitmap);
const detectedFaces = faceDetections.map((detection) => ({
id: makeFaceID(fileID, detection, imageDimensions),
fileId: fileID,
detection,
}));
mlFile.faces = detectedFaces;
const yoloFaceDetections = await detectFaces(imageBitmap);
const partialResult = yoloFaceDetections.map(
({ box, landmarks, score }) => {
const faceID = makeFaceID(fileID, box, imageDimensions);
const detection = { box, landmarks };
return { faceID, detection, score };
},
);
if (detectedFaces.length > 0) {
const alignments: FaceAlignment[] = [];
const alignments: FaceAlignment[] = [];
for (const face of mlFile.faces) {
const alignment = faceAlignment(face.detection);
face.alignment = alignment;
alignments.push(alignment);
for (const { faceID, detection } of partialResult) {
const alignment = computeFaceAlignment(detection);
alignments.push(alignment);
await saveFaceCrop(imageBitmap, face);
// This step is not part of the indexing pipeline, we just do it here
// since we have already computed the face alignment. Ignore errors that
// happen during this since it does not impact the generated face index.
try {
await saveFaceCrop(imageBitmap, faceID, alignment);
} catch (e) {
log.error(`Failed to save face crop for faceID ${faceID}`, e);
}
const alignedFacesData = convertToMobileFaceNetInput(
imageBitmap,
alignments,
);
const blurValues = detectBlur(alignedFacesData, mlFile.faces);
mlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i]));
const embeddings = await computeEmbeddings(alignedFacesData);
mlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
mlFile.faces.forEach((face) => {
face.detection = relativeDetection(face.detection, imageDimensions);
});
}
return mlFile;
const alignedFacesData = convertToMobileFaceNetInput(
imageBitmap,
alignments,
);
const embeddings = await computeEmbeddings(alignedFacesData);
const blurs = detectBlur(
alignedFacesData,
partialResult.map((f) => f.detection),
);
return partialResult.map(({ faceID, detection, score }, i) => ({
faceID,
detection: normalizeToImageDimensions(detection, imageDimensions),
score,
blur: blurs[i],
embedding: Array.from(embeddings[i]),
}));
};
/**
@@ -141,14 +161,14 @@ const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => {
*/
const detectFaces = async (
imageBitmap: ImageBitmap,
): Promise<FaceDetection[]> => {
): Promise<YOLOFaceDetection[]> => {
const rect = ({ width, height }) => ({ x: 0, y: 0, width, height });
const { yoloInput, yoloSize } =
convertToYOLOInputFloat32ChannelsFirst(imageBitmap);
const yoloOutput = await workerBridge.detectFaces(yoloInput);
const faces = filterExtractDetectionsFromYOLOOutput(yoloOutput);
const faceDetections = transformFaceDetections(
const faceDetections = transformYOLOFaceDetections(
faces,
rect(yoloSize),
rect(imageBitmap),
@@ -209,6 +229,12 @@ const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => {
return { yoloInput, yoloSize };
};
export interface YOLOFaceDetection {
box: Box;
landmarks: Point[];
score: number;
}
/**
* Extract detected faces from the YOLOv5Face's output.
*
@@ -227,8 +253,8 @@ const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => {
*/
const filterExtractDetectionsFromYOLOOutput = (
rows: Float32Array,
): FaceDetection[] => {
const faces: FaceDetection[] = [];
): YOLOFaceDetection[] => {
const faces: YOLOFaceDetection[] = [];
// Iterate over each row.
for (let i = 0; i < rows.length; i += 16) {
const score = rows[i + 4];
@@ -253,7 +279,6 @@ const filterExtractDetectionsFromYOLOOutput = (
const rightMouthY = rows[i + 14];
const box = { x, y, width, height };
const probability = score as number;
const landmarks = [
{ x: leftEyeX, y: leftEyeY },
{ x: rightEyeX, y: rightEyeY },
@@ -261,26 +286,26 @@ const filterExtractDetectionsFromYOLOOutput = (
{ x: leftMouthX, y: leftMouthY },
{ x: rightMouthX, y: rightMouthY },
];
faces.push({ box, landmarks, probability });
faces.push({ box, landmarks, score });
}
return faces;
};
/**
* Transform the given {@link faceDetections} from their coordinate system in
* Transform the given {@link yoloFaceDetections} from their coordinate system in
* which they were detected ({@link inBox}) back to the coordinate system of the
* original image ({@link toBox}).
*/
const transformFaceDetections = (
faceDetections: FaceDetection[],
const transformYOLOFaceDetections = (
yoloFaceDetections: YOLOFaceDetection[],
inBox: Box,
toBox: Box,
): FaceDetection[] => {
): YOLOFaceDetection[] => {
const transform = boxTransformationMatrix(inBox, toBox);
return faceDetections.map((f) => ({
return yoloFaceDetections.map((f) => ({
box: transformBox(f.box, transform),
landmarks: f.landmarks.map((p) => applyToPoint(transform, p)),
probability: f.probability,
score: f.score,
}));
};
@@ -312,8 +337,8 @@ const transformBox = (box: Box, transform: TransformationMatrix): Box => {
* Remove overlapping faces from an array of face detections through non-maximum
* suppression algorithm.
*
* This function sorts the detections by their probability in descending order,
* then iterates over them.
* This function sorts the detections by their score in descending order, then
* iterates over them.
*
* For each detection, it calculates the Intersection over Union (IoU) with all
* other detections.
@@ -322,8 +347,8 @@ const transformBox = (box: Box, transform: TransformationMatrix): Box => {
* (`iouThreshold`), the other detection is considered overlapping and is
* removed.
*
* @param detections - An array of face detections to remove overlapping faces
* from.
* @param detections - An array of YOLO face detections to remove overlapping
* faces from.
*
* @param iouThreshold - The minimum IoU between two detections for them to be
* considered overlapping.
@@ -331,11 +356,11 @@ const transformBox = (box: Box, transform: TransformationMatrix): Box => {
* @returns An array of face detections with overlapping faces removed
*/
const naiveNonMaxSuppression = (
detections: FaceDetection[],
detections: YOLOFaceDetection[],
iouThreshold: number,
): FaceDetection[] => {
): YOLOFaceDetection[] => {
// Sort the detections by score, the highest first.
detections.sort((a, b) => b.probability - a.probability);
detections.sort((a, b) => b.score - a.score);
// Loop through the detections and calculate the IOU.
for (let i = 0; i < detections.length - 1; i++) {
@@ -379,11 +404,7 @@ const intersectionOverUnion = (a: FaceDetection, b: FaceDetection): number => {
return intersectionArea / unionArea;
};
const makeFaceID = (
fileID: number,
{ box }: FaceDetection,
image: Dimensions,
) => {
const makeFaceID = (fileID: number, box: Box, image: Dimensions) => {
const part = (v: number) => clamp(v, 0.0, 0.999999).toFixed(5).substring(2);
const xMin = part(box.x / image.width);
const yMin = part(box.y / image.height);
@@ -392,13 +413,30 @@ const makeFaceID = (
return [`${fileID}`, xMin, yMin, xMax, yMax].join("_");
};
export interface FaceAlignment {
/**
* An affine transformation matrix (rotation, translation, scaling) to align
* the face extracted from the image.
*/
affineMatrix: number[][];
/**
* The bounding box of the transformed box.
*
* The affine transformation shifts the original detection box a new,
* transformed, box (possibily rotated). This property is the bounding box
* of that transformed box. It is in the coordinate system of the original,
* full, image on which the detection occurred.
*/
boundingBox: Box;
}
/**
* Compute and return an {@link FaceAlignment} for the given face detection.
*
* @param faceDetection A geometry indicating a face detected in an image.
*/
const faceAlignment = (faceDetection: FaceDetection): FaceAlignment =>
faceAlignmentUsingSimilarityTransform(
const computeFaceAlignment = (faceDetection: FaceDetection): FaceAlignment =>
computeFaceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(idealMobileFaceNetLandmarks, mobileFaceNetFaceSize),
);
@@ -421,7 +459,7 @@ const normalizeLandmarks = (
): [number, number][] =>
landmarks.map(([x, y]) => [x / faceSize, y / faceSize]);
const faceAlignmentUsingSimilarityTransform = (
const computeFaceAlignmentUsingSimilarityTransform = (
faceDetection: FaceDetection,
alignedLandmarks: [number, number][],
): FaceAlignment => {
@@ -483,28 +521,35 @@ const convertToMobileFaceNetInput = (
return faceData;
};
interface FaceDetection {
box: Box;
landmarks: Point[];
}
/**
* Laplacian blur detection.
*
* Return an array of detected blur values, one for each face in {@link faces}.
* The face data is taken from the slice of {@link alignedFacesData}
* corresponding to each face of {@link faces}.
* Return an array of detected blur values, one for each face detection in
* {@link faceDetections}. The face data is taken from the slice of
* {@link alignedFacesData} corresponding to the face of {@link faceDetections}.
*/
const detectBlur = (alignedFacesData: Float32Array, faces: Face[]): number[] =>
faces.map((face, i) => {
const detectBlur = (
alignedFacesData: Float32Array,
faceDetections: FaceDetection[],
): number[] =>
faceDetections.map((d, i) => {
const faceImage = grayscaleIntMatrixFromNormalized2List(
alignedFacesData,
i,
mobileFaceNetFaceSize,
mobileFaceNetFaceSize,
);
return matrixVariance(applyLaplacian(faceImage, faceDirection(face)));
return matrixVariance(applyLaplacian(faceImage, faceDirection(d)));
});
type FaceDirection = "left" | "right" | "straight";
const faceDirection = (face: Face): FaceDirection => {
const landmarks = face.detection.landmarks;
const faceDirection = ({ landmarks }: FaceDetection): FaceDirection => {
const leftEye = landmarks[0];
const rightEye = landmarks[1];
const nose = landmarks[2];
@@ -694,7 +739,7 @@ const computeEmbeddings = async (
/**
* Convert the coordinates to between 0-1, normalized by the image's dimensions.
*/
const relativeDetection = (
const normalizeToImageDimensions = (
faceDetection: FaceDetection,
{ width, height }: Dimensions,
): FaceDetection => {
@@ -709,6 +754,5 @@ const relativeDetection = (
x: l.x / width,
y: l.y / height,
}));
const probability = faceDetection.probability;
return { box, landmarks, probability };
return { box, landmarks };
};

View File

@@ -1,37 +0,0 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import DownloadManager from "services/download";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export const fetchImageBitmap = async (file: EnteFile) =>
fetchRenderableBlob(file).then(createImageBitmap);
async function fetchRenderableBlob(file: EnteFile) {
const fileStream = await DownloadManager.getFile(file);
const fileBlob = await new Response(fileStream).blob();
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return createImageBitmap(fileBlob);
}

View File

@@ -0,0 +1,248 @@
import { FILE_TYPE } from "@/media/file-type";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import { ensure } from "@/utils/ensure";
import { wait } from "@/utils/promise";
import { type Remote } from "comlink";
import { getLocalFiles } from "services/fileService";
import machineLearningService from "services/machineLearning/machineLearningService";
import mlWorkManager from "services/machineLearning/mlWorkManager";
import type { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import {
faceIndex,
indexedAndIndexableCounts,
syncWithLocalIndexableFileIDs,
unindexedFileIDs,
} from "./db";
import { FaceIndexerWorker } from "./indexer.worker";
/**
* Face indexing orchestrator.
*
* This module exposes a singleton instance of this class which drives the face
* indexing process on the user's library.
*
* The indexer operates in two modes - live indexing and backfill.
*
* When live indexing, any files that are being uploaded from the current client
* are provided to the indexer, which puts them in a queue and indexes them one
* by one. This is more efficient since we already have the file's content at
* hand and do not have to download and decrypt it.
*
* When backfilling, the indexer figures out if any of the user's files
* (irrespective of where they were uploaded from) still need to be indexed, and
* if so, downloads, decrypts and indexes them.
*
* Live indexing has higher priority, backfilling runs otherwise. If nothing
* remains to be indexed, the indexer goes to sleep for a while.
*/
class FaceIndexer {
/** Live indexing queue. */
private liveItems: { enteFile: EnteFile; file: File | undefined }[];
/** Timeout for when the next time we will wake up. */
private wakeTimeout: ReturnType<typeof setTimeout> | undefined;
/**
* Add a file to the live indexing queue.
*
* @param enteFile An {@link EnteFile} that should be indexed.
*
* @param file The contents of {@link enteFile} as a web {@link File}
* object, if available.
*/
enqueueFile(enteFile: EnteFile, file: File | undefined) {
// If face indexing is not enabled, don't enqueue anything. Later on if
// the user turns on face indexing these files will get indexed as part
// of the backfilling anyway, the live indexing is just an optimization.
if (!mlWorkManager.isMlSearchEnabled) return;
this.liveItems.push({ enteFile, file });
this.wakeUpIfNeeded();
}
private wakeUpIfNeeded() {
// Already awake.
if (!this.wakeTimeout) return;
// Cancel the alarm, wake up now.
clearTimeout(this.wakeTimeout);
this.wakeTimeout = undefined;
// Get to work.
this.tick();
}
/**
* A promise for the lazily created singleton {@link FaceIndexerWorker} remote
* exposed by this module.
*/
_faceIndexer: Promise<Remote<FaceIndexerWorker>>;
/**
* Main thread interface to the face indexer.
*
* This function provides a promise that resolves to a lazily created singleton
* remote with a {@link FaceIndexerWorker} at the other end.
*/
faceIndexer = (): Promise<Remote<FaceIndexerWorker>> =>
(this._faceIndexer ??= createFaceIndexerComlinkWorker().remote);
private async tick() {
console.log("tick");
const item = this.liveItems.pop();
if (!item) {
// TODO-ML: backfill instead if needed here.
this.wakeTimeout = setTimeout(() => {
this.wakeTimeout = undefined;
this.wakeUpIfNeeded();
}, 30 * 1000);
return;
}
/*
const fileID = item.enteFile.id;
try {
const faceIndex = await indexFaces(item.enteFile, item.file, userAgent);
log.info(`faces in file ${fileID}`, faceIndex);
} catch (e) {
log.error(`Failed to index faces in file ${fileID}`, e);
markIndexingFailed(item.enteFile.id);
}
*/
// Let the runloop drain.
await wait(0);
// Run again.
// TODO
// this.tick();
}
/**
* Add a newly uploaded file to the face indexing queue.
*
* @param enteFile The {@link EnteFile} that was uploaded.
* @param file
*/
/*
indexFacesInFile = (enteFile: EnteFile, file: File) => {
if (!mlWorkManager.isMlSearchEnabled) return;
faceIndexer().then((indexer) => {
indexer.enqueueFile(file, enteFile);
});
};
*/
}
/** The singleton instance of {@link FaceIndexer}. */
export default new FaceIndexer();
const createFaceIndexerComlinkWorker = () =>
new ComlinkWorker<typeof FaceIndexerWorker>(
"face-indexer",
new Worker(new URL("indexer.worker.ts", import.meta.url)),
);
export interface FaceIndexingStatus {
/**
* Which phase we are in within the indexing pipeline when viewed across the
* user's entire library:
*
* - "scheduled": There are files we know of that have not been indexed.
*
* - "indexing": The face indexer is currently running.
*
* - "clustering": All files we know of have been indexed, and we are now
* clustering the faces that were found.
*
* - "done": Face indexing and clustering is complete for the user's
* library.
*/
phase: "scheduled" | "indexing" | "clustering" | "done";
/** The number of files that have already been indexed. */
nSyncedFiles: number;
/** The total number of files that are eligible for indexing. */
nTotalFiles: number;
}
export const faceIndexingStatus = async (): Promise<FaceIndexingStatus> => {
const isSyncing = machineLearningService.isSyncing;
const { indexedCount, indexableCount } = await indexedAndIndexableCounts();
let phase: FaceIndexingStatus["phase"];
if (indexedCount < indexableCount) {
if (!isSyncing) {
phase = "scheduled";
} else {
phase = "indexing";
}
} else {
phase = "done";
}
return {
phase,
nSyncedFiles: indexedCount,
nTotalFiles: indexableCount,
};
};
/**
* Return the IDs of all the faces in the given {@link enteFile} that are not
* associated with a person cluster.
*/
export const unidentifiedFaceIDs = async (
enteFile: EnteFile,
): Promise<string[]> => {
const index = await faceIndex(enteFile.id);
return index?.faceEmbedding.faces.map((f) => f.faceID) ?? [];
};
/**
* Return true if the user has enabled face indexing in the app's settings.
*
* This setting is persisted locally (in local storage) and is not synced with
* remote. There is a separate setting, "faceSearchEnabled" that is synced with
* remote, but that tracks whether or not the user has enabled face search once
* on any client. This {@link isFaceIndexingEnabled} property, on the other
* hand, denotes whether or not indexing is enabled on the current client.
*/
export const isFaceIndexingEnabled = async () => {
if (isInternalUserForML()) {
return localStorage.getItem("faceIndexingEnabled") == "1";
}
// Force disabled for everyone else while we finalize it to avoid redundant
// reindexing for users.
return false;
};
/**
* Update the (locally stored) value of {@link isFaceIndexingEnabled}.
*/
export const setIsFaceIndexingEnabled = async (enabled: boolean) => {
if (enabled) localStorage.setItem("faceIndexingEnabled", "1");
else localStorage.removeItem("faceIndexingEnabled");
};
/**
* Sync face DB with the local indexable files that we know about. Then return
* the next {@link count} files that still need to be indexed.
*
* For more specifics of what a "sync" entails, see
* {@link syncWithLocalIndexableFileIDs}.
*
* @param userID Limit indexing to files owned by a {@link userID}.
*
* @param count Limit the resulting list of files to {@link count}.
*/
export const getFilesToIndex = async (userID: number, count: number) => {
const localFiles = await getLocalFiles();
const indexableTypes = [FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO];
const indexableFiles = localFiles.filter(
(f) =>
f.ownerID == userID && indexableTypes.includes(f.metadata.fileType),
);
const filesByID = new Map(indexableFiles.map((f) => [f.id, f]));
await syncWithLocalIndexableFileIDs([...filesByID.keys()]);
const fileIDsToIndex = await unindexedFileIDs(count);
return fileIDsToIndex.map((id) => ensure(filesByID.get(id)));
};

View File

@@ -0,0 +1,75 @@
import log from "@/next/log";
import type { EnteFile } from "types/file";
import { fileLogID } from "utils/file";
import {
closeFaceDBConnectionsIfNeeded,
markIndexingFailed,
saveFaceIndex,
} from "./db";
import { indexFaces } from "./f-index";
import { putFaceIndex } from "./remote";
import type { FaceIndex } from "./types";
/**
* Index faces in a file, save the persist the results locally, and put them on
* remote.
*
* This class is instantiated in a Web Worker so as to not get in the way of the
* main thread. It could've been a bunch of free standing functions too, it is
* just a class for convenience of compatibility with how the rest of our
* comlink workers are structured.
*/
export class FaceIndexerWorker {
/*
* Index faces in a file, save the persist the results locally, and put them
* on remote.
*
* @param enteFile The {@link EnteFile} to index.
*
* @param file If the file is one which is being uploaded from the current
* client, then we will also have access to the file's content. In such
* cases, pass a web {@link File} object to use that its data directly for
* face indexing. If this is not provided, then the file's contents will be
* downloaded and decrypted from remote.
*/
async index(enteFile: EnteFile, file: File | undefined, userAgent: string) {
const f = fileLogID(enteFile);
const startTime = Date.now();
let faceIndex: FaceIndex;
try {
faceIndex = await indexFaces(enteFile, file, userAgent);
} catch (e) {
// Mark indexing as having failed only if the indexing itself
// failed, not if there were subsequent failures (like when trying
// to put the result to remote or save it to the local face DB).
log.error(`Failed to index faces in ${f}`, e);
markIndexingFailed(enteFile.id);
throw e;
}
try {
await putFaceIndex(enteFile, faceIndex);
await saveFaceIndex(faceIndex);
} catch (e) {
log.error(`Failed to put/save face index for ${f}`, e);
throw e;
}
log.debug(() => {
const nf = faceIndex.faceEmbedding.faces.length;
const ms = Date.now() - startTime;
return `Indexed ${nf} faces in ${f} (${ms} ms)`;
});
return faceIndex;
}
/**
* Calls {@link closeFaceDBConnectionsIfNeeded} to close any open
* connections to the face DB from the web worker's context.
*/
closeFaceDB() {
closeFaceDBConnectionsIfNeeded();
}
}

View File

@@ -84,6 +84,10 @@ export const syncPeopleIndex = async () => {
: best,
);
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
if (personFace && !personFace.crop?.cacheKey) {
const file = await getLocalFile(personFace.fileId);

View File

@@ -2,24 +2,20 @@ import log from "@/next/log";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { putEmbedding } from "services/embeddingService";
import type { EnteFile } from "types/file";
import type { Point } from "./types";
import type { Face, FaceDetection, MlFileData } from "./types-old";
import type { FaceIndex } from "./types";
export const putFaceEmbedding = async (
export const putFaceIndex = async (
enteFile: EnteFile,
mlFileData: MlFileData,
userAgent: string,
faceIndex: FaceIndex,
) => {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData, userAgent);
log.debug(() => ({ t: "Local ML file data", mlFileData }));
log.debug(() => ({
t: "Uploaded ML file data",
d: JSON.stringify(serverMl),
t: "Uploading faceEmbedding",
d: JSON.stringify(faceIndex),
}));
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
await comlinkCryptoWorker.encryptMetadata(faceIndex, enteFile.key);
await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
@@ -27,123 +23,3 @@ export const putFaceEmbedding = async (
model: "file-ml-clip-face",
});
};
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
public client: string;
public constructor(faces: ServerFace[], client: string, version: number) {
this.faces = faces;
this.client = client;
this.version = version;
}
}
class ServerFace {
public faceID: string;
public embedding: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embedding: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embedding = embedding;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Point[];
public constructor(box: ServerFaceBox, landmarks: Point[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public x: number;
public y: number;
public width: number;
public height: number;
public constructor(x: number, y: number, width: number, height: number) {
this.x = x;
this.y = y;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
userAgent: string,
): ServerFileMl {
if (localFileMlData.errorCount > 0) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, landmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(faces, userAgent, 1);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}

View File

@@ -1,46 +0,0 @@
import type { Box, Dimensions, Point } from "./types";
export interface FaceDetection {
// box and landmarks is relative to image dimentions stored at mlFileData
box: Box;
landmarks?: Point[];
probability?: number;
}
export interface FaceAlignment {
/**
* An affine transformation matrix (rotation, translation, scaling) to align
* the face extracted from the image.
*/
affineMatrix: number[][];
/**
* The bounding box of the transformed box.
*
* The affine transformation shifts the original detection box a new,
* transformed, box (possibily rotated). This property is the bounding box
* of that transformed box. It is in the coordinate system of the original,
* full, image on which the detection occurred.
*/
boundingBox: Box;
}
export interface Face {
fileId: number;
detection: FaceDetection;
id: string;
alignment?: FaceAlignment;
blurValue?: number;
embedding?: Float32Array;
personId?: number;
}
export interface MlFileData {
fileId: number;
faces?: Face[];
imageDimensions?: Dimensions;
mlVersion: number;
errorCount: number;
}

View File

@@ -1,46 +1,12 @@
import log from "@/next/log";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import PQueue from "p-queue";
import mlIDbStorage, {
ML_SEARCH_CONFIG_NAME,
type MinimalPersistedFileData,
} from "services/face/db-old";
import { putFaceEmbedding } from "services/face/remote";
import { getLocalFiles } from "services/fileService";
import { getFilesToIndex } from "services/face/indexer";
import { FaceIndexerWorker } from "services/face/indexer.worker";
import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import { indexFaces } from "../face/f-index";
export const defaultMLVersion = 1;
const batchSize = 200;
export const MAX_ML_SYNC_ERROR_COUNT = 1;
export interface MLSearchConfig {
enabled: boolean;
}
export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
enabled: false,
};
export async function getMLSearchConfig() {
if (isInternalUserForML()) {
return mlIDbStorage.getConfig(
ML_SEARCH_CONFIG_NAME,
DEFAULT_ML_SEARCH_CONFIG,
);
}
// Force disabled for everyone else while we finalize it to avoid redundant
// reindexing for users.
return DEFAULT_ML_SEARCH_CONFIG;
}
export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig);
}
class MLSyncContext {
public token: string;
public userID: number;
@@ -79,6 +45,8 @@ class MachineLearningService {
private localSyncContext: Promise<MLSyncContext>;
private syncContext: Promise<MLSyncContext>;
public isSyncing = false;
public async sync(
token: string,
userID: number,
@@ -90,9 +58,7 @@ class MachineLearningService {
const syncContext = await this.getSyncContext(token, userID, userAgent);
await this.syncLocalFiles(syncContext);
await this.getOutOfSyncFiles(syncContext);
syncContext.outOfSyncFiles = await getFilesToIndex(userID, batchSize);
if (syncContext.outOfSyncFiles.length > 0) {
await this.syncFiles(syncContext);
@@ -103,96 +69,8 @@ class MachineLearningService {
return !error && nOutOfSyncFiles > 0;
}
private newMlData(fileId: number) {
return {
fileId,
mlVersion: 0,
errorCount: 0,
} as MinimalPersistedFileData;
}
private async getLocalFilesMap(syncContext: MLSyncContext) {
if (!syncContext.localFilesMap) {
const localFiles = await getLocalFiles();
const personalFiles = localFiles.filter(
(f) => f.ownerID === syncContext.userID,
);
syncContext.localFilesMap = new Map<number, EnteFile>();
personalFiles.forEach((f) =>
syncContext.localFilesMap.set(f.id, f),
);
}
return syncContext.localFilesMap;
}
private async syncLocalFiles(syncContext: MLSyncContext) {
const startTime = Date.now();
const localFilesMap = await this.getLocalFilesMap(syncContext);
const db = await mlIDbStorage.db;
const tx = db.transaction("files", "readwrite");
const mlFileIdsArr = await mlIDbStorage.getAllFileIdsForUpdate(tx);
const mlFileIds = new Set<number>();
mlFileIdsArr.forEach((mlFileId) => mlFileIds.add(mlFileId));
const newFileIds: Array<number> = [];
for (const localFileId of localFilesMap.keys()) {
if (!mlFileIds.has(localFileId)) {
newFileIds.push(localFileId);
}
}
let updated = false;
if (newFileIds.length > 0) {
log.info("newFiles: ", newFileIds.length);
const newFiles = newFileIds.map((fileId) => this.newMlData(fileId));
await mlIDbStorage.putAllFiles(newFiles, tx);
updated = true;
}
const removedFileIds: Array<number> = [];
for (const mlFileId of mlFileIds) {
if (!localFilesMap.has(mlFileId)) {
removedFileIds.push(mlFileId);
}
}
if (removedFileIds.length > 0) {
log.info("removedFiles: ", removedFileIds.length);
await mlIDbStorage.removeAllFiles(removedFileIds, tx);
updated = true;
}
await tx.done;
if (updated) {
// TODO: should do in same transaction
await mlIDbStorage.incrementIndexVersion("files");
}
log.info("syncLocalFiles", Date.now() - startTime, "ms");
}
private async getOutOfSyncFiles(syncContext: MLSyncContext) {
const startTime = Date.now();
const fileIds = await mlIDbStorage.getFileIds(
batchSize,
defaultMLVersion,
MAX_ML_SYNC_ERROR_COUNT,
);
log.info("fileIds: ", JSON.stringify(fileIds));
const localFilesMap = await this.getLocalFilesMap(syncContext);
syncContext.outOfSyncFiles = fileIds.map((fileId) =>
localFilesMap.get(fileId),
);
log.info("getOutOfSyncFiles", Date.now() - startTime, "ms");
}
private async syncFiles(syncContext: MLSyncContext) {
this.isSyncing = true;
try {
const functions = syncContext.outOfSyncFiles.map(
(outOfSyncfile) => async () => {
@@ -212,12 +90,7 @@ class MachineLearningService {
syncContext.error = error;
}
await syncContext.syncQueue.onIdle();
// TODO: In case syncJob has to use multiple ml workers
// do in same transaction with each file update
// or keep in files store itself
await mlIDbStorage.incrementIndexVersion("files");
// await this.disposeMLModels();
this.isSyncing = false;
}
private async getSyncContext(
@@ -300,23 +173,10 @@ class MachineLearningService {
localFile?: globalThis.File,
) {
try {
const mlFileData = await this.syncFile(
enteFile,
localFile,
syncContext.userAgent,
);
await this.syncFile(enteFile, localFile, syncContext.userAgent);
syncContext.nSyncedFiles += 1;
return mlFileData;
} catch (e) {
log.error("ML syncFile failed", e);
let error = e;
console.error(
"Error in ml sync, fileId: ",
enteFile.id,
"name: ",
enteFile.metadata.title,
error,
);
if ("status" in error) {
const parsedMessage = parseUploadErrorCodes(error);
error = parsedMessage;
@@ -331,42 +191,18 @@ class MachineLearningService {
throw error;
}
await this.persistMLFileSyncError(enteFile, error);
syncContext.nSyncedFiles += 1;
}
}
private async syncFile(
enteFile: EnteFile,
localFile: globalThis.File | undefined,
file: File | undefined,
userAgent: string,
) {
const oldMlFile = await mlIDbStorage.getFile(enteFile.id);
if (oldMlFile && oldMlFile.mlVersion) {
return oldMlFile;
}
const worker = new FaceIndexerWorker();
const newMlFile = await indexFaces(enteFile, localFile);
await putFaceEmbedding(enteFile, newMlFile, userAgent);
await mlIDbStorage.putFile(newMlFile);
return newMlFile;
}
private async persistMLFileSyncError(enteFile: EnteFile, e: Error) {
try {
await mlIDbStorage.upsertFileInTx(enteFile.id, (mlFileData) => {
if (!mlFileData) {
mlFileData = this.newMlData(enteFile.id);
}
mlFileData.errorCount = (mlFileData.errorCount || 0) + 1;
console.error(`lastError for ${enteFile.id}`, e);
return mlFileData;
});
} catch (e) {
// TODO: logError or stop sync job after most of the requests are failed
console.error("Error while storing ml sync error", e);
}
await worker.index(enteFile, file, userAgent);
}
}

View File

@@ -8,7 +8,6 @@ import { getToken, getUserID } from "@ente/shared/storage/localStorage/helpers";
import debounce from "debounce";
import PQueue from "p-queue";
import { createFaceComlinkWorker } from "services/face";
import mlIDbStorage from "services/face/db-old";
import type { DedicatedMLWorker } from "services/face/face.worker";
import { EnteFile } from "types/file";
@@ -117,6 +116,10 @@ class MLWorkManager {
);
}
public isMlSearchEnabled() {
return this.mlSearchEnabled;
}
public async setMlSearchEnabled(enabled: boolean) {
if (!this.mlSearchEnabled && enabled) {
log.info("Enabling MLWorkManager");
@@ -163,7 +166,6 @@ class MLWorkManager {
this.stopSyncJob();
this.mlSyncJob = undefined;
await this.terminateLiveSyncWorker();
await mlIDbStorage.clearMLDB();
}
private async fileUploadedHandler(arg: {
@@ -224,7 +226,11 @@ class MLWorkManager {
this.mlSearchEnabled && this.startSyncJob();
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
public async syncLocalFile(enteFile: EnteFile, localFile: globalThis.File) {
return;
/*
TODO-ML(MR): Disable live sync for now
await this.liveSyncQueue.add(async () => {
this.stopSyncJob();
const token = getToken();
@@ -239,6 +245,7 @@ class MLWorkManager {
localFile,
);
});
*/
}
// Sync Job

View File

@@ -2,9 +2,7 @@ import { FILE_TYPE } from "@/media/file-type";
import log from "@/next/log";
import * as chrono from "chrono-node";
import { t } from "i18next";
import mlIDbStorage from "services/face/db-old";
import type { Person } from "services/face/people";
import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { Collection } from "types/collection";
import { EntityType, LocationTag, LocationTagData } from "types/entity";
import { EnteFile } from "types/file";
@@ -22,6 +20,7 @@ import { getFormattedDate } from "utils/search";
import { clipService, computeClipMatchScore } from "./clip-service";
import { localCLIPEmbeddings } from "./embeddingService";
import { getLatestEntities } from "./entityService";
import { faceIndexingStatus } from "./face/indexer";
import locationSearchService, { City } from "./locationSearchService";
const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]);
@@ -175,19 +174,24 @@ export async function getAllPeopleSuggestion(): Promise<Array<Suggestion>> {
export async function getIndexStatusSuggestion(): Promise<Suggestion> {
try {
const indexStatus = await mlIDbStorage.getIndexStatus(defaultMLVersion);
const indexStatus = await faceIndexingStatus();
let label;
if (!indexStatus.localFilesSynced) {
label = t("INDEXING_SCHEDULED");
} else if (indexStatus.outOfSyncFilesExists) {
label = t("ANALYZING_PHOTOS", {
indexStatus,
});
} else if (!indexStatus.peopleIndexSynced) {
label = t("INDEXING_PEOPLE", { indexStatus });
} else {
label = t("INDEXING_DONE", { indexStatus });
let label: string;
switch (indexStatus.phase) {
case "scheduled":
label = t("INDEXING_SCHEDULED");
break;
case "indexing":
label = t("ANALYZING_PHOTOS", {
indexStatus,
});
break;
case "clustering":
label = t("INDEXING_PEOPLE", { indexStatus });
break;
case "done":
label = t("INDEXING_DONE", { indexStatus });
break;
}
return {
@@ -430,7 +434,7 @@ function convertSuggestionToSearchQuery(option: Suggestion): Search {
}
async function getAllPeople(limit: number = undefined) {
let people: Array<Person> = await mlIDbStorage.getAllPeople();
let people: Array<Person> = []; // await mlIDbStorage.getAllPeople();
// await mlPeopleStore.iterate<Person, void>((person) => {
// people.push(person);
// });

View File

@@ -1,5 +1,5 @@
import { FILE_TYPE } from "@/media/file-type";
import { IndexStatus } from "services/face/db-old";
import type { FaceIndexingStatus } from "services/face/indexer";
import type { Person } from "services/face/people";
import { City } from "services/locationSearchService";
import { LocationTagData } from "types/entity";
@@ -31,7 +31,7 @@ export interface Suggestion {
| DateValue
| number[]
| Person
| IndexStatus
| FaceIndexingStatus
| LocationTagData
| City
| FILE_TYPE

View File

@@ -81,6 +81,14 @@ class ModuleState {
const moduleState = new ModuleState();
/**
* @returns a string to use as an identifier when logging information about the
* given {@link enteFile}. The returned string contains the file name (for ease
* of debugging) and the file ID (for exactness).
*/
export const fileLogID = (enteFile: EnteFile) =>
`file ${enteFile.metadata.title ?? "-"} (${enteFile.id})`;
export async function getUpdatedEXIFFileForDownload(
fileReader: FileReader,
file: EnteFile,