diff --git a/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts b/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts deleted file mode 100644 index 999bae8fe9..0000000000 --- a/web/apps/photos/src/services/machineLearning/machineLearningFactory.ts +++ /dev/null @@ -1,216 +0,0 @@ -import { haveWindow } from "@/next/env"; -import log from "@/next/log"; -import { ComlinkWorker } from "@/next/worker/comlink-worker"; -import { getDedicatedCryptoWorker } from "@ente/shared/crypto"; -import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker"; -import PQueue from "p-queue"; -import { - BlurDetectionMethod, - BlurDetectionService, - ClusteringMethod, - ClusteringService, - Face, - FaceAlignmentMethod, - FaceAlignmentService, - FaceCropMethod, - FaceCropService, - FaceDetectionMethod, - FaceDetectionService, - FaceEmbeddingMethod, - FaceEmbeddingService, - MLLibraryData, - MLSyncConfig, - MLSyncContext, -} from "services/ml/types"; -import { EnteFile } from "types/file"; -import { logQueueStats } from "utils/machineLearning"; -import arcfaceAlignmentService from "./arcfaceAlignmentService"; -import arcfaceCropService from "./arcfaceCropService"; -import dbscanClusteringService from "./dbscanClusteringService"; -import hdbscanClusteringService from "./hdbscanClusteringService"; -import laplacianBlurDetectionService from "./laplacianBlurDetectionService"; -import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService"; -import yoloFaceDetectionService from "./yoloFaceDetectionService"; - -export class MLFactory { - public static getFaceDetectionService( - method: FaceDetectionMethod, - ): FaceDetectionService { - if (method === "YoloFace") { - return yoloFaceDetectionService; - } - - throw Error("Unknon face detection method: " + method); - } - - public static getFaceCropService(method: FaceCropMethod) { - if (method === "ArcFace") { - return arcfaceCropService; - } - - throw Error("Unknon face crop method: " + method); - } - - public static getFaceAlignmentService( - method: FaceAlignmentMethod, - ): FaceAlignmentService { - if (method === "ArcFace") { - return arcfaceAlignmentService; - } - - throw Error("Unknon face alignment method: " + method); - } - - public static getBlurDetectionService( - method: BlurDetectionMethod, - ): BlurDetectionService { - if (method === "Laplacian") { - return laplacianBlurDetectionService; - } - - throw Error("Unknon blur detection method: " + method); - } - - public static getFaceEmbeddingService( - method: FaceEmbeddingMethod, - ): FaceEmbeddingService { - if (method === "MobileFaceNet") { - return mobileFaceNetEmbeddingService; - } - - throw Error("Unknon face embedding method: " + method); - } - - public static getClusteringService( - method: ClusteringMethod, - ): ClusteringService { - if (method === "Hdbscan") { - return hdbscanClusteringService; - } - if (method === "Dbscan") { - return dbscanClusteringService; - } - - throw Error("Unknon clustering method: " + method); - } - - public static getMLSyncContext( - token: string, - userID: number, - config: MLSyncConfig, - shouldUpdateMLVersion: boolean = true, - ) { - return new LocalMLSyncContext( - token, - userID, - config, - shouldUpdateMLVersion, - ); - } -} - -export class LocalMLSyncContext implements MLSyncContext { - public token: string; - public userID: number; - public config: MLSyncConfig; - public shouldUpdateMLVersion: boolean; - - public faceDetectionService: FaceDetectionService; - public faceCropService: FaceCropService; - public faceAlignmentService: FaceAlignmentService; - public blurDetectionService: BlurDetectionService; - public faceEmbeddingService: FaceEmbeddingService; - public faceClusteringService: ClusteringService; - - public localFilesMap: Map; - public outOfSyncFiles: EnteFile[]; - public nSyncedFiles: number; - public nSyncedFaces: number; - public allSyncedFacesMap?: Map>; - - public error?: Error; - - public mlLibraryData: MLLibraryData; - - public syncQueue: PQueue; - // TODO: wheather to limit concurrent downloads - // private downloadQueue: PQueue; - - private concurrency: number; - private comlinkCryptoWorker: Array< - ComlinkWorker - >; - private enteWorkers: Array; - - constructor( - token: string, - userID: number, - config: MLSyncConfig, - shouldUpdateMLVersion: boolean = true, - concurrency?: number, - ) { - this.token = token; - this.userID = userID; - this.config = config; - this.shouldUpdateMLVersion = shouldUpdateMLVersion; - - this.faceDetectionService = MLFactory.getFaceDetectionService( - this.config.faceDetection.method, - ); - this.faceCropService = MLFactory.getFaceCropService( - this.config.faceCrop.method, - ); - this.faceAlignmentService = MLFactory.getFaceAlignmentService( - this.config.faceAlignment.method, - ); - this.blurDetectionService = MLFactory.getBlurDetectionService( - this.config.blurDetection.method, - ); - this.faceEmbeddingService = MLFactory.getFaceEmbeddingService( - this.config.faceEmbedding.method, - ); - this.faceClusteringService = MLFactory.getClusteringService( - this.config.faceClustering.method, - ); - - this.outOfSyncFiles = []; - this.nSyncedFiles = 0; - this.nSyncedFaces = 0; - - this.concurrency = concurrency ?? getConcurrency(); - - log.info("Using concurrency: ", this.concurrency); - // timeout is added on downloads - // timeout on queue will keep the operation open till worker is terminated - this.syncQueue = new PQueue({ concurrency: this.concurrency }); - logQueueStats(this.syncQueue, "sync"); - // this.downloadQueue = new PQueue({ concurrency: 1 }); - // logQueueStats(this.downloadQueue, 'download'); - - this.comlinkCryptoWorker = new Array(this.concurrency); - this.enteWorkers = new Array(this.concurrency); - } - - public async getEnteWorker(id: number): Promise { - const wid = id % this.enteWorkers.length; - console.log("getEnteWorker: ", id, wid); - if (!this.enteWorkers[wid]) { - this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker(); - this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote; - } - - return this.enteWorkers[wid]; - } - - public async dispose() { - this.localFilesMap = undefined; - await this.syncQueue.onIdle(); - this.syncQueue.removeAllListeners(); - for (const enteComlinkWorker of this.comlinkCryptoWorker) { - enteComlinkWorker?.terminate(); - } - } -} - -export const getConcurrency = () => - haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2)); diff --git a/web/apps/photos/src/services/machineLearning/machineLearningService.ts b/web/apps/photos/src/services/machineLearning/machineLearningService.ts index 99c9ddf98c..5050b9ff7a 100644 --- a/web/apps/photos/src/services/machineLearning/machineLearningService.ts +++ b/web/apps/photos/src/services/machineLearning/machineLearningService.ts @@ -26,10 +26,38 @@ import mlIDbStorage, { } from "utils/storage/mlIDbStorage"; import { isInternalUserForML } from "utils/user"; import FaceService from "./faceService"; -import { MLFactory } from "./machineLearningFactory"; import PeopleService from "./peopleService"; import ReaderService from "./readerService"; +import { haveWindow } from "@/next/env"; +import { ComlinkWorker } from "@/next/worker/comlink-worker"; +import { getDedicatedCryptoWorker } from "@ente/shared/crypto"; +import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker"; +import PQueue from "p-queue"; +import { + BlurDetectionMethod, + BlurDetectionService, + ClusteringMethod, + ClusteringService, + FaceAlignmentMethod, + FaceAlignmentService, + FaceCropMethod, + FaceCropService, + FaceDetectionMethod, + FaceDetectionService, + FaceEmbeddingMethod, + FaceEmbeddingService, + MLLibraryData, +} from "services/ml/types"; +import { logQueueStats } from "utils/machineLearning"; +import arcfaceAlignmentService from "./arcfaceAlignmentService"; +import arcfaceCropService from "./arcfaceCropService"; +import dbscanClusteringService from "./dbscanClusteringService"; +import hdbscanClusteringService from "./hdbscanClusteringService"; +import laplacianBlurDetectionService from "./laplacianBlurDetectionService"; +import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService"; +import yoloFaceDetectionService from "./yoloFaceDetectionService"; + export const DEFAULT_ML_SYNC_JOB_CONFIG: JobConfig = { intervalSec: 5, // TODO: finalize this after seeing effects on and from machine sleep @@ -119,6 +147,189 @@ export async function updateMLSearchConfig(newConfig: MLSearchConfig) { return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig); } +export class MLFactory { + public static getFaceDetectionService( + method: FaceDetectionMethod, + ): FaceDetectionService { + if (method === "YoloFace") { + return yoloFaceDetectionService; + } + + throw Error("Unknon face detection method: " + method); + } + + public static getFaceCropService(method: FaceCropMethod) { + if (method === "ArcFace") { + return arcfaceCropService; + } + + throw Error("Unknon face crop method: " + method); + } + + public static getFaceAlignmentService( + method: FaceAlignmentMethod, + ): FaceAlignmentService { + if (method === "ArcFace") { + return arcfaceAlignmentService; + } + + throw Error("Unknon face alignment method: " + method); + } + + public static getBlurDetectionService( + method: BlurDetectionMethod, + ): BlurDetectionService { + if (method === "Laplacian") { + return laplacianBlurDetectionService; + } + + throw Error("Unknon blur detection method: " + method); + } + + public static getFaceEmbeddingService( + method: FaceEmbeddingMethod, + ): FaceEmbeddingService { + if (method === "MobileFaceNet") { + return mobileFaceNetEmbeddingService; + } + + throw Error("Unknon face embedding method: " + method); + } + + public static getClusteringService( + method: ClusteringMethod, + ): ClusteringService { + if (method === "Hdbscan") { + return hdbscanClusteringService; + } + if (method === "Dbscan") { + return dbscanClusteringService; + } + + throw Error("Unknon clustering method: " + method); + } + + public static getMLSyncContext( + token: string, + userID: number, + config: MLSyncConfig, + shouldUpdateMLVersion: boolean = true, + ) { + return new LocalMLSyncContext( + token, + userID, + config, + shouldUpdateMLVersion, + ); + } +} + +export class LocalMLSyncContext implements MLSyncContext { + public token: string; + public userID: number; + public config: MLSyncConfig; + public shouldUpdateMLVersion: boolean; + + public faceDetectionService: FaceDetectionService; + public faceCropService: FaceCropService; + public faceAlignmentService: FaceAlignmentService; + public blurDetectionService: BlurDetectionService; + public faceEmbeddingService: FaceEmbeddingService; + public faceClusteringService: ClusteringService; + + public localFilesMap: Map; + public outOfSyncFiles: EnteFile[]; + public nSyncedFiles: number; + public nSyncedFaces: number; + public allSyncedFacesMap?: Map>; + + public error?: Error; + + public mlLibraryData: MLLibraryData; + + public syncQueue: PQueue; + // TODO: wheather to limit concurrent downloads + // private downloadQueue: PQueue; + + private concurrency: number; + private comlinkCryptoWorker: Array< + ComlinkWorker + >; + private enteWorkers: Array; + + constructor( + token: string, + userID: number, + config: MLSyncConfig, + shouldUpdateMLVersion: boolean = true, + concurrency?: number, + ) { + this.token = token; + this.userID = userID; + this.config = config; + this.shouldUpdateMLVersion = shouldUpdateMLVersion; + + this.faceDetectionService = MLFactory.getFaceDetectionService( + this.config.faceDetection.method, + ); + this.faceCropService = MLFactory.getFaceCropService( + this.config.faceCrop.method, + ); + this.faceAlignmentService = MLFactory.getFaceAlignmentService( + this.config.faceAlignment.method, + ); + this.blurDetectionService = MLFactory.getBlurDetectionService( + this.config.blurDetection.method, + ); + this.faceEmbeddingService = MLFactory.getFaceEmbeddingService( + this.config.faceEmbedding.method, + ); + this.faceClusteringService = MLFactory.getClusteringService( + this.config.faceClustering.method, + ); + + this.outOfSyncFiles = []; + this.nSyncedFiles = 0; + this.nSyncedFaces = 0; + + this.concurrency = concurrency ?? getConcurrency(); + + log.info("Using concurrency: ", this.concurrency); + // timeout is added on downloads + // timeout on queue will keep the operation open till worker is terminated + this.syncQueue = new PQueue({ concurrency: this.concurrency }); + logQueueStats(this.syncQueue, "sync"); + // this.downloadQueue = new PQueue({ concurrency: 1 }); + // logQueueStats(this.downloadQueue, 'download'); + + this.comlinkCryptoWorker = new Array(this.concurrency); + this.enteWorkers = new Array(this.concurrency); + } + + public async getEnteWorker(id: number): Promise { + const wid = id % this.enteWorkers.length; + console.log("getEnteWorker: ", id, wid); + if (!this.enteWorkers[wid]) { + this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker(); + this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote; + } + + return this.enteWorkers[wid]; + } + + public async dispose() { + this.localFilesMap = undefined; + await this.syncQueue.onIdle(); + this.syncQueue.removeAllListeners(); + for (const enteComlinkWorker of this.comlinkCryptoWorker) { + enteComlinkWorker?.terminate(); + } + } +} + +export const getConcurrency = () => + haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2)); + class MachineLearningService { private localSyncContext: Promise; private syncContext: Promise;