From e134d599b7451b4079c230a375021c4805043074 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 24 Jul 2024 13:45:18 +0530 Subject: [PATCH] [mob] Use stream to optimize processing --- .../face_ml/face_recognition_service.dart | 44 +++++++++++-------- .../file_ml/files_ml_data_response.dart | 13 ++++++ .../services/machine_learning/ml_service.dart | 19 ++++---- 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart index db5bf2c047..392a9df20f 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart @@ -67,7 +67,6 @@ class FaceRecognitionService { Future sync() async { await _syncPersonFeedback(); if (LocalSettings.instance.remoteFetchEnabled) { - await _syncFaceEmbeddings(); } else { _logger.severe( 'Not fetching embeddings because user manually disabled it in debug options', @@ -88,22 +87,24 @@ class FaceRecognitionService { _isSyncing = false; } - Future> _syncFaceEmbeddings() async { + Stream> syncEmbeddings({ + int yieldSize = 10, + }) async* { final List filesToIndex = await getFilesForMlIndexing(); - final Map pendingIndex = {}; final List> chunks = - filesToIndex.chunks(_embeddingFetchLimit); // Chunks of 200 - int filesIndexedForFaces = 0; - int filesIndexedForClip = 0; + filesToIndex.chunks(_embeddingFetchLimit); + List batchToYield = []; + for (final chunk in chunks) { final Set ids = {}; + final Map pendingIndex = {}; for (final instruction in chunk) { ids.add(instruction.file.uploadedFileID!); pendingIndex[instruction.file.uploadedFileID!] = instruction; } - _logger.info('starting remote fetch for ${ids.length} files'); + _logger.info("fetching embeddings for ${ids.length} files"); final res = await RemoteFileMLService.instance.getFileEmbeddings(ids); - _logger.info('fetched ${res.mlData.length} embeddings'); + _logger.info("embeddingResponse ${res.debugLog()}"); final List faces = []; final List clipEmbeddings = []; for (RemoteFileML fileMl in res.mlData.values) { @@ -112,7 +113,6 @@ class FaceRecognitionService { //Note: Always do null check, empty value means no face was found. if (facesFromRemoteEmbedding != null) { faces.addAll(facesFromRemoteEmbedding); - filesIndexedForFaces++; existingInstruction.shouldRunFaces = false; } if (fileMl.clipEmbedding != null && @@ -125,7 +125,6 @@ class FaceRecognitionService { ), ); existingInstruction.shouldRunClip = false; - filesIndexedForClip++; } if (!existingInstruction.pendingML) { pendingIndex.remove(fileMl.fileID); @@ -134,23 +133,30 @@ class FaceRecognitionService { pendingIndex[fileMl.fileID] = existingInstruction; } } - + for (final fileID in pendingIndex.keys) { + final instruction = pendingIndex[fileID]!; + if (instruction.pendingML) { + batchToYield.add(instruction); + if (batchToYield.length == yieldSize) { + _logger.info("queueing indexing for $yieldSize"); + yield batchToYield; + batchToYield = []; + } + } + } if (res.noEmbeddingFileIDs.isNotEmpty) { - _logger.info( - 'No embeddings found for ${res.noEmbeddingFileIDs.length} files', - ); for (final fileID in res.noEmbeddingFileIDs) { faces.add(Face.empty(fileID, error: false)); } } await FaceMLDataDB.instance.bulkInsertFaces(faces); await EmbeddingsDB.instance.putMany(clipEmbeddings); - _logger.info( - 'Embedding store files for face $filesIndexedForFaces, and clip $filesIndexedForClip', - ); } - final List instructions = pendingIndex.values.toList(); - return instructions; + // Yield any remaining instructions + if (batchToYield.isNotEmpty) { + _logger.info("queueing indexing for $batchToYield.length"); + yield batchToYield; + } } // Returns a list of faces from the given remote fileML. null if the version is less than the current version diff --git a/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart b/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart index 67b69dcf5c..56da3adc58 100644 --- a/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart +++ b/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart @@ -16,4 +16,17 @@ class FilesMLDataResponse { required this.fetchErrorFileIDs, required this.pendingIndexFileIDs, }); + + String debugLog() { + final nonZeroNoEmbeddingFileIDs = noEmbeddingFileIDs.isNotEmpty + ? ', smallEmbeddings: ${noEmbeddingFileIDs.length}' + : ''; + final nonZeroFetchErrorFileIDs = fetchErrorFileIDs.isNotEmpty + ? ', errorForFileIDs: ${fetchErrorFileIDs.length}' + : ''; + final nonZeroPendingIndexFileIDs = pendingIndexFileIDs.isNotEmpty + ? ', pendingIndexFileIDs: ${pendingIndexFileIDs.length}' + : ''; + return 'MLRemote(mlData: ${mlData.length}$nonZeroNoEmbeddingFileIDs$nonZeroFetchErrorFileIDs$nonZeroPendingIndexFileIDs)'; + } } diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 6c2d6f7074..3f2e96aac3 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -14,7 +14,6 @@ import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; import "package:photos/events/machine_learning_control_event.dart"; import "package:photos/events/people_changed_event.dart"; -import "package:photos/extensions/list.dart"; import "package:photos/face/db.dart"; import "package:photos/face/model/box.dart"; import "package:photos/face/model/detection.dart" as face_detection; @@ -61,7 +60,9 @@ class MLService { // Singleton pattern MLService._privateConstructor(); + static final instance = MLService._privateConstructor(); + factory MLService() => instance; final _initModelLock = Lock(); @@ -175,26 +176,25 @@ class MLService { try { _isIndexingOrClusteringRunning = true; _logger.info('starting image indexing'); - - final filesToIndex = await getFilesForMlIndexing(); - final List> chunks = - filesToIndex.chunks(_fileDownloadLimit); + final Stream> instructionStream = + FaceRecognitionService.instance + .syncEmbeddings(yieldSize: _fileDownloadLimit); int fileAnalyzedCount = 0; final Stopwatch stopwatch = Stopwatch()..start(); - outerLoop: - for (final chunk in chunks) { + + await for (final chunk in instructionStream) { if (!await canUseHighBandwidth()) { _logger.info( 'stopping indexing because user is not connected to wifi', ); - break outerLoop; + break; } final futures = >[]; for (final instruction in chunk) { if (_shouldPauseIndexingAndClustering) { _logger.info("indexAllImages() was paused, stopping"); - break outerLoop; + break; } await _ensureReadyForInference(); futures.add(processImage(instruction)); @@ -206,7 +206,6 @@ class MLService { ); fileAnalyzedCount += sumFutures; } - _logger.info( "`indexAllImages()` finished. Analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image)", );