diff --git a/web/apps/photos/src/components/Search/SearchBar/searchInput/MenuWithPeople.tsx b/web/apps/photos/src/components/Search/SearchBar/searchInput/MenuWithPeople.tsx index 1ea50873c7..b9b7ea88d5 100644 --- a/web/apps/photos/src/components/Search/SearchBar/searchInput/MenuWithPeople.tsx +++ b/web/apps/photos/src/components/Search/SearchBar/searchInput/MenuWithPeople.tsx @@ -5,8 +5,8 @@ import { t } from "i18next"; import { AppContext } from "pages/_app"; import { useContext } from "react"; import { components } from "react-select"; +import { IndexStatus } from "services/ml/db"; import { Suggestion, SuggestionType } from "types/search"; -import { IndexStatus } from "utils/storage/mlIDbStorage"; const { Menu } = components; diff --git a/web/apps/photos/src/components/Sidebar/UtilitySection.tsx b/web/apps/photos/src/components/Sidebar/UtilitySection.tsx index 6b4a6f43d5..32f61d9764 100644 --- a/web/apps/photos/src/components/Sidebar/UtilitySection.tsx +++ b/web/apps/photos/src/components/Sidebar/UtilitySection.tsx @@ -9,7 +9,7 @@ import { t } from "i18next"; import { useRouter } from "next/router"; import { AppContext } from "pages/_app"; import { useContext, useState } from "react"; -// import mlIDbStorage from 'utils/storage/mlIDbStorage'; +// import mlIDbStorage from 'services/ml/db'; import { configurePasskeyRecovery, isPasskeyRecoveryEnabled, diff --git a/web/apps/photos/src/components/ml/PeopleList.tsx b/web/apps/photos/src/components/ml/PeopleList.tsx index 65a41b3552..4ff61e044d 100644 --- a/web/apps/photos/src/components/ml/PeopleList.tsx +++ b/web/apps/photos/src/components/ml/PeopleList.tsx @@ -3,9 +3,9 @@ import { Skeleton, styled } from "@mui/material"; import { Legend } from "components/PhotoViewer/styledComponents/Legend"; import { t } from "i18next"; import React, { useEffect, useState } from "react"; -import { Face, Person } from "services/ml/types"; +import mlIDbStorage from "services/ml/db"; +import { Face, Person, type MlFileData } from "services/ml/types"; import { EnteFile } from "types/file"; -import { getPeopleList, getUnidentifiedFaces } from "utils/machineLearning"; const FaceChipContainer = styled("div")` display: flex; @@ -194,3 +194,45 @@ const FaceCropImageView: React.FC = ({ ); }; + +async function getPeopleList(file: EnteFile): Promise> { + let startTime = Date.now(); + const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); + log.info( + "getPeopleList:mlFilesStore:getItem", + Date.now() - startTime, + "ms", + ); + if (!mlFileData?.faces || mlFileData.faces.length < 1) { + return []; + } + + const peopleIds = mlFileData.faces + .filter((f) => f.personId !== null && f.personId !== undefined) + .map((f) => f.personId); + if (!peopleIds || peopleIds.length < 1) { + return []; + } + // log.info("peopleIds: ", peopleIds); + startTime = Date.now(); + const peoplePromises = peopleIds.map( + (p) => mlIDbStorage.getPerson(p) as Promise, + ); + const peopleList = await Promise.all(peoplePromises); + log.info( + "getPeopleList:mlPeopleStore:getItems", + Date.now() - startTime, + "ms", + ); + // log.info("peopleList: ", peopleList); + + return peopleList; +} + +async function getUnidentifiedFaces(file: EnteFile): Promise> { + const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); + + return mlFileData?.faces?.filter( + (f) => f.personId === null || f.personId === undefined, + ); +} diff --git a/web/apps/photos/src/constants/mlConfig.ts b/web/apps/photos/src/constants/mlConfig.ts deleted file mode 100644 index 0ee2009420..0000000000 --- a/web/apps/photos/src/constants/mlConfig.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { MLSearchConfig, MLSyncConfig } from "services/ml/types"; -import { JobConfig } from "types/common/job"; - -export const DEFAULT_ML_SYNC_JOB_CONFIG: JobConfig = { - intervalSec: 5, - // TODO: finalize this after seeing effects on and from machine sleep - maxItervalSec: 960, - backoffMultiplier: 2, -}; - -export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = { - batchSize: 200, - imageSource: "Original", - faceDetection: { - method: "YoloFace", - }, - faceCrop: { - enabled: true, - method: "ArcFace", - padding: 0.25, - maxSize: 256, - blobOptions: { - type: "image/jpeg", - quality: 0.8, - }, - }, - faceAlignment: { - method: "ArcFace", - }, - blurDetection: { - method: "Laplacian", - threshold: 15, - }, - faceEmbedding: { - method: "MobileFaceNet", - faceSize: 112, - generateTsne: true, - }, - faceClustering: { - method: "Hdbscan", - minClusterSize: 3, - minSamples: 5, - clusterSelectionEpsilon: 0.6, - clusterSelectionMethod: "leaf", - minInputSize: 50, - // maxDistanceInsideCluster: 0.4, - generateDebugInfo: true, - }, - mlVersion: 3, -}; - -export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = { - enabled: false, -}; - -export const MAX_ML_SYNC_ERROR_COUNT = 1; diff --git a/web/apps/photos/src/pages/_app.tsx b/web/apps/photos/src/pages/_app.tsx index 0e80d0df9f..77e724d292 100644 --- a/web/apps/photos/src/pages/_app.tsx +++ b/web/apps/photos/src/pages/_app.tsx @@ -53,6 +53,10 @@ import { createContext, useEffect, useRef, useState } from "react"; import LoadingBar from "react-top-loading-bar"; import DownloadManager from "services/download"; import exportService, { resumeExportsIfNeeded } from "services/export"; +import { + getMLSearchConfig, + updateMLSearchConfig, +} from "services/machineLearning/machineLearningService"; import mlWorkManager from "services/machineLearning/mlWorkManager"; import { getFamilyPortalRedirectURL, @@ -64,10 +68,6 @@ import { NotificationAttributes, SetNotificationAttributes, } from "types/Notification"; -import { - getMLSearchConfig, - updateMLSearchConfig, -} from "utils/machineLearning/config"; import { getUpdateAvailableForDownloadMessage, getUpdateReadyToInstallMessage, diff --git a/web/apps/photos/src/pages/gallery/index.tsx b/web/apps/photos/src/pages/gallery/index.tsx index f90d1b8371..f870dfb768 100644 --- a/web/apps/photos/src/pages/gallery/index.tsx +++ b/web/apps/photos/src/pages/gallery/index.tsx @@ -84,7 +84,10 @@ import { getSectionSummaries, } from "services/collectionService"; import downloadManager from "services/download"; -import { syncEmbeddings, syncFileEmbeddings } from "services/embeddingService"; +import { + syncCLIPEmbeddings, + syncFaceEmbeddings, +} from "services/embeddingService"; import { syncEntities } from "services/entityService"; import { getLocalFiles, syncFiles } from "services/fileService"; import locationSearchService from "services/locationSearchService"; @@ -130,6 +133,7 @@ import { } from "utils/file"; import { isArchivedFile } from "utils/magicMetadata"; import { getSessionExpiredMessage } from "utils/ui"; +import { isInternalUserForML } from "utils/user"; import { getLocalFamilyData } from "utils/user/family"; export const DeadCenter = styled("div")` @@ -698,10 +702,10 @@ export default function Gallery() { await syncTrash(collections, setTrashedFiles); await syncEntities(); await syncMapEnabled(); - await syncEmbeddings(); + await syncCLIPEmbeddings(); const electron = globalThis.electron; - if (electron) { - await syncFileEmbeddings(); + if (isInternalUserForML() && electron) { + await syncFaceEmbeddings(); } if (clipService.isPlatformSupported()) { void clipService.scheduleImageEmbeddingExtraction(); diff --git a/web/apps/photos/src/services/clip-service.ts b/web/apps/photos/src/services/clip-service.ts index 703c89cf4b..aa724b4d58 100644 --- a/web/apps/photos/src/services/clip-service.ts +++ b/web/apps/photos/src/services/clip-service.ts @@ -11,7 +11,7 @@ import { Embedding } from "types/embedding"; import { EnteFile } from "types/file"; import { getPersonalFiles } from "utils/file"; import downloadManager from "./download"; -import { getLocalEmbeddings, putEmbedding } from "./embeddingService"; +import { localCLIPEmbeddings, putEmbedding } from "./embeddingService"; import { getAllLocalFiles, getLocalFiles } from "./fileService"; /** Status of CLIP indexing on the images in the user's local library. */ @@ -195,7 +195,7 @@ class CLIPService { return; } const localFiles = getPersonalFiles(await getAllLocalFiles(), user); - const existingEmbeddings = await getLocalEmbeddings(); + const existingEmbeddings = await localCLIPEmbeddings(); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, existingEmbeddings, @@ -394,7 +394,7 @@ export const computeClipMatchScore = async ( const initialIndexingStatus = async (): Promise => { const user = getData(LS_KEYS.USER); if (!user) throw new Error("Orphan CLIP indexing without a login"); - const allEmbeddings = await getLocalEmbeddings(); + const allEmbeddings = await localCLIPEmbeddings(); const localFiles = getPersonalFiles(await getLocalFiles(), user); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, diff --git a/web/apps/photos/src/services/embeddingService.ts b/web/apps/photos/src/services/embeddingService.ts index 36af848424..a60ae2a61b 100644 --- a/web/apps/photos/src/services/embeddingService.ts +++ b/web/apps/photos/src/services/embeddingService.ts @@ -7,6 +7,7 @@ import HTTPService from "@ente/shared/network/HTTPService"; import { getEndpoint } from "@ente/shared/network/api"; import localForage from "@ente/shared/storage/localForage"; import { getToken } from "@ente/shared/storage/localStorage/helpers"; +import { FileML } from "services/machineLearning/machineLearningService"; import type { Embedding, EmbeddingModel, @@ -15,31 +16,30 @@ import type { PutEmbeddingRequest, } from "types/embedding"; import { EnteFile } from "types/file"; -import { - getLatestVersionEmbeddings, - getLatestVersionFileEmbeddings, -} from "utils/embedding"; -import { FileML } from "utils/machineLearning/mldataMappers"; import { getLocalCollections } from "./collectionService"; import { getAllLocalFiles } from "./fileService"; import { getLocalTrashedFiles } from "./trashService"; -const ENDPOINT = getEndpoint(); - const DIFF_LIMIT = 500; -const EMBEDDINGS_TABLE_V1 = "embeddings"; -const EMBEDDINGS_TABLE = "embeddings_v2"; +/** Local storage key suffix for embedding sync times */ +const embeddingSyncTimeLSKeySuffix = "embedding_sync_time"; +/** Local storage key for CLIP embeddings. */ +const clipEmbeddingsLSKey = "embeddings_v2"; const FILE_EMBEDING_TABLE = "file_embeddings"; -const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time"; -export const getAllLocalEmbeddings = async () => { +/** Return all CLIP embeddings that we have available locally. */ +export const localCLIPEmbeddings = async () => + (await storedCLIPEmbeddings()).filter(({ model }) => model === "onnx-clip"); + +const storedCLIPEmbeddings = async () => { const embeddings: Array = - await localForage.getItem(EMBEDDINGS_TABLE); + await localForage.getItem(clipEmbeddingsLSKey); if (!embeddings) { - await localForage.removeItem(EMBEDDINGS_TABLE_V1); - await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE); - await localForage.setItem(EMBEDDINGS_TABLE, []); + // Migrate + await localForage.removeItem("embeddings"); + await localForage.removeItem("embedding_sync_time"); + await localForage.setItem(clipEmbeddingsLSKey, []); return []; } return embeddings; @@ -54,15 +54,10 @@ export const getFileMLEmbeddings = async (): Promise => { return embeddings; }; -export const getLocalEmbeddings = async () => { - const embeddings = await getAllLocalEmbeddings(); - return embeddings.filter((embedding) => embedding.model === "onnx-clip"); -}; - const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => { return ( (await localForage.getItem( - `${model}-${EMBEDDING_SYNC_TIME_TABLE}`, + `${model}-${embeddingSyncTimeLSKeySuffix}`, )) ?? 0 ); }; @@ -71,13 +66,17 @@ const setModelEmbeddingSyncTime = async ( model: EmbeddingModel, time: number, ) => { - await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time); + await localForage.setItem(`${model}-${embeddingSyncTimeLSKeySuffix}`, time); }; -export const syncEmbeddings = async () => { - const models: EmbeddingModel[] = ["onnx-clip"]; +/** + * Fetch new CLIP embeddings with the server and save them locally. Also prune + * local embeddings for any files no longer exist locally. + */ +export const syncCLIPEmbeddings = async () => { + const model: EmbeddingModel = "onnx-clip"; try { - let allEmbeddings = await getAllLocalEmbeddings(); + let allEmbeddings = await storedCLIPEmbeddings(); const localFiles = await getAllLocalFiles(); const hiddenAlbums = await getLocalCollections("hidden"); const localTrashFiles = await getLocalTrashedFiles(); @@ -89,79 +88,75 @@ export const syncEmbeddings = async () => { await cleanupDeletedEmbeddings( allLocalFiles, allEmbeddings, - EMBEDDINGS_TABLE, + clipEmbeddingsLSKey, ); log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`); - for (const model of models) { - let modelLastSinceTime = await getModelEmbeddingSyncTime(model); - log.info( - `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`, - ); - let response: GetEmbeddingDiffResponse; - do { - response = await getEmbeddingsDiff(modelLastSinceTime, model); - if (!response.diff?.length) { - return; - } - const newEmbeddings = await Promise.all( - response.diff.map(async (embedding) => { - try { - const { - encryptedEmbedding, - decryptionHeader, - ...rest - } = embedding; - const worker = - await ComlinkCryptoWorker.getInstance(); - const fileKey = fileIdToKeyMap.get( - embedding.fileID, - ); - if (!fileKey) { - throw Error(CustomError.FILE_NOT_FOUND); - } - const decryptedData = await worker.decryptEmbedding( - encryptedEmbedding, - decryptionHeader, - fileIdToKeyMap.get(embedding.fileID), - ); - return { - ...rest, - embedding: decryptedData, - } as Embedding; - } catch (e) { - let hasHiddenAlbums = false; - if (e.message === CustomError.FILE_NOT_FOUND) { - hasHiddenAlbums = hiddenAlbums?.length > 0; - } - log.error( - `decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`, - e, - ); + let modelLastSinceTime = await getModelEmbeddingSyncTime(model); + log.info( + `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`, + ); + let response: GetEmbeddingDiffResponse; + do { + response = await getEmbeddingsDiff(modelLastSinceTime, model); + if (!response.diff?.length) { + return; + } + const newEmbeddings = await Promise.all( + response.diff.map(async (embedding) => { + try { + const { + encryptedEmbedding, + decryptionHeader, + ...rest + } = embedding; + const worker = await ComlinkCryptoWorker.getInstance(); + const fileKey = fileIdToKeyMap.get(embedding.fileID); + if (!fileKey) { + throw Error(CustomError.FILE_NOT_FOUND); } - }), - ); - allEmbeddings = getLatestVersionEmbeddings([ - ...allEmbeddings, - ...newEmbeddings, - ]); - if (response.diff.length) { - modelLastSinceTime = response.diff.slice(-1)[0].updatedAt; - } - await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings); - await setModelEmbeddingSyncTime(model, modelLastSinceTime); - log.info( - `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`, - ); - } while (response.diff.length === DIFF_LIMIT); - } + const decryptedData = await worker.decryptEmbedding( + encryptedEmbedding, + decryptionHeader, + fileIdToKeyMap.get(embedding.fileID), + ); + + return { + ...rest, + embedding: decryptedData, + } as Embedding; + } catch (e) { + let hasHiddenAlbums = false; + if (e.message === CustomError.FILE_NOT_FOUND) { + hasHiddenAlbums = hiddenAlbums?.length > 0; + } + log.error( + `decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`, + e, + ); + } + }), + ); + allEmbeddings = getLatestVersionEmbeddings([ + ...allEmbeddings, + ...newEmbeddings, + ]); + if (response.diff.length) { + modelLastSinceTime = response.diff.slice(-1)[0].updatedAt; + } + await localForage.setItem(clipEmbeddingsLSKey, allEmbeddings); + await setModelEmbeddingSyncTime(model, modelLastSinceTime); + log.info( + `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`, + ); + } while (response.diff.length === DIFF_LIMIT); } catch (e) { log.error("Sync embeddings failed", e); } }; -export const syncFileEmbeddings = async () => { - const models: EmbeddingModel[] = ["file-ml-clip-face"]; +export const syncFaceEmbeddings = async () => { + const model: EmbeddingModel = "file-ml-clip-face"; try { let allEmbeddings: FileML[] = await getFileMLEmbeddings(); const localFiles = await getAllLocalFiles(); @@ -178,69 +173,99 @@ export const syncFileEmbeddings = async () => { FILE_EMBEDING_TABLE, ); log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`); - for (const model of models) { - let modelLastSinceTime = await getModelEmbeddingSyncTime(model); - log.info( - `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`, - ); - let response: GetEmbeddingDiffResponse; - do { - response = await getEmbeddingsDiff(modelLastSinceTime, model); - if (!response.diff?.length) { - return; - } - const newEmbeddings = await Promise.all( - response.diff.map(async (embedding) => { - try { - const worker = - await ComlinkCryptoWorker.getInstance(); - const fileKey = fileIdToKeyMap.get( - embedding.fileID, - ); - if (!fileKey) { - throw Error(CustomError.FILE_NOT_FOUND); - } - const decryptedData = await worker.decryptMetadata( - embedding.encryptedEmbedding, - embedding.decryptionHeader, - fileIdToKeyMap.get(embedding.fileID), - ); - return { - ...decryptedData, - updatedAt: embedding.updatedAt, - } as unknown as FileML; - } catch (e) { - let hasHiddenAlbums = false; - if (e.message === CustomError.FILE_NOT_FOUND) { - hasHiddenAlbums = hiddenAlbums?.length > 0; - } - log.error( - `decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`, - e, - ); + let modelLastSinceTime = await getModelEmbeddingSyncTime(model); + log.info( + `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`, + ); + let response: GetEmbeddingDiffResponse; + do { + response = await getEmbeddingsDiff(modelLastSinceTime, model); + if (!response.diff?.length) { + return; + } + const newEmbeddings = await Promise.all( + response.diff.map(async (embedding) => { + try { + const worker = await ComlinkCryptoWorker.getInstance(); + const fileKey = fileIdToKeyMap.get(embedding.fileID); + if (!fileKey) { + throw Error(CustomError.FILE_NOT_FOUND); } - }), - ); - allEmbeddings = getLatestVersionFileEmbeddings([ - ...allEmbeddings, - ...newEmbeddings, - ]); - if (response.diff.length) { - modelLastSinceTime = response.diff.slice(-1)[0].updatedAt; - } - await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings); - await setModelEmbeddingSyncTime(model, modelLastSinceTime); - log.info( - `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`, - ); - } while (response.diff.length === DIFF_LIMIT); - } + const decryptedData = await worker.decryptMetadata( + embedding.encryptedEmbedding, + embedding.decryptionHeader, + fileIdToKeyMap.get(embedding.fileID), + ); + + return { + ...decryptedData, + updatedAt: embedding.updatedAt, + } as unknown as FileML; + } catch (e) { + let hasHiddenAlbums = false; + if (e.message === CustomError.FILE_NOT_FOUND) { + hasHiddenAlbums = hiddenAlbums?.length > 0; + } + log.error( + `decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`, + e, + ); + } + }), + ); + allEmbeddings = getLatestVersionFileEmbeddings([ + ...allEmbeddings, + ...newEmbeddings, + ]); + if (response.diff.length) { + modelLastSinceTime = response.diff.slice(-1)[0].updatedAt; + } + await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings); + await setModelEmbeddingSyncTime(model, modelLastSinceTime); + log.info( + `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`, + ); + } while (response.diff.length === DIFF_LIMIT); } catch (e) { log.error("Sync embeddings failed", e); } }; +const getLatestVersionEmbeddings = (embeddings: Embedding[]) => { + const latestVersionEntities = new Map(); + embeddings.forEach((embedding) => { + if (!embedding?.fileID) { + return; + } + const existingEmbeddings = latestVersionEntities.get(embedding.fileID); + if ( + !existingEmbeddings || + existingEmbeddings.updatedAt < embedding.updatedAt + ) { + latestVersionEntities.set(embedding.fileID, embedding); + } + }); + return Array.from(latestVersionEntities.values()); +}; + +const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => { + const latestVersionEntities = new Map(); + embeddings.forEach((embedding) => { + if (!embedding?.fileID) { + return; + } + const existingEmbeddings = latestVersionEntities.get(embedding.fileID); + if ( + !existingEmbeddings || + existingEmbeddings.updatedAt < embedding.updatedAt + ) { + latestVersionEntities.set(embedding.fileID, embedding); + } + }); + return Array.from(latestVersionEntities.values()); +}; + export const getEmbeddingsDiff = async ( sinceTime: number, model: EmbeddingModel, @@ -251,7 +276,7 @@ export const getEmbeddingsDiff = async ( return; } const response = await HTTPService.get( - `${ENDPOINT}/embeddings/diff`, + `${getEndpoint()}/embeddings/diff`, { sinceTime, limit: DIFF_LIMIT, @@ -280,7 +305,7 @@ export const putEmbedding = async ( throw Error(CustomError.TOKEN_MISSING); } const resp = await HTTPService.put( - `${ENDPOINT}/embeddings`, + `${getEndpoint()}/embeddings`, putEmbeddingReq, null, { diff --git a/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts b/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts index 1d7deac5e2..f23a065c83 100644 --- a/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts +++ b/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts @@ -1,3 +1,5 @@ +import { Matrix } from "ml-matrix"; +import { Point } from "services/ml/geom"; import { FaceAlignment, FaceAlignmentMethod, @@ -5,7 +7,7 @@ import { FaceDetection, Versioned, } from "services/ml/types"; -import { getArcfaceAlignment } from "utils/machineLearning/faceAlign"; +import { getSimilarityTransformation } from "similarity-transformation"; class ArcfaceAlignmentService implements FaceAlignmentService { public method: Versioned; @@ -23,3 +25,86 @@ class ArcfaceAlignmentService implements FaceAlignmentService { } export default new ArcfaceAlignmentService(); + +const ARCFACE_LANDMARKS = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [56.1396, 92.2848], +] as Array<[number, number]>; + +const ARCFACE_LANDMARKS_FACE_SIZE = 112; + +const ARC_FACE_5_LANDMARKS = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], +] as Array<[number, number]>; + +export function getArcfaceAlignment( + faceDetection: FaceDetection, +): FaceAlignment { + const landmarkCount = faceDetection.landmarks.length; + return getFaceAlignmentUsingSimilarityTransform( + faceDetection, + normalizeLandmarks( + landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS, + ARCFACE_LANDMARKS_FACE_SIZE, + ), + ); +} + +function getFaceAlignmentUsingSimilarityTransform( + faceDetection: FaceDetection, + alignedLandmarks: Array<[number, number]>, + // alignmentMethod: Versioned +): FaceAlignment { + const landmarksMat = new Matrix( + faceDetection.landmarks + .map((p) => [p.x, p.y]) + .slice(0, alignedLandmarks.length), + ).transpose(); + const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose(); + + const simTransform = getSimilarityTransformation( + landmarksMat, + alignedLandmarksMat, + ); + + const RS = Matrix.mul(simTransform.rotation, simTransform.scale); + const TR = simTransform.translation; + + const affineMatrix = [ + [RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)], + [RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)], + [0, 0, 1], + ]; + + const size = 1 / simTransform.scale; + const meanTranslation = simTransform.toMean.sub(0.5).mul(size); + const centerMat = simTransform.fromMean.sub(meanTranslation); + const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0)); + const rotation = -Math.atan2( + simTransform.rotation.get(0, 1), + simTransform.rotation.get(0, 0), + ); + // log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size }); + + return { + affineMatrix, + center, + size, + rotation, + }; +} + +function normalizeLandmarks( + landmarks: Array<[number, number]>, + faceSize: number, +): Array<[number, number]> { + return landmarks.map((landmark) => + landmark.map((p) => p / faceSize), + ) as Array<[number, number]>; +} diff --git a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts b/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts index f5e1e835a6..2075d6acf8 100644 --- a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts +++ b/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts @@ -1,4 +1,6 @@ +import { Box, enlargeBox } from "services/ml/geom"; import { + FaceAlignment, FaceCrop, FaceCropConfig, FaceCropMethod, @@ -6,8 +8,8 @@ import { FaceDetection, Versioned, } from "services/ml/types"; -import { getArcfaceAlignment } from "utils/machineLearning/faceAlign"; -import { getFaceCrop } from "utils/machineLearning/faceCrop"; +import { cropWithRotation } from "utils/image"; +import { getArcfaceAlignment } from "./arcfaceAlignmentService"; class ArcFaceCropService implements FaceCropService { public method: Versioned; @@ -32,3 +34,27 @@ class ArcFaceCropService implements FaceCropService { } export default new ArcFaceCropService(); + +export function getFaceCrop( + imageBitmap: ImageBitmap, + alignment: FaceAlignment, + config: FaceCropConfig, +): FaceCrop { + const alignmentBox = new Box({ + x: alignment.center.x - alignment.size / 2, + y: alignment.center.y - alignment.size / 2, + width: alignment.size, + height: alignment.size, + }).round(); + const scaleForPadding = 1 + config.padding * 2; + const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round(); + const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, { + width: config.maxSize, + height: config.maxSize, + }); + + return { + image: faceImageBitmap, + imageBox: paddedBox, + }; +} diff --git a/web/apps/photos/src/services/machineLearning/faceService.ts b/web/apps/photos/src/services/machineLearning/faceService.ts index 6bb045ae2b..b7805b3360 100644 --- a/web/apps/photos/src/services/machineLearning/faceService.ts +++ b/web/apps/photos/src/services/machineLearning/faceService.ts @@ -1,22 +1,20 @@ import { openCache } from "@/next/blob-cache"; import log from "@/next/log"; +import mlIDbStorage from "services/ml/db"; import { DetectedFace, Face, MLSyncContext, MLSyncFileContext, + type FaceAlignment, + type Versioned, } from "services/ml/types"; -import { imageBitmapToBlob } from "utils/image"; -import { - areFaceIdsSame, - extractFaceImagesToFloat32, +import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image"; +import ReaderService, { getFaceId, getLocalFile, getOriginalImageBitmap, - isDifferentOrOld, -} from "utils/machineLearning"; -import mlIDbStorage from "utils/storage/mlIDbStorage"; -import ReaderService from "./readerService"; +} from "./readerService"; class FaceService { async syncFileFaceDetections( @@ -304,3 +302,58 @@ class FaceService { } export default new FaceService(); + +export function areFaceIdsSame(ofFaces: Array, toFaces: Array) { + if ( + (ofFaces === null || ofFaces === undefined) && + (toFaces === null || toFaces === undefined) + ) { + return true; + } + return primitiveArrayEquals( + ofFaces?.map((f) => f.id), + toFaces?.map((f) => f.id), + ); +} + +function primitiveArrayEquals(a, b) { + return ( + Array.isArray(a) && + Array.isArray(b) && + a.length === b.length && + a.every((val, index) => val === b[index]) + ); +} + +export function isDifferentOrOld( + method: Versioned, + thanMethod: Versioned, +) { + return ( + !method || + method.value !== thanMethod.value || + method.version < thanMethod.version + ); +} + +async function extractFaceImagesToFloat32( + faceAlignments: Array, + faceSize: number, + image: ImageBitmap, +): Promise { + const faceData = new Float32Array( + faceAlignments.length * faceSize * faceSize * 3, + ); + for (let i = 0; i < faceAlignments.length; i++) { + const alignedFace = faceAlignments[i]; + const faceDataOffset = i * faceSize * faceSize * 3; + warpAffineFloat32List( + image, + alignedFace, + faceSize, + faceData, + faceDataOffset, + ); + } + return faceData; +} diff --git a/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts b/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts deleted file mode 100644 index 999bae8fe9..0000000000 --- a/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts +++ /dev/null @@ -1,216 +0,0 @@ -import { haveWindow } from "@/next/env"; -import log from "@/next/log"; -import { ComlinkWorker } from "@/next/worker/comlink-worker"; -import { getDedicatedCryptoWorker } from "@ente/shared/crypto"; -import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker"; -import PQueue from "p-queue"; -import { - BlurDetectionMethod, - BlurDetectionService, - ClusteringMethod, - ClusteringService, - Face, - FaceAlignmentMethod, - FaceAlignmentService, - FaceCropMethod, - FaceCropService, - FaceDetectionMethod, - FaceDetectionService, - FaceEmbeddingMethod, - FaceEmbeddingService, - MLLibraryData, - MLSyncConfig, - MLSyncContext, -} from "services/ml/types"; -import { EnteFile } from "types/file"; -import { logQueueStats } from "utils/machineLearning"; -import arcfaceAlignmentService from "./arcfaceAlignmentService"; -import arcfaceCropService from "./arcfaceCropService"; -import dbscanClusteringService from "./dbscanClusteringService"; -import hdbscanClusteringService from "./hdbscanClusteringService"; -import laplacianBlurDetectionService from "./laplacianBlurDetectionService"; -import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService"; -import yoloFaceDetectionService from "./yoloFaceDetectionService"; - -export class MLFactory { - public static getFaceDetectionService( - method: FaceDetectionMethod, - ): FaceDetectionService { - if (method === "YoloFace") { - return yoloFaceDetectionService; - } - - throw Error("Unknon face detection method: " + method); - } - - public static getFaceCropService(method: FaceCropMethod) { - if (method === "ArcFace") { - return arcfaceCropService; - } - - throw Error("Unknon face crop method: " + method); - } - - public static getFaceAlignmentService( - method: FaceAlignmentMethod, - ): FaceAlignmentService { - if (method === "ArcFace") { - return arcfaceAlignmentService; - } - - throw Error("Unknon face alignment method: " + method); - } - - public static getBlurDetectionService( - method: BlurDetectionMethod, - ): BlurDetectionService { - if (method === "Laplacian") { - return laplacianBlurDetectionService; - } - - throw Error("Unknon blur detection method: " + method); - } - - public static getFaceEmbeddingService( - method: FaceEmbeddingMethod, - ): FaceEmbeddingService { - if (method === "MobileFaceNet") { - return mobileFaceNetEmbeddingService; - } - - throw Error("Unknon face embedding method: " + method); - } - - public static getClusteringService( - method: ClusteringMethod, - ): ClusteringService { - if (method === "Hdbscan") { - return hdbscanClusteringService; - } - if (method === "Dbscan") { - return dbscanClusteringService; - } - - throw Error("Unknon clustering method: " + method); - } - - public static getMLSyncContext( - token: string, - userID: number, - config: MLSyncConfig, - shouldUpdateMLVersion: boolean = true, - ) { - return new LocalMLSyncContext( - token, - userID, - config, - shouldUpdateMLVersion, - ); - } -} - -export class LocalMLSyncContext implements MLSyncContext { - public token: string; - public userID: number; - public config: MLSyncConfig; - public shouldUpdateMLVersion: boolean; - - public faceDetectionService: FaceDetectionService; - public faceCropService: FaceCropService; - public faceAlignmentService: FaceAlignmentService; - public blurDetectionService: BlurDetectionService; - public faceEmbeddingService: FaceEmbeddingService; - public faceClusteringService: ClusteringService; - - public localFilesMap: Map; - public outOfSyncFiles: EnteFile[]; - public nSyncedFiles: number; - public nSyncedFaces: number; - public allSyncedFacesMap?: Map>; - - public error?: Error; - - public mlLibraryData: MLLibraryData; - - public syncQueue: PQueue; - // TODO: wheather to limit concurrent downloads - // private downloadQueue: PQueue; - - private concurrency: number; - private comlinkCryptoWorker: Array< - ComlinkWorker - >; - private enteWorkers: Array; - - constructor( - token: string, - userID: number, - config: MLSyncConfig, - shouldUpdateMLVersion: boolean = true, - concurrency?: number, - ) { - this.token = token; - this.userID = userID; - this.config = config; - this.shouldUpdateMLVersion = shouldUpdateMLVersion; - - this.faceDetectionService = MLFactory.getFaceDetectionService( - this.config.faceDetection.method, - ); - this.faceCropService = MLFactory.getFaceCropService( - this.config.faceCrop.method, - ); - this.faceAlignmentService = MLFactory.getFaceAlignmentService( - this.config.faceAlignment.method, - ); - this.blurDetectionService = MLFactory.getBlurDetectionService( - this.config.blurDetection.method, - ); - this.faceEmbeddingService = MLFactory.getFaceEmbeddingService( - this.config.faceEmbedding.method, - ); - this.faceClusteringService = MLFactory.getClusteringService( - this.config.faceClustering.method, - ); - - this.outOfSyncFiles = []; - this.nSyncedFiles = 0; - this.nSyncedFaces = 0; - - this.concurrency = concurrency ?? getConcurrency(); - - log.info("Using concurrency: ", this.concurrency); - // timeout is added on downloads - // timeout on queue will keep the operation open till worker is terminated - this.syncQueue = new PQueue({ concurrency: this.concurrency }); - logQueueStats(this.syncQueue, "sync"); - // this.downloadQueue = new PQueue({ concurrency: 1 }); - // logQueueStats(this.downloadQueue, 'download'); - - this.comlinkCryptoWorker = new Array(this.concurrency); - this.enteWorkers = new Array(this.concurrency); - } - - public async getEnteWorker(id: number): Promise { - const wid = id % this.enteWorkers.length; - console.log("getEnteWorker: ", id, wid); - if (!this.enteWorkers[wid]) { - this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker(); - this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote; - } - - return this.enteWorkers[wid]; - } - - public async dispose() { - this.localFilesMap = undefined; - await this.syncQueue.onIdle(); - this.syncQueue.removeAllListeners(); - for (const enteComlinkWorker of this.comlinkCryptoWorker) { - enteComlinkWorker?.terminate(); - } - } -} - -export const getConcurrency = () => - haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2)); diff --git a/web/apps/photos/src/services/machineLearning/machineLearningService.ts b/web/apps/photos/src/services/machineLearning/machineLearningService.ts index 48edb0490e..4ac17dbb8c 100644 --- a/web/apps/photos/src/services/machineLearning/machineLearningService.ts +++ b/web/apps/photos/src/services/machineLearning/machineLearningService.ts @@ -1,25 +1,330 @@ +import { haveWindow } from "@/next/env"; import log from "@/next/log"; +import { ComlinkWorker } from "@/next/worker/comlink-worker"; import { APPS } from "@ente/shared/apps/constants"; -import ComlinkCryptoWorker from "@ente/shared/crypto"; +import ComlinkCryptoWorker, { + getDedicatedCryptoWorker, +} from "@ente/shared/crypto"; +import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker"; import { CustomError, parseUploadErrorCodes } from "@ente/shared/error"; -import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig"; +import PQueue from "p-queue"; import downloadManager from "services/download"; import { putEmbedding } from "services/embeddingService"; import { getLocalFiles } from "services/fileService"; +import mlIDbStorage, { + ML_SEARCH_CONFIG_NAME, + ML_SYNC_CONFIG_NAME, + ML_SYNC_JOB_CONFIG_NAME, +} from "services/ml/db"; import { + BlurDetectionMethod, + BlurDetectionService, + ClusteringMethod, + ClusteringService, + Face, + FaceAlignmentMethod, + FaceAlignmentService, + FaceCropMethod, + FaceCropService, + FaceDetection, + FaceDetectionMethod, + FaceDetectionService, + FaceEmbeddingMethod, + FaceEmbeddingService, + Landmark, + MLLibraryData, + MLSearchConfig, + MLSyncConfig, MLSyncContext, MLSyncFileContext, MLSyncResult, MlFileData, } from "services/ml/types"; import { EnteFile } from "types/file"; -import { getMLSyncConfig } from "utils/machineLearning/config"; -import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappers"; -import mlIDbStorage from "utils/storage/mlIDbStorage"; +import { isInternalUserForML } from "utils/user"; +import arcfaceAlignmentService from "./arcfaceAlignmentService"; +import arcfaceCropService from "./arcfaceCropService"; +import dbscanClusteringService from "./dbscanClusteringService"; import FaceService from "./faceService"; -import { MLFactory } from "./machineLearningFactory"; +import hdbscanClusteringService from "./hdbscanClusteringService"; +import laplacianBlurDetectionService from "./laplacianBlurDetectionService"; +import type { JobConfig } from "./mlWorkManager"; +import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService"; import PeopleService from "./peopleService"; import ReaderService from "./readerService"; +import yoloFaceDetectionService from "./yoloFaceDetectionService"; + +export const DEFAULT_ML_SYNC_JOB_CONFIG: JobConfig = { + intervalSec: 5, + // TODO: finalize this after seeing effects on and from machine sleep + maxItervalSec: 960, + backoffMultiplier: 2, +}; + +export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = { + batchSize: 200, + imageSource: "Original", + faceDetection: { + method: "YoloFace", + }, + faceCrop: { + enabled: true, + method: "ArcFace", + padding: 0.25, + maxSize: 256, + blobOptions: { + type: "image/jpeg", + quality: 0.8, + }, + }, + faceAlignment: { + method: "ArcFace", + }, + blurDetection: { + method: "Laplacian", + threshold: 15, + }, + faceEmbedding: { + method: "MobileFaceNet", + faceSize: 112, + generateTsne: true, + }, + faceClustering: { + method: "Hdbscan", + minClusterSize: 3, + minSamples: 5, + clusterSelectionEpsilon: 0.6, + clusterSelectionMethod: "leaf", + minInputSize: 50, + // maxDistanceInsideCluster: 0.4, + generateDebugInfo: true, + }, + mlVersion: 3, +}; + +export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = { + enabled: false, +}; + +export const MAX_ML_SYNC_ERROR_COUNT = 1; + +export async function getMLSyncJobConfig() { + return mlIDbStorage.getConfig( + ML_SYNC_JOB_CONFIG_NAME, + DEFAULT_ML_SYNC_JOB_CONFIG, + ); +} + +export async function getMLSyncConfig() { + return mlIDbStorage.getConfig(ML_SYNC_CONFIG_NAME, DEFAULT_ML_SYNC_CONFIG); +} + +export async function getMLSearchConfig() { + if (isInternalUserForML()) { + return mlIDbStorage.getConfig( + ML_SEARCH_CONFIG_NAME, + DEFAULT_ML_SEARCH_CONFIG, + ); + } + // Force disabled for everyone else while we finalize it to avoid redundant + // reindexing for users. + return DEFAULT_ML_SEARCH_CONFIG; +} + +export async function updateMLSyncJobConfig(newConfig: JobConfig) { + return mlIDbStorage.putConfig(ML_SYNC_JOB_CONFIG_NAME, newConfig); +} + +export async function updateMLSyncConfig(newConfig: MLSyncConfig) { + return mlIDbStorage.putConfig(ML_SYNC_CONFIG_NAME, newConfig); +} + +export async function updateMLSearchConfig(newConfig: MLSearchConfig) { + return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig); +} + +export class MLFactory { + public static getFaceDetectionService( + method: FaceDetectionMethod, + ): FaceDetectionService { + if (method === "YoloFace") { + return yoloFaceDetectionService; + } + + throw Error("Unknon face detection method: " + method); + } + + public static getFaceCropService(method: FaceCropMethod) { + if (method === "ArcFace") { + return arcfaceCropService; + } + + throw Error("Unknon face crop method: " + method); + } + + public static getFaceAlignmentService( + method: FaceAlignmentMethod, + ): FaceAlignmentService { + if (method === "ArcFace") { + return arcfaceAlignmentService; + } + + throw Error("Unknon face alignment method: " + method); + } + + public static getBlurDetectionService( + method: BlurDetectionMethod, + ): BlurDetectionService { + if (method === "Laplacian") { + return laplacianBlurDetectionService; + } + + throw Error("Unknon blur detection method: " + method); + } + + public static getFaceEmbeddingService( + method: FaceEmbeddingMethod, + ): FaceEmbeddingService { + if (method === "MobileFaceNet") { + return mobileFaceNetEmbeddingService; + } + + throw Error("Unknon face embedding method: " + method); + } + + public static getClusteringService( + method: ClusteringMethod, + ): ClusteringService { + if (method === "Hdbscan") { + return hdbscanClusteringService; + } + if (method === "Dbscan") { + return dbscanClusteringService; + } + + throw Error("Unknon clustering method: " + method); + } + + public static getMLSyncContext( + token: string, + userID: number, + config: MLSyncConfig, + shouldUpdateMLVersion: boolean = true, + ) { + return new LocalMLSyncContext( + token, + userID, + config, + shouldUpdateMLVersion, + ); + } +} + +export class LocalMLSyncContext implements MLSyncContext { + public token: string; + public userID: number; + public config: MLSyncConfig; + public shouldUpdateMLVersion: boolean; + + public faceDetectionService: FaceDetectionService; + public faceCropService: FaceCropService; + public faceAlignmentService: FaceAlignmentService; + public blurDetectionService: BlurDetectionService; + public faceEmbeddingService: FaceEmbeddingService; + public faceClusteringService: ClusteringService; + + public localFilesMap: Map; + public outOfSyncFiles: EnteFile[]; + public nSyncedFiles: number; + public nSyncedFaces: number; + public allSyncedFacesMap?: Map>; + + public error?: Error; + + public mlLibraryData: MLLibraryData; + + public syncQueue: PQueue; + // TODO: wheather to limit concurrent downloads + // private downloadQueue: PQueue; + + private concurrency: number; + private comlinkCryptoWorker: Array< + ComlinkWorker + >; + private enteWorkers: Array; + + constructor( + token: string, + userID: number, + config: MLSyncConfig, + shouldUpdateMLVersion: boolean = true, + concurrency?: number, + ) { + this.token = token; + this.userID = userID; + this.config = config; + this.shouldUpdateMLVersion = shouldUpdateMLVersion; + + this.faceDetectionService = MLFactory.getFaceDetectionService( + this.config.faceDetection.method, + ); + this.faceCropService = MLFactory.getFaceCropService( + this.config.faceCrop.method, + ); + this.faceAlignmentService = MLFactory.getFaceAlignmentService( + this.config.faceAlignment.method, + ); + this.blurDetectionService = MLFactory.getBlurDetectionService( + this.config.blurDetection.method, + ); + this.faceEmbeddingService = MLFactory.getFaceEmbeddingService( + this.config.faceEmbedding.method, + ); + this.faceClusteringService = MLFactory.getClusteringService( + this.config.faceClustering.method, + ); + + this.outOfSyncFiles = []; + this.nSyncedFiles = 0; + this.nSyncedFaces = 0; + + this.concurrency = concurrency ?? getConcurrency(); + + log.info("Using concurrency: ", this.concurrency); + // timeout is added on downloads + // timeout on queue will keep the operation open till worker is terminated + this.syncQueue = new PQueue({ concurrency: this.concurrency }); + logQueueStats(this.syncQueue, "sync"); + // this.downloadQueue = new PQueue({ concurrency: 1 }); + // logQueueStats(this.downloadQueue, 'download'); + + this.comlinkCryptoWorker = new Array(this.concurrency); + this.enteWorkers = new Array(this.concurrency); + } + + public async getEnteWorker(id: number): Promise { + const wid = id % this.enteWorkers.length; + console.log("getEnteWorker: ", id, wid); + if (!this.enteWorkers[wid]) { + this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker(); + this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote; + } + + return this.enteWorkers[wid]; + } + + public async dispose() { + this.localFilesMap = undefined; + await this.syncQueue.onIdle(); + this.syncQueue.removeAllListeners(); + for (const enteComlinkWorker of this.comlinkCryptoWorker) { + enteComlinkWorker?.terminate(); + } + } +} + +export const getConcurrency = () => + haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2)); class MachineLearningService { private localSyncContext: Promise; @@ -445,3 +750,160 @@ class MachineLearningService { } export default new MachineLearningService(); + +export interface FileML extends ServerFileMl { + updatedAt: number; +} + +class ServerFileMl { + public fileID: number; + public height?: number; + public width?: number; + public faceEmbedding: ServerFaceEmbeddings; + + public constructor( + fileID: number, + faceEmbedding: ServerFaceEmbeddings, + height?: number, + width?: number, + ) { + this.fileID = fileID; + this.height = height; + this.width = width; + this.faceEmbedding = faceEmbedding; + } +} + +class ServerFaceEmbeddings { + public faces: ServerFace[]; + public version: number; + public client?: string; + public error?: boolean; + + public constructor( + faces: ServerFace[], + version: number, + client?: string, + error?: boolean, + ) { + this.faces = faces; + this.version = version; + this.client = client; + this.error = error; + } +} + +class ServerFace { + public faceID: string; + public embeddings: number[]; + public detection: ServerDetection; + public score: number; + public blur: number; + + public constructor( + faceID: string, + embeddings: number[], + detection: ServerDetection, + score: number, + blur: number, + ) { + this.faceID = faceID; + this.embeddings = embeddings; + this.detection = detection; + this.score = score; + this.blur = blur; + } +} + +class ServerDetection { + public box: ServerFaceBox; + public landmarks: Landmark[]; + + public constructor(box: ServerFaceBox, landmarks: Landmark[]) { + this.box = box; + this.landmarks = landmarks; + } +} + +class ServerFaceBox { + public xMin: number; + public yMin: number; + public width: number; + public height: number; + + public constructor( + xMin: number, + yMin: number, + width: number, + height: number, + ) { + this.xMin = xMin; + this.yMin = yMin; + this.width = width; + this.height = height; + } +} + +function LocalFileMlDataToServerFileMl( + localFileMlData: MlFileData, +): ServerFileMl { + if ( + localFileMlData.errorCount > 0 && + localFileMlData.lastErrorMessage !== undefined + ) { + return null; + } + const imageDimensions = localFileMlData.imageDimensions; + + const faces: ServerFace[] = []; + for (let i = 0; i < localFileMlData.faces.length; i++) { + const face: Face = localFileMlData.faces[i]; + const faceID = face.id; + const embedding = face.embedding; + const score = face.detection.probability; + const blur = face.blurValue; + const detection: FaceDetection = face.detection; + const box = detection.box; + const landmarks = detection.landmarks; + const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height); + const newLandmarks: Landmark[] = []; + for (let j = 0; j < landmarks.length; j++) { + newLandmarks.push({ + x: landmarks[j].x, + y: landmarks[j].y, + } as Landmark); + } + + const newFaceObject = new ServerFace( + faceID, + Array.from(embedding), + new ServerDetection(newBox, newLandmarks), + score, + blur, + ); + faces.push(newFaceObject); + } + const faceEmbeddings = new ServerFaceEmbeddings( + faces, + 1, + localFileMlData.lastErrorMessage, + ); + return new ServerFileMl( + localFileMlData.fileId, + faceEmbeddings, + imageDimensions.height, + imageDimensions.width, + ); +} + +export function logQueueStats(queue: PQueue, name: string) { + queue.on("active", () => + log.info( + `queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`, + ), + ); + queue.on("idle", () => log.info(`queuestats: ${name}: Idle`)); + queue.on("error", (error) => + console.error(`queuestats: ${name}: Error, `, error), + ); +} diff --git a/web/apps/photos/src/services/machineLearning/mlWorkManager.ts b/web/apps/photos/src/services/machineLearning/mlWorkManager.ts index c75d4b14a0..700d358e04 100644 --- a/web/apps/photos/src/services/machineLearning/mlWorkManager.ts +++ b/web/apps/photos/src/services/machineLearning/mlWorkManager.ts @@ -5,20 +5,110 @@ import { eventBus, Events } from "@ente/shared/events"; import { getToken, getUserID } from "@ente/shared/storage/localStorage/helpers"; import debounce from "debounce"; import PQueue from "p-queue"; +import { getMLSyncJobConfig } from "services/machineLearning/machineLearningService"; +import mlIDbStorage from "services/ml/db"; import { MLSyncResult } from "services/ml/types"; -import { JobResult } from "types/common/job"; import { EnteFile } from "types/file"; import { getDedicatedMLWorker } from "utils/comlink/ComlinkMLWorker"; -import { SimpleJob } from "utils/common/job"; -import { logQueueStats } from "utils/machineLearning"; -import { getMLSyncJobConfig } from "utils/machineLearning/config"; -import mlIDbStorage from "utils/storage/mlIDbStorage"; import { DedicatedMLWorker } from "worker/ml.worker"; +import { logQueueStats } from "./machineLearningService"; const LIVE_SYNC_IDLE_DEBOUNCE_SEC = 30; const LIVE_SYNC_QUEUE_TIMEOUT_SEC = 300; const LOCAL_FILES_UPDATED_DEBOUNCE_SEC = 30; +export type JobState = "Scheduled" | "Running" | "NotScheduled"; + +export interface JobConfig { + intervalSec: number; + maxItervalSec: number; + backoffMultiplier: number; +} + +export interface JobResult { + shouldBackoff: boolean; +} + +export class SimpleJob { + private config: JobConfig; + private runCallback: () => Promise; + private state: JobState; + private stopped: boolean; + private intervalSec: number; + private nextTimeoutId: ReturnType; + + constructor(config: JobConfig, runCallback: () => Promise) { + this.config = config; + this.runCallback = runCallback; + this.state = "NotScheduled"; + this.stopped = true; + this.intervalSec = this.config.intervalSec; + } + + public resetInterval() { + this.intervalSec = this.config.intervalSec; + } + + public start() { + this.stopped = false; + this.resetInterval(); + if (this.state !== "Running") { + this.scheduleNext(); + } else { + log.info("Job already running, not scheduling"); + } + } + + private scheduleNext() { + if (this.state === "Scheduled" || this.nextTimeoutId) { + this.clearScheduled(); + } + + this.nextTimeoutId = setTimeout( + () => this.run(), + this.intervalSec * 1000, + ); + this.state = "Scheduled"; + log.info("Scheduled next job after: ", this.intervalSec); + } + + async run() { + this.nextTimeoutId = undefined; + this.state = "Running"; + + try { + const jobResult = await this.runCallback(); + if (jobResult && jobResult.shouldBackoff) { + this.intervalSec = Math.min( + this.config.maxItervalSec, + this.intervalSec * this.config.backoffMultiplier, + ); + } else { + this.resetInterval(); + } + log.info("Job completed"); + } catch (e) { + console.error("Error while running Job: ", e); + } finally { + this.state = "NotScheduled"; + !this.stopped && this.scheduleNext(); + } + } + + // currently client is responsible to terminate running job + public stop() { + this.stopped = true; + this.clearScheduled(); + } + + private clearScheduled() { + clearTimeout(this.nextTimeoutId); + this.nextTimeoutId = undefined; + this.state = "NotScheduled"; + log.info("Cleared next job"); + } +} + export interface MLSyncJobResult extends JobResult { mlSyncResult: MLSyncResult; } diff --git a/web/apps/photos/src/services/machineLearning/peopleService.ts b/web/apps/photos/src/services/machineLearning/peopleService.ts index f4ed5104bb..f7d5cf38a1 100644 --- a/web/apps/photos/src/services/machineLearning/peopleService.ts +++ b/web/apps/photos/src/services/machineLearning/peopleService.ts @@ -1,14 +1,8 @@ import log from "@/next/log"; +import mlIDbStorage from "services/ml/db"; import { Face, MLSyncContext, Person } from "services/ml/types"; -import { - findFirstIfSorted, - getAllFacesFromMap, - getLocalFile, - getOriginalImageBitmap, - isDifferentOrOld, -} from "utils/machineLearning"; -import mlIDbStorage from "utils/storage/mlIDbStorage"; -import FaceService from "./faceService"; +import FaceService, { isDifferentOrOld } from "./faceService"; +import { getLocalFile, getOriginalImageBitmap } from "./readerService"; class PeopleService { async syncPeopleIndex(syncContext: MLSyncContext) { @@ -92,3 +86,28 @@ class PeopleService { } export default new PeopleService(); + +function findFirstIfSorted( + elements: Array, + comparator: (a: T, b: T) => number, +) { + if (!elements || elements.length < 1) { + return; + } + let first = elements[0]; + + for (let i = 1; i < elements.length; i++) { + const comp = comparator(elements[i], first); + if (comp < 0) { + first = elements[i]; + } + } + + return first; +} + +function getAllFacesFromMap(allFacesMap: Map>) { + const allFaces = [...allFacesMap.values()].flat(); + + return allFaces; +} diff --git a/web/apps/photos/src/services/machineLearning/readerService.ts b/web/apps/photos/src/services/machineLearning/readerService.ts index 44d0738f51..6ad4c80e81 100644 --- a/web/apps/photos/src/services/machineLearning/readerService.ts +++ b/web/apps/photos/src/services/machineLearning/readerService.ts @@ -1,11 +1,18 @@ import { FILE_TYPE } from "@/media/file-type"; +import { decodeLivePhoto } from "@/media/live-photo"; import log from "@/next/log"; -import { MLSyncContext, MLSyncFileContext } from "services/ml/types"; +import PQueue from "p-queue"; +import DownloadManager from "services/download"; +import { getLocalFiles } from "services/fileService"; +import { Dimensions } from "services/ml/geom"; import { - getLocalFileImageBitmap, - getOriginalImageBitmap, - getThumbnailImageBitmap, -} from "utils/machineLearning"; + DetectedFace, + MLSyncContext, + MLSyncFileContext, +} from "services/ml/types"; +import { EnteFile } from "types/file"; +import { getRenderableImage } from "utils/file"; +import { clamp } from "utils/image"; class ReaderService { async getImageBitmap( @@ -55,3 +62,95 @@ class ReaderService { } } export default new ReaderService(); + +export async function getLocalFile(fileId: number) { + const localFiles = await getLocalFiles(); + return localFiles.find((f) => f.id === fileId); +} + +export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) { + const xMin = clamp( + detectedFace.detection.box.x / imageDims.width, + 0.0, + 0.999999, + ) + .toFixed(5) + .substring(2); + const yMin = clamp( + detectedFace.detection.box.y / imageDims.height, + 0.0, + 0.999999, + ) + .toFixed(5) + .substring(2); + const xMax = clamp( + (detectedFace.detection.box.x + detectedFace.detection.box.width) / + imageDims.width, + 0.0, + 0.999999, + ) + .toFixed(5) + .substring(2); + const yMax = clamp( + (detectedFace.detection.box.y + detectedFace.detection.box.height) / + imageDims.height, + 0.0, + 0.999999, + ) + .toFixed(5) + .substring(2); + + const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`; + const faceID = `${detectedFace.fileId}_${rawFaceID}`; + + return faceID; +} + +async function getImageBlobBitmap(blob: Blob): Promise { + return await createImageBitmap(blob); +} + +async function getOriginalFile(file: EnteFile, queue?: PQueue) { + let fileStream; + if (queue) { + fileStream = await queue.add(() => DownloadManager.getFile(file)); + } else { + fileStream = await DownloadManager.getFile(file); + } + return new Response(fileStream).blob(); +} + +async function getOriginalConvertedFile(file: EnteFile, queue?: PQueue) { + const fileBlob = await getOriginalFile(file, queue); + if (file.metadata.fileType === FILE_TYPE.IMAGE) { + return await getRenderableImage(file.metadata.title, fileBlob); + } else { + const { imageFileName, imageData } = await decodeLivePhoto( + file.metadata.title, + fileBlob, + ); + return await getRenderableImage(imageFileName, new Blob([imageData])); + } +} + +export async function getOriginalImageBitmap(file: EnteFile, queue?: PQueue) { + const fileBlob = await getOriginalConvertedFile(file, queue); + log.info("[MLService] Got file: ", file.id.toString()); + return getImageBlobBitmap(fileBlob); +} + +export async function getThumbnailImageBitmap(file: EnteFile) { + const thumb = await DownloadManager.getThumbnail(file); + log.info("[MLService] Got thumbnail: ", file.id.toString()); + + return getImageBlobBitmap(new Blob([thumb])); +} + +export async function getLocalFileImageBitmap( + enteFile: EnteFile, + localFile: globalThis.File, +) { + let fileBlob = localFile as Blob; + fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob); + return getImageBlobBitmap(fileBlob); +} diff --git a/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts b/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts index 7e8af1a269..8856576bee 100644 --- a/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts +++ b/web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts @@ -323,14 +323,10 @@ function transformBox(box: Box, transform: Matrix) { const topLeft = transformPoint(box.topLeft, transform); const bottomRight = transformPoint(box.bottomRight, transform); - return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y); -} - -function newBoxFromPoints( - left: number, - top: number, - right: number, - bottom: number, -) { - return boxFromBoundingBox({ left, top, right, bottom }); + return boxFromBoundingBox({ + left: topLeft.x, + top: topLeft.y, + right: bottomRight.x, + bottom: bottomRight.y, + }); } diff --git a/web/apps/photos/src/utils/storage/mlIDbStorage.ts b/web/apps/photos/src/services/ml/db.ts similarity index 99% rename from web/apps/photos/src/utils/storage/mlIDbStorage.ts rename to web/apps/photos/src/services/ml/db.ts index c90832df3d..90b2f4aa07 100644 --- a/web/apps/photos/src/utils/storage/mlIDbStorage.ts +++ b/web/apps/photos/src/services/ml/db.ts @@ -1,11 +1,5 @@ import { haveWindow } from "@/next/env"; import log from "@/next/log"; -import { - DEFAULT_ML_SEARCH_CONFIG, - DEFAULT_ML_SYNC_CONFIG, - DEFAULT_ML_SYNC_JOB_CONFIG, - MAX_ML_SYNC_ERROR_COUNT, -} from "constants/mlConfig"; import { DBSchema, IDBPDatabase, @@ -15,6 +9,12 @@ import { openDB, } from "idb"; import isElectron from "is-electron"; +import { + DEFAULT_ML_SEARCH_CONFIG, + DEFAULT_ML_SYNC_CONFIG, + DEFAULT_ML_SYNC_JOB_CONFIG, + MAX_ML_SYNC_ERROR_COUNT, +} from "services/machineLearning/machineLearningService"; import { Face, MLLibraryData, MlFileData, Person } from "services/ml/types"; export interface IndexStatus { diff --git a/web/apps/photos/src/services/ml/geom.ts b/web/apps/photos/src/services/ml/geom.ts index 1cf51963cf..556e2b309d 100644 --- a/web/apps/photos/src/services/ml/geom.ts +++ b/web/apps/photos/src/services/ml/geom.ts @@ -76,3 +76,17 @@ export class Box implements IRect { return new Box({ x, y, width, height }); } } + +export function enlargeBox(box: Box, factor: number = 1.5) { + const center = new Point(box.x + box.width / 2, box.y + box.height / 2); + + const size = new Point(box.width, box.height); + const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2); + + return boxFromBoundingBox({ + left: center.x - newHalfSize.x, + top: center.y - newHalfSize.y, + right: center.x + newHalfSize.x, + bottom: center.y + newHalfSize.y, + }); +} diff --git a/web/apps/photos/src/services/ml/types.ts b/web/apps/photos/src/services/ml/types.ts index 2ee42b36a3..422cf9d4aa 100644 --- a/web/apps/photos/src/services/ml/types.ts +++ b/web/apps/photos/src/services/ml/types.ts @@ -329,8 +329,3 @@ export interface MachineLearningWorker { close(): void; } - -export interface ClipEmbedding { - embedding: Float32Array; - model: "ggml-clip" | "onnx-clip"; -} diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index 7c00dec724..a212fc9dcf 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -2,6 +2,8 @@ import { FILE_TYPE } from "@/media/file-type"; import log from "@/next/log"; import * as chrono from "chrono-node"; import { t } from "i18next"; +import { getMLSyncConfig } from "services/machineLearning/machineLearningService"; +import mlIDbStorage from "services/ml/db"; import { Person } from "services/ml/types"; import { Collection } from "types/collection"; import { EntityType, LocationTag, LocationTagData } from "types/entity"; @@ -16,12 +18,9 @@ import { } from "types/search"; import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker"; import { getUniqueFiles } from "utils/file"; -import { getAllPeople } from "utils/machineLearning"; -import { getMLSyncConfig } from "utils/machineLearning/config"; import { getFormattedDate } from "utils/search"; -import mlIDbStorage from "utils/storage/mlIDbStorage"; import { clipService, computeClipMatchScore } from "./clip-service"; -import { getLocalEmbeddings } from "./embeddingService"; +import { localCLIPEmbeddings } from "./embeddingService"; import { getLatestEntities } from "./entityService"; import locationSearchService, { City } from "./locationSearchService"; @@ -376,7 +375,7 @@ const searchClip = async ( await clipService.getTextEmbeddingIfAvailable(searchPhrase); if (!textEmbedding) return undefined; - const imageEmbeddings = await getLocalEmbeddings(); + const imageEmbeddings = await localCLIPEmbeddings(); const clipSearchResult = new Map( ( await Promise.all( @@ -430,3 +429,14 @@ function convertSuggestionToSearchQuery(option: Suggestion): Search { return { clip: option.value as ClipSearchScores }; } } + +async function getAllPeople(limit: number = undefined) { + let people: Array = await mlIDbStorage.getAllPeople(); + // await mlPeopleStore.iterate((person) => { + // people.push(person); + // }); + people = people ?? []; + return people + .sort((p1, p2) => p2.files.length - p1.files.length) + .slice(0, limit); +} diff --git a/web/apps/photos/src/types/common/job.ts b/web/apps/photos/src/types/common/job.ts deleted file mode 100644 index fe42e4aaf2..0000000000 --- a/web/apps/photos/src/types/common/job.ts +++ /dev/null @@ -1,11 +0,0 @@ -export type JobState = "Scheduled" | "Running" | "NotScheduled"; - -export interface JobConfig { - intervalSec: number; - maxItervalSec: number; - backoffMultiplier: number; -} - -export interface JobResult { - shouldBackoff: boolean; -} diff --git a/web/apps/photos/src/types/embedding.tsx b/web/apps/photos/src/types/embedding.tsx index d4719986bc..161244c159 100644 --- a/web/apps/photos/src/types/embedding.tsx +++ b/web/apps/photos/src/types/embedding.tsx @@ -1,9 +1,9 @@ /** - * The embeddings models that we support. + * The embeddings that we (the current client) knows how to handle. * * This is an exhaustive set of values we pass when PUT-ting encrypted * embeddings on the server. However, we should be prepared to receive an - * {@link EncryptedEmbedding} with a model value distinct from one of these. + * {@link EncryptedEmbedding} with a model value different from these. */ export type EmbeddingModel = "onnx-clip" | "file-ml-clip-face"; diff --git a/web/apps/photos/src/types/machineLearning/ui.ts b/web/apps/photos/src/types/machineLearning/ui.ts deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/web/apps/photos/src/types/search/index.ts b/web/apps/photos/src/types/search/index.ts index 7cff13298f..e08b842a30 100644 --- a/web/apps/photos/src/types/search/index.ts +++ b/web/apps/photos/src/types/search/index.ts @@ -1,9 +1,9 @@ import { FILE_TYPE } from "@/media/file-type"; import { City } from "services/locationSearchService"; +import { IndexStatus } from "services/ml/db"; import { Person } from "services/ml/types"; import { LocationTagData } from "types/entity"; import { EnteFile } from "types/file"; -import { IndexStatus } from "utils/storage/mlIDbStorage"; export enum SuggestionType { DATE = "DATE", diff --git a/web/apps/photos/src/utils/common/job.ts b/web/apps/photos/src/utils/common/job.ts deleted file mode 100644 index 365f879e95..0000000000 --- a/web/apps/photos/src/utils/common/job.ts +++ /dev/null @@ -1,82 +0,0 @@ -import log from "@/next/log"; -import { JobConfig, JobResult, JobState } from "types/common/job"; - -export class SimpleJob { - private config: JobConfig; - private runCallback: () => Promise; - private state: JobState; - private stopped: boolean; - private intervalSec: number; - private nextTimeoutId: ReturnType; - - constructor(config: JobConfig, runCallback: () => Promise) { - this.config = config; - this.runCallback = runCallback; - this.state = "NotScheduled"; - this.stopped = true; - this.intervalSec = this.config.intervalSec; - } - - public resetInterval() { - this.intervalSec = this.config.intervalSec; - } - - public start() { - this.stopped = false; - this.resetInterval(); - if (this.state !== "Running") { - this.scheduleNext(); - } else { - log.info("Job already running, not scheduling"); - } - } - - private scheduleNext() { - if (this.state === "Scheduled" || this.nextTimeoutId) { - this.clearScheduled(); - } - - this.nextTimeoutId = setTimeout( - () => this.run(), - this.intervalSec * 1000, - ); - this.state = "Scheduled"; - log.info("Scheduled next job after: ", this.intervalSec); - } - - async run() { - this.nextTimeoutId = undefined; - this.state = "Running"; - - try { - const jobResult = await this.runCallback(); - if (jobResult && jobResult.shouldBackoff) { - this.intervalSec = Math.min( - this.config.maxItervalSec, - this.intervalSec * this.config.backoffMultiplier, - ); - } else { - this.resetInterval(); - } - log.info("Job completed"); - } catch (e) { - console.error("Error while running Job: ", e); - } finally { - this.state = "NotScheduled"; - !this.stopped && this.scheduleNext(); - } - } - - // currently client is responsible to terminate running job - public stop() { - this.stopped = true; - this.clearScheduled(); - } - - private clearScheduled() { - clearTimeout(this.nextTimeoutId); - this.nextTimeoutId = undefined; - this.state = "NotScheduled"; - log.info("Cleared next job"); - } -} diff --git a/web/apps/photos/src/utils/embedding.ts b/web/apps/photos/src/utils/embedding.ts deleted file mode 100644 index 00012f174f..0000000000 --- a/web/apps/photos/src/utils/embedding.ts +++ /dev/null @@ -1,36 +0,0 @@ -import { Embedding } from "types/embedding"; -import { FileML } from "./machineLearning/mldataMappers"; - -export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => { - const latestVersionEntities = new Map(); - embeddings.forEach((embedding) => { - if (!embedding?.fileID) { - return; - } - const existingEmbeddings = latestVersionEntities.get(embedding.fileID); - if ( - !existingEmbeddings || - existingEmbeddings.updatedAt < embedding.updatedAt - ) { - latestVersionEntities.set(embedding.fileID, embedding); - } - }); - return Array.from(latestVersionEntities.values()); -}; - -export const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => { - const latestVersionEntities = new Map(); - embeddings.forEach((embedding) => { - if (!embedding?.fileID) { - return; - } - const existingEmbeddings = latestVersionEntities.get(embedding.fileID); - if ( - !existingEmbeddings || - existingEmbeddings.updatedAt < embedding.updatedAt - ) { - latestVersionEntities.set(embedding.fileID, embedding); - } - }); - return Array.from(latestVersionEntities.values()); -}; diff --git a/web/apps/photos/src/utils/image/index.ts b/web/apps/photos/src/utils/image/index.ts index 39263ab308..bdaf64d735 100644 --- a/web/apps/photos/src/utils/image/index.ts +++ b/web/apps/photos/src/utils/image/index.ts @@ -1,9 +1,8 @@ // these utils only work in env where OffscreenCanvas is available import { Matrix, inverse } from "ml-matrix"; -import { Box, Dimensions } from "services/ml/geom"; +import { Box, Dimensions, enlargeBox } from "services/ml/geom"; import { FaceAlignment } from "services/ml/types"; -import { enlargeBox } from "utils/machineLearning"; export function normalizePixelBetween0And1(pixelValue: number) { return pixelValue / 255.0; diff --git a/web/apps/photos/src/utils/machineLearning/config.ts b/web/apps/photos/src/utils/machineLearning/config.ts deleted file mode 100644 index f47a96b055..0000000000 --- a/web/apps/photos/src/utils/machineLearning/config.ts +++ /dev/null @@ -1,48 +0,0 @@ -import { - DEFAULT_ML_SEARCH_CONFIG, - DEFAULT_ML_SYNC_CONFIG, - DEFAULT_ML_SYNC_JOB_CONFIG, -} from "constants/mlConfig"; -import { MLSearchConfig, MLSyncConfig } from "services/ml/types"; -import { JobConfig } from "types/common/job"; -import mlIDbStorage, { - ML_SEARCH_CONFIG_NAME, - ML_SYNC_CONFIG_NAME, - ML_SYNC_JOB_CONFIG_NAME, -} from "utils/storage/mlIDbStorage"; -import { isInternalUserForML } from "utils/user"; - -export async function getMLSyncJobConfig() { - return mlIDbStorage.getConfig( - ML_SYNC_JOB_CONFIG_NAME, - DEFAULT_ML_SYNC_JOB_CONFIG, - ); -} - -export async function getMLSyncConfig() { - return mlIDbStorage.getConfig(ML_SYNC_CONFIG_NAME, DEFAULT_ML_SYNC_CONFIG); -} - -export async function getMLSearchConfig() { - if (isInternalUserForML()) { - return mlIDbStorage.getConfig( - ML_SEARCH_CONFIG_NAME, - DEFAULT_ML_SEARCH_CONFIG, - ); - } - // Force disabled for everyone else while we finalize it to avoid redundant - // reindexing for users. - return DEFAULT_ML_SEARCH_CONFIG; -} - -export async function updateMLSyncJobConfig(newConfig: JobConfig) { - return mlIDbStorage.putConfig(ML_SYNC_JOB_CONFIG_NAME, newConfig); -} - -export async function updateMLSyncConfig(newConfig: MLSyncConfig) { - return mlIDbStorage.putConfig(ML_SYNC_CONFIG_NAME, newConfig); -} - -export async function updateMLSearchConfig(newConfig: MLSearchConfig) { - return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig); -} diff --git a/web/apps/photos/src/utils/machineLearning/faceAlign.ts b/web/apps/photos/src/utils/machineLearning/faceAlign.ts deleted file mode 100644 index 3e6846ff80..0000000000 --- a/web/apps/photos/src/utils/machineLearning/faceAlign.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { Matrix } from "ml-matrix"; -import { Point } from "services/ml/geom"; -import { FaceAlignment, FaceDetection } from "services/ml/types"; -import { getSimilarityTransformation } from "similarity-transformation"; - -const ARCFACE_LANDMARKS = [ - [38.2946, 51.6963], - [73.5318, 51.5014], - [56.0252, 71.7366], - [56.1396, 92.2848], -] as Array<[number, number]>; - -const ARCFACE_LANDMARKS_FACE_SIZE = 112; - -const ARC_FACE_5_LANDMARKS = [ - [38.2946, 51.6963], - [73.5318, 51.5014], - [56.0252, 71.7366], - [41.5493, 92.3655], - [70.7299, 92.2041], -] as Array<[number, number]>; - -export function getArcfaceAlignment( - faceDetection: FaceDetection, -): FaceAlignment { - const landmarkCount = faceDetection.landmarks.length; - return getFaceAlignmentUsingSimilarityTransform( - faceDetection, - normalizeLandmarks( - landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS, - ARCFACE_LANDMARKS_FACE_SIZE, - ), - ); -} - -function getFaceAlignmentUsingSimilarityTransform( - faceDetection: FaceDetection, - alignedLandmarks: Array<[number, number]>, - // alignmentMethod: Versioned -): FaceAlignment { - const landmarksMat = new Matrix( - faceDetection.landmarks - .map((p) => [p.x, p.y]) - .slice(0, alignedLandmarks.length), - ).transpose(); - const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose(); - - const simTransform = getSimilarityTransformation( - landmarksMat, - alignedLandmarksMat, - ); - - const RS = Matrix.mul(simTransform.rotation, simTransform.scale); - const TR = simTransform.translation; - - const affineMatrix = [ - [RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)], - [RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)], - [0, 0, 1], - ]; - - const size = 1 / simTransform.scale; - const meanTranslation = simTransform.toMean.sub(0.5).mul(size); - const centerMat = simTransform.fromMean.sub(meanTranslation); - const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0)); - const rotation = -Math.atan2( - simTransform.rotation.get(0, 1), - simTransform.rotation.get(0, 0), - ); - // log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size }); - - return { - affineMatrix, - center, - size, - rotation, - }; -} - -function normalizeLandmarks( - landmarks: Array<[number, number]>, - faceSize: number, -): Array<[number, number]> { - return landmarks.map((landmark) => - landmark.map((p) => p / faceSize), - ) as Array<[number, number]>; -} diff --git a/web/apps/photos/src/utils/machineLearning/faceCrop.ts b/web/apps/photos/src/utils/machineLearning/faceCrop.ts deleted file mode 100644 index 6a9a1efcd0..0000000000 --- a/web/apps/photos/src/utils/machineLearning/faceCrop.ts +++ /dev/null @@ -1,28 +0,0 @@ -import { Box } from "services/ml/geom"; -import { FaceAlignment, FaceCrop, FaceCropConfig } from "services/ml/types"; -import { cropWithRotation } from "utils/image"; -import { enlargeBox } from "."; - -export function getFaceCrop( - imageBitmap: ImageBitmap, - alignment: FaceAlignment, - config: FaceCropConfig, -): FaceCrop { - const alignmentBox = new Box({ - x: alignment.center.x - alignment.size / 2, - y: alignment.center.y - alignment.size / 2, - width: alignment.size, - height: alignment.size, - }).round(); - const scaleForPadding = 1 + config.padding * 2; - const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round(); - const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, { - width: config.maxSize, - height: config.maxSize, - }); - - return { - image: faceImageBitmap, - imageBox: paddedBox, - }; -} diff --git a/web/apps/photos/src/utils/machineLearning/index.ts b/web/apps/photos/src/utils/machineLearning/index.ts deleted file mode 100644 index 987eabc00e..0000000000 --- a/web/apps/photos/src/utils/machineLearning/index.ts +++ /dev/null @@ -1,272 +0,0 @@ -import { FILE_TYPE } from "@/media/file-type"; -import { decodeLivePhoto } from "@/media/live-photo"; -import log from "@/next/log"; -import PQueue from "p-queue"; -import DownloadManager from "services/download"; -import { getLocalFiles } from "services/fileService"; -import { Box, Dimensions, Point, boxFromBoundingBox } from "services/ml/geom"; -import { - DetectedFace, - Face, - FaceAlignment, - MlFileData, - Person, - Versioned, -} from "services/ml/types"; -import { EnteFile } from "types/file"; -import { getRenderableImage } from "utils/file"; -import { clamp, warpAffineFloat32List } from "utils/image"; -import mlIDbStorage from "utils/storage/mlIDbStorage"; - -export function enlargeBox(box: Box, factor: number = 1.5) { - const center = new Point(box.x + box.width / 2, box.y + box.height / 2); - - const size = new Point(box.width, box.height); - const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2); - - return boxFromBoundingBox({ - left: center.x - newHalfSize.x, - top: center.y - newHalfSize.y, - right: center.x + newHalfSize.x, - bottom: center.y + newHalfSize.y, - }); -} - -export function getAllFacesFromMap(allFacesMap: Map>) { - const allFaces = [...allFacesMap.values()].flat(); - - return allFaces; -} - -export async function getLocalFile(fileId: number) { - const localFiles = await getLocalFiles(); - return localFiles.find((f) => f.id === fileId); -} - -export async function extractFaceImagesToFloat32( - faceAlignments: Array, - faceSize: number, - image: ImageBitmap, -): Promise { - const faceData = new Float32Array( - faceAlignments.length * faceSize * faceSize * 3, - ); - for (let i = 0; i < faceAlignments.length; i++) { - const alignedFace = faceAlignments[i]; - const faceDataOffset = i * faceSize * faceSize * 3; - warpAffineFloat32List( - image, - alignedFace, - faceSize, - faceData, - faceDataOffset, - ); - } - return faceData; -} - -export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) { - const xMin = clamp( - detectedFace.detection.box.x / imageDims.width, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); - const yMin = clamp( - detectedFace.detection.box.y / imageDims.height, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); - const xMax = clamp( - (detectedFace.detection.box.x + detectedFace.detection.box.width) / - imageDims.width, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); - const yMax = clamp( - (detectedFace.detection.box.y + detectedFace.detection.box.height) / - imageDims.height, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); - - const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`; - const faceID = `${detectedFace.fileId}_${rawFaceID}`; - - return faceID; -} - -export async function getImageBlobBitmap(blob: Blob): Promise { - return await createImageBitmap(blob); -} - -async function getOriginalFile(file: EnteFile, queue?: PQueue) { - let fileStream; - if (queue) { - fileStream = await queue.add(() => DownloadManager.getFile(file)); - } else { - fileStream = await DownloadManager.getFile(file); - } - return new Response(fileStream).blob(); -} - -async function getOriginalConvertedFile(file: EnteFile, queue?: PQueue) { - const fileBlob = await getOriginalFile(file, queue); - if (file.metadata.fileType === FILE_TYPE.IMAGE) { - return await getRenderableImage(file.metadata.title, fileBlob); - } else { - const { imageFileName, imageData } = await decodeLivePhoto( - file.metadata.title, - fileBlob, - ); - return await getRenderableImage(imageFileName, new Blob([imageData])); - } -} - -export async function getOriginalImageBitmap(file: EnteFile, queue?: PQueue) { - const fileBlob = await getOriginalConvertedFile(file, queue); - log.info("[MLService] Got file: ", file.id.toString()); - return getImageBlobBitmap(fileBlob); -} - -export async function getThumbnailImageBitmap(file: EnteFile) { - const thumb = await DownloadManager.getThumbnail(file); - log.info("[MLService] Got thumbnail: ", file.id.toString()); - - return getImageBlobBitmap(new Blob([thumb])); -} - -export async function getLocalFileImageBitmap( - enteFile: EnteFile, - localFile: globalThis.File, -) { - let fileBlob = localFile as Blob; - fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob); - return getImageBlobBitmap(fileBlob); -} - -export async function getPeopleList(file: EnteFile): Promise> { - let startTime = Date.now(); - const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); - log.info( - "getPeopleList:mlFilesStore:getItem", - Date.now() - startTime, - "ms", - ); - if (!mlFileData?.faces || mlFileData.faces.length < 1) { - return []; - } - - const peopleIds = mlFileData.faces - .filter((f) => f.personId !== null && f.personId !== undefined) - .map((f) => f.personId); - if (!peopleIds || peopleIds.length < 1) { - return []; - } - // log.info("peopleIds: ", peopleIds); - startTime = Date.now(); - const peoplePromises = peopleIds.map( - (p) => mlIDbStorage.getPerson(p) as Promise, - ); - const peopleList = await Promise.all(peoplePromises); - log.info( - "getPeopleList:mlPeopleStore:getItems", - Date.now() - startTime, - "ms", - ); - // log.info("peopleList: ", peopleList); - - return peopleList; -} - -export async function getUnidentifiedFaces( - file: EnteFile, -): Promise> { - const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); - - return mlFileData?.faces?.filter( - (f) => f.personId === null || f.personId === undefined, - ); -} - -export async function getAllPeople(limit: number = undefined) { - let people: Array = await mlIDbStorage.getAllPeople(); - // await mlPeopleStore.iterate((person) => { - // people.push(person); - // }); - people = people ?? []; - return people - .sort((p1, p2) => p2.files.length - p1.files.length) - .slice(0, limit); -} - -export function findFirstIfSorted( - elements: Array, - comparator: (a: T, b: T) => number, -) { - if (!elements || elements.length < 1) { - return; - } - let first = elements[0]; - - for (let i = 1; i < elements.length; i++) { - const comp = comparator(elements[i], first); - if (comp < 0) { - first = elements[i]; - } - } - - return first; -} - -export function isDifferentOrOld( - method: Versioned, - thanMethod: Versioned, -) { - return ( - !method || - method.value !== thanMethod.value || - method.version < thanMethod.version - ); -} - -function primitiveArrayEquals(a, b) { - return ( - Array.isArray(a) && - Array.isArray(b) && - a.length === b.length && - a.every((val, index) => val === b[index]) - ); -} - -export function areFaceIdsSame(ofFaces: Array, toFaces: Array) { - if ( - (ofFaces === null || ofFaces === undefined) && - (toFaces === null || toFaces === undefined) - ) { - return true; - } - return primitiveArrayEquals( - ofFaces?.map((f) => f.id), - toFaces?.map((f) => f.id), - ); -} - -export function logQueueStats(queue: PQueue, name: string) { - queue.on("active", () => - log.info( - `queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`, - ), - ); - queue.on("idle", () => log.info(`queuestats: ${name}: Idle`)); - queue.on("error", (error) => - console.error(`queuestats: ${name}: Error, `, error), - ); -} diff --git a/web/apps/photos/src/utils/machineLearning/mldataMappers.ts b/web/apps/photos/src/utils/machineLearning/mldataMappers.ts deleted file mode 100644 index 520fa94d53..0000000000 --- a/web/apps/photos/src/utils/machineLearning/mldataMappers.ts +++ /dev/null @@ -1,177 +0,0 @@ -import { - ClipEmbedding, - Face, - FaceDetection, - Landmark, - MlFileData, -} from "services/ml/types"; - -export interface FileML extends ServerFileMl { - updatedAt: number; -} - -class ServerFileMl { - public fileID: number; - public height?: number; - public width?: number; - public faceEmbedding: ServerFaceEmbeddings; - public clipEmbedding?: ClipEmbedding; - - public constructor( - fileID: number, - faceEmbedding: ServerFaceEmbeddings, - clipEmbedding?: ClipEmbedding, - height?: number, - width?: number, - ) { - this.fileID = fileID; - this.height = height; - this.width = width; - this.faceEmbedding = faceEmbedding; - this.clipEmbedding = clipEmbedding; - } -} - -class ServerFaceEmbeddings { - public faces: ServerFace[]; - public version: number; - public client?: string; - public error?: boolean; - - public constructor( - faces: ServerFace[], - version: number, - client?: string, - error?: boolean, - ) { - this.faces = faces; - this.version = version; - this.client = client; - this.error = error; - } -} - -class ServerFace { - public fileID: number; - public faceID: string; - public embeddings: number[]; - public detection: ServerDetection; - public score: number; - public blur: number; - public fileInfo?: ServerFileInfo; - - public constructor( - fileID: number, - faceID: string, - embeddings: number[], - detection: ServerDetection, - score: number, - blur: number, - fileInfo?: ServerFileInfo, - ) { - this.fileID = fileID; - this.faceID = faceID; - this.embeddings = embeddings; - this.detection = detection; - this.score = score; - this.blur = blur; - this.fileInfo = fileInfo; - } -} - -class ServerFileInfo { - public imageWidth?: number; - public imageHeight?: number; - - public constructor(imageWidth?: number, imageHeight?: number) { - this.imageWidth = imageWidth; - this.imageHeight = imageHeight; - } -} - -class ServerDetection { - public box: ServerFaceBox; - public landmarks: Landmark[]; - - public constructor(box: ServerFaceBox, landmarks: Landmark[]) { - this.box = box; - this.landmarks = landmarks; - } -} - -class ServerFaceBox { - public xMin: number; - public yMin: number; - public width: number; - public height: number; - - public constructor( - xMin: number, - yMin: number, - width: number, - height: number, - ) { - this.xMin = xMin; - this.yMin = yMin; - this.width = width; - this.height = height; - } -} - -export function LocalFileMlDataToServerFileMl( - localFileMlData: MlFileData, -): ServerFileMl { - if ( - localFileMlData.errorCount > 0 && - localFileMlData.lastErrorMessage !== undefined - ) { - return null; - } - const imageDimensions = localFileMlData.imageDimensions; - const fileInfo = new ServerFileInfo( - imageDimensions.width, - imageDimensions.height, - ); - const faces: ServerFace[] = []; - for (let i = 0; i < localFileMlData.faces.length; i++) { - const face: Face = localFileMlData.faces[i]; - const faceID = face.id; - const embedding = face.embedding; - const score = face.detection.probability; - const blur = face.blurValue; - const detection: FaceDetection = face.detection; - const box = detection.box; - const landmarks = detection.landmarks; - const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height); - const newLandmarks: Landmark[] = []; - for (let j = 0; j < landmarks.length; j++) { - newLandmarks.push({ - x: landmarks[j].x, - y: landmarks[j].y, - } as Landmark); - } - - const newFaceObject = new ServerFace( - localFileMlData.fileId, - faceID, - Array.from(embedding), - new ServerDetection(newBox, newLandmarks), - score, - blur, - fileInfo, - ); - faces.push(newFaceObject); - } - const faceEmbeddings = new ServerFaceEmbeddings( - faces, - 1, - localFileMlData.lastErrorMessage, - ); - return new ServerFileMl( - localFileMlData.fileId, - faceEmbeddings, - null, - imageDimensions.height, - imageDimensions.width, - ); -} diff --git a/web/apps/photos/src/utils/machineLearning/transform.ts b/web/apps/photos/src/utils/machineLearning/transform.ts deleted file mode 100644 index e69de29bb2..0000000000