From 68801398885e56848968cb27139fc73f797f79d1 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 24 Jul 2024 15:53:26 +0200 Subject: [PATCH] [mob][photos] Separate ML Isolate --- .../services/machine_learning/ml_isolate.dart | 265 +++++++++++++++++ .../services/machine_learning/ml_service.dart | 276 ++---------------- .../semantic_search_service.dart | 4 +- 3 files changed, 284 insertions(+), 261 deletions(-) create mode 100644 mobile/lib/services/machine_learning/ml_isolate.dart diff --git a/mobile/lib/services/machine_learning/ml_isolate.dart b/mobile/lib/services/machine_learning/ml_isolate.dart new file mode 100644 index 0000000000..720f45d8c4 --- /dev/null +++ b/mobile/lib/services/machine_learning/ml_isolate.dart @@ -0,0 +1,265 @@ +import "dart:async"; +import "dart:io" show Platform; +import "dart:isolate"; + +import "package:dart_ui_isolate/dart_ui_isolate.dart"; +import "package:flutter/foundation.dart" show debugPrint, kDebugMode; +import "package:logging/logging.dart"; +import "package:photos/core/error-reporting/super_logging.dart"; +import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart'; +import 'package:photos/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart'; +import 'package:photos/services/machine_learning/ml_result.dart'; +import "package:photos/services/machine_learning/ml_service.dart"; +import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; +import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart"; +import "package:photos/services/remote_assets_service.dart"; +import "package:photos/utils/ml_util.dart"; +import "package:synchronized/synchronized.dart"; + +enum MLOperation { analyzeImage, loadModels, runClipText } + +class MLIsolate { + static final _logger = Logger("MLIsolate"); + + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 120); + int _activeTasks = 0; + + final _functionLock = Lock(); + final _initIsolateLock = Lock(); + + late DartUiIsolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool _isIsolateSpawned = false; + + // Singleton pattern + MLIsolate._privateConstructor(); + static final instance = MLIsolate._privateConstructor(); + factory MLIsolate() => instance; + + Future _initIsolate() async { + return _initIsolateLock.synchronized(() async { + if (_isIsolateSpawned) return; + _logger.info("initIsolate called"); + + _receivePort = ReceivePort(); + + try { + _isolate = await DartUiIsolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + _isIsolateSpawned = true; + + _resetInactivityTimer(); + _logger.info('initIsolate done'); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + _isIsolateSpawned = false; + } + }); + } + + /// The main execution function of the isolate. + @pragma('vm:entry-point') + static void _isolateMain(SendPort mainSendPort) async { + Logger.root.level = kDebugMode ? Level.ALL : Level.INFO; + Logger.root.onRecord.listen((LogRecord rec) { + debugPrint('[MLIsolate] ${rec.toPrettyString()}'); + }); + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = MLOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case MLOperation.analyzeImage: + final time = DateTime.now(); + final MLResult result = await MLService.analyzeImageSync(args); + _logger.info( + "`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms", + ); + sendPort.send(result.toJsonString()); + break; + case MLOperation.loadModels: + await FaceDetectionService.instance.loadModel(useEntePlugin: true); + await FaceEmbeddingService.instance.loadModel(useEntePlugin: true); + await ClipImageEncoder.instance.loadModel(useEntePlugin: true); + sendPort.send(true); + break; + case MLOperation.runClipText: + final textEmbedding = await ClipTextEncoder.predict(args); + sendPort.send(List.from(textEmbedding, growable: false)); + break; + } + } catch (e, s) { + _logger.severe("Error in FaceML isolate", e, s); + sendPort.send({'error': e.toString(), 'stackTrace': s.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (MLOperation, Map) message, + ) async { + await _initIsolate(); + return _functionLock.synchronized(() async { + _resetInactivityTimer(); + + if (message.$1 == MLOperation.analyzeImage && + MLService.instance.shouldPauseIndexingAndClustering) { + return null; + } + + final completer = Completer(); + final answerPort = ReceivePort(); + + _activeTasks++; + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + _activeTasks--; + + return completer.future; + }); + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + if (_activeTasks > 0) { + _logger.info('Tasks are still running. Delaying isolate disposal.'); + // Optionally, reschedule the timer to check again later. + _resetInactivityTimer(); + } else { + _logger.info( + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); + _dispose(); + } + }); + } + + void _dispose() async { + if (!_isIsolateSpawned) return; + _logger.info('Disposing isolate and models'); + // await _releaseModels(); TODO: Implement this + _isIsolateSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + /// Analyzes the given image data by running the full pipeline for faces, using [_analyzeImageSync] in the isolate. + Future analyzeImage( + FileMLInstruction instruction, + ) async { + final String filePath = await getImagePathForML(instruction.enteFile); + + final Stopwatch stopwatch = Stopwatch()..start(); + late MLResult result; + + try { + final resultJsonString = await _runInIsolate( + ( + MLOperation.analyzeImage, + { + "enteFileID": instruction.enteFile.uploadedFileID ?? -1, + "filePath": filePath, + "runFaces": instruction.shouldRunFaces, + "runClip": instruction.shouldRunClip, + "faceDetectionAddress": + FaceDetectionService.instance.sessionAddress, + "faceEmbeddingAddress": + FaceEmbeddingService.instance.sessionAddress, + "clipImageAddress": ClipImageEncoder.instance.sessionAddress, + } + ), + ) as String?; + if (resultJsonString == null) { + if (!MLService.instance.shouldPauseIndexingAndClustering) { + _logger.severe('Analyzing image in isolate is giving back null'); + } + return null; + } + result = MLResult.fromJsonString(resultJsonString); + } catch (e, s) { + _logger.severe( + "Could not analyze image with ID ${instruction.enteFile.uploadedFileID} \n", + e, + s, + ); + debugPrint( + "This image with ID ${instruction.enteFile.uploadedFileID} has name ${instruction.enteFile.displayName}.", + ); + final resultBuilder = + MLResult.fromEnteFileID(instruction.enteFile.uploadedFileID!) + ..errorOccurred(); + return resultBuilder; + } + stopwatch.stop(); + _logger.info( + "Finished Analyze image with uploadedFileID ${instruction.enteFile.uploadedFileID}, in " + "${stopwatch.elapsedMilliseconds} ms (including time waiting for inference engine availability)", + ); + + return result; + } + + Future> runClipText(String query) async { + try { + final int clipAddress = ClipTextEncoder.instance.sessionAddress; + final String remotePath = ClipTextEncoder.instance.vocabRemotePath; + final String tokenizerVocabPath = + await RemoteAssetsService.instance.getAssetPath(remotePath); + final textEmbedding = await _runInIsolate( + ( + MLOperation.runClipText, + { + "text": query, + "address": clipAddress, + "vocabPath": tokenizerVocabPath, + "useEntePlugin": Platform.isAndroid, + } + ), + ) as List; + return textEmbedding; + } catch (e, s) { + _logger.severe("Could not run clip text in isolate", e, s); + rethrow; + } + } + + Future loadModels() async { + try { + await _runInIsolate( + (MLOperation.loadModels, {}), + ); + } catch (e, s) { + _logger.severe("Could not load models in isolate", e, s); + rethrow; + } + } +} diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index fd4a31e796..1f9fe06cf9 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -1,15 +1,12 @@ import "dart:async"; import "dart:developer" as dev show log; import "dart:io" show File, Platform; -import "dart:isolate"; import "dart:math" show min; import "dart:typed_data" show Uint8List, ByteData; -import "package:dart_ui_isolate/dart_ui_isolate.dart"; -import "package:flutter/foundation.dart" show debugPrint, kDebugMode; +import "package:flutter/foundation.dart" show debugPrint; import "package:logging/logging.dart"; import "package:package_info_plus/package_info_plus.dart"; -import "package:photos/core/error-reporting/super_logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; import "package:photos/events/machine_learning_control_event.dart"; @@ -32,34 +29,18 @@ import "package:photos/services/machine_learning/face_ml/person/person_service.d import 'package:photos/services/machine_learning/file_ml/file_ml.dart'; import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; import 'package:photos/services/machine_learning/ml_exceptions.dart'; +import "package:photos/services/machine_learning/ml_isolate.dart"; import 'package:photos/services/machine_learning/ml_result.dart'; import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; -import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart"; import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; -import "package:photos/services/remote_assets_service.dart"; import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/local_settings.dart"; import "package:photos/utils/ml_util.dart"; import "package:photos/utils/network_util.dart"; import "package:synchronized/synchronized.dart"; -enum FaceMlOperation { analyzeImage, loadModels, runClipText } - -/// This class is responsible for running the full face ml pipeline on images. -/// -/// WARNING: For getting the ML results needed for the UI, you should use `FaceSearchService` instead of this class! -/// -/// The pipeline consists of face detection, face alignment and face embedding. class MLService { - final _logger = Logger("FaceMlService"); - - // Flutter isolate things for running the image ml pipeline - Timer? _inactivityTimer; - final Duration _inactivityDuration = const Duration(seconds: 120); - int _activeTasks = 0; - late DartUiIsolate _isolate; - late ReceivePort _receivePort = ReceivePort(); - late SendPort _mainSendPort; + final _logger = Logger("MLService"); // Singleton pattern MLService._privateConstructor(); @@ -67,13 +48,10 @@ class MLService { factory MLService() => instance; final _initModelLock = Lock(); - final _functionLock = Lock(); - final _initIsolateLock = Lock(); bool _isInitialized = false; bool _isModelsInitialized = false; bool _isModelsInitUsingEntePlugin = false; - bool _isIsolateSpawned = false; late String client; @@ -87,7 +65,7 @@ class MLService { bool _showClusteringIsHappening = false; bool _mlControllerStatus = false; bool _isIndexingOrClusteringRunning = false; - bool _shouldPauseIndexingAndClustering = false; + bool shouldPauseIndexingAndClustering = false; static const int _fileDownloadLimit = 10; static const _kForceClusteringFaceCount = 8000; @@ -110,8 +88,8 @@ class MLService { } _mlControllerStatus = event.shouldRun; if (_mlControllerStatus) { - if (_shouldPauseIndexingAndClustering) { - _shouldPauseIndexingAndClustering = false; + if (shouldPauseIndexingAndClustering) { + shouldPauseIndexingAndClustering = false; _logger.info( "MLController allowed running ML, faces indexing undoing previous pause", ); @@ -160,7 +138,7 @@ class MLService { void pauseIndexingAndClustering() { if (_isIndexingOrClusteringRunning) { - _shouldPauseIndexingAndClustering = true; + shouldPauseIndexingAndClustering = true; } } @@ -191,7 +169,7 @@ class MLService { } final futures = >[]; for (final instruction in chunk) { - if (_shouldPauseIndexingAndClustering) { + if (shouldPauseIndexingAndClustering) { _logger.info("indexAllImages() was paused, stopping"); break outerLoop; } @@ -214,7 +192,7 @@ class MLService { _logger.severe("indexAllImages failed", e, s); } finally { _isIndexingOrClusteringRunning = false; - _shouldPauseIndexingAndClustering = false; + shouldPauseIndexingAndClustering = false; } } @@ -275,7 +253,7 @@ class MLService { int bucket = 1; while (true) { - if (_shouldPauseIndexingAndClustering) { + if (shouldPauseIndexingAndClustering) { _logger.info( "MLController does not allow running ML, stopping before clustering bucket $bucket", ); @@ -385,7 +363,7 @@ class MLService { } finally { _showClusteringIsHappening = false; _isIndexingOrClusteringRunning = false; - _shouldPauseIndexingAndClustering = false; + shouldPauseIndexingAndClustering = false; } } @@ -397,11 +375,11 @@ class MLService { bool actuallyRanML = false; try { - final MLResult? result = await _analyzeImageInSingleIsolate( + final MLResult? result = await MLIsolate.instance.analyzeImage( instruction, ); if (result == null) { - if (!_shouldPauseIndexingAndClustering) { + if (!shouldPauseIndexingAndClustering) { _logger.severe( "Failed to analyze image with uploadedFileID: ${instruction.enteFile.uploadedFileID}", ); @@ -562,9 +540,7 @@ class MLService { // Initialize models try { - await _runInIsolate( - (FaceMlOperation.loadModels, {}), - ); + await MLIsolate.instance.loadModels(); _isModelsInitUsingEntePlugin = true; } catch (e, s) { _logger.severe("Could not initialize clip image", e, s); @@ -599,32 +575,7 @@ class MLService { }); } - Future _initIsolate() async { - return _initIsolateLock.synchronized(() async { - if (_isIsolateSpawned) return; - _logger.info("initIsolate called"); - - _receivePort = ReceivePort(); - - try { - _isolate = await DartUiIsolate.spawn( - _isolateMain, - _receivePort.sendPort, - ); - _mainSendPort = await _receivePort.first as SendPort; - _isIsolateSpawned = true; - - _resetInactivityTimer(); - _logger.info('initIsolate done'); - } catch (e) { - _logger.severe('Could not spawn isolate', e); - _isIsolateSpawned = false; - } - }); - } - Future _ensureReadyForInference() async { - await _initIsolate(); await _initModelsUsingFfiBasedPlugin(); if (Platform.isAndroid) { await _initModelUsingEntePlugin(); @@ -633,200 +584,7 @@ class MLService { } } - /// The main execution function of the isolate. - @pragma('vm:entry-point') - static void _isolateMain(SendPort mainSendPort) async { - Logger.root.level = kDebugMode ? Level.ALL : Level.INFO; - Logger.root.onRecord.listen((LogRecord rec) { - debugPrint('[MLIsolate] ${rec.toPrettyString()}'); - }); - final receivePort = ReceivePort(); - mainSendPort.send(receivePort.sendPort); - receivePort.listen((message) async { - final functionIndex = message[0] as int; - final function = FaceMlOperation.values[functionIndex]; - final args = message[1] as Map; - final sendPort = message[2] as SendPort; - - try { - switch (function) { - case FaceMlOperation.analyzeImage: - final time = DateTime.now(); - final MLResult result = await MLService._analyzeImageSync(args); - dev.log( - "`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms", - ); - sendPort.send(result.toJsonString()); - break; - case FaceMlOperation.loadModels: - await FaceDetectionService.instance.loadModel(useEntePlugin: true); - await FaceEmbeddingService.instance.loadModel(useEntePlugin: true); - await ClipImageEncoder.instance.loadModel(useEntePlugin: true); - sendPort.send(true); - break; - case FaceMlOperation.runClipText: - final textEmbedding = await ClipTextEncoder.predict(args); - sendPort.send(List.from(textEmbedding, growable: false)); - break; - } - } catch (e, stackTrace) { - dev.log( - "[SEVERE] Error in FaceML isolate: $e", - error: e, - stackTrace: stackTrace, - ); - sendPort - .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); - } - }); - } - - /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. - Future _runInIsolate( - (FaceMlOperation, Map) message, - ) async { - await _initIsolate(); - return _functionLock.synchronized(() async { - _resetInactivityTimer(); - - if (_shouldPauseIndexingAndClustering) { - return null; - } - - final completer = Completer(); - final answerPort = ReceivePort(); - - _activeTasks++; - _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); - - answerPort.listen((receivedMessage) { - if (receivedMessage is Map && receivedMessage.containsKey('error')) { - // Handle the error - final errorMessage = receivedMessage['error']; - final errorStackTrace = receivedMessage['stackTrace']; - final exception = Exception(errorMessage); - final stackTrace = StackTrace.fromString(errorStackTrace); - completer.completeError(exception, stackTrace); - } else { - completer.complete(receivedMessage); - } - }); - _activeTasks--; - - return completer.future; - }); - } - - /// Resets a timer that kills the isolate after a certain amount of inactivity. - /// - /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) - void _resetInactivityTimer() { - _inactivityTimer?.cancel(); - _inactivityTimer = Timer(_inactivityDuration, () { - if (_activeTasks > 0) { - _logger.info('Tasks are still running. Delaying isolate disposal.'); - // Optionally, reschedule the timer to check again later. - _resetInactivityTimer(); - } else { - _logger.info( - 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', - ); - _dispose(); - } - }); - } - - void _dispose() async { - if (!_isIsolateSpawned) return; - _logger.info('Disposing isolate and models'); - await _releaseModels(); - _isIsolateSpawned = false; - _isolate.kill(); - _receivePort.close(); - _inactivityTimer?.cancel(); - } - - /// Analyzes the given image data by running the full pipeline for faces, using [_analyzeImageSync] in the isolate. - Future _analyzeImageInSingleIsolate( - FileMLInstruction instruction, - ) async { - final String filePath = await getImagePathForML(instruction.enteFile); - - final Stopwatch stopwatch = Stopwatch()..start(); - late MLResult result; - - try { - final resultJsonString = await _runInIsolate( - ( - FaceMlOperation.analyzeImage, - { - "enteFileID": instruction.enteFile.uploadedFileID ?? -1, - "filePath": filePath, - "runFaces": instruction.shouldRunFaces, - "runClip": instruction.shouldRunClip, - "faceDetectionAddress": - FaceDetectionService.instance.sessionAddress, - "faceEmbeddingAddress": - FaceEmbeddingService.instance.sessionAddress, - "clipImageAddress": ClipImageEncoder.instance.sessionAddress, - } - ), - ) as String?; - if (resultJsonString == null) { - if (!_shouldPauseIndexingAndClustering) { - _logger.severe('Analyzing image in isolate is giving back null'); - } - return null; - } - result = MLResult.fromJsonString(resultJsonString); - } catch (e, s) { - _logger.severe( - "Could not analyze image with ID ${instruction.enteFile.uploadedFileID} \n", - e, - s, - ); - debugPrint( - "This image with ID ${instruction.enteFile.uploadedFileID} has name ${instruction.enteFile.displayName}.", - ); - final resultBuilder = - MLResult.fromEnteFileID(instruction.enteFile.uploadedFileID!) - ..errorOccurred(); - return resultBuilder; - } - stopwatch.stop(); - _logger.info( - "Finished Analyze image with uploadedFileID ${instruction.enteFile.uploadedFileID}, in " - "${stopwatch.elapsedMilliseconds} ms (including time waiting for inference engine availability)", - ); - - return result; - } - - Future> runClipTextInIsolate(String query) async { - try { - final int clipAddress = ClipTextEncoder.instance.sessionAddress; - final String remotePath = ClipTextEncoder.instance.vocabRemotePath; - final String tokenizerVocabPath = - await RemoteAssetsService.instance.getAssetPath(remotePath); - final textEmbedding = await _runInIsolate( - ( - FaceMlOperation.runClipText, - { - "text": query, - "address": clipAddress, - "vocabPath": tokenizerVocabPath, - "useEntePlugin": Platform.isAndroid, - } - ), - ) as List; - return textEmbedding; - } catch (e, s) { - _logger.severe("Could not run clip text in isolate", e, s); - rethrow; - } - } - - static Future _analyzeImageSync(Map args) async { + static Future analyzeImageSync(Map args) async { try { final int enteFileID = args["enteFileID"] as int; final String imagePath = args["filePath"] as String; @@ -906,7 +664,7 @@ class MLService { _logStatus(); return true; } - if (_shouldPauseIndexingAndClustering) { + if (shouldPauseIndexingAndClustering) { // This should ideally not be triggered, because one of the above should be triggered instead. _logger.warning( "Cannot run $function because indexing and clustering is being paused", @@ -923,7 +681,7 @@ class MLService { isFaceIndexingEnabled: ${LocalSettings.instance.isFaceIndexingEnabled} canRunMLController: $_mlControllerStatus isIndexingOrClusteringRunning: $_isIndexingOrClusteringRunning - shouldPauseIndexingAndClustering: $_shouldPauseIndexingAndClustering + shouldPauseIndexingAndClustering: $shouldPauseIndexingAndClustering debugIndexingDisabled: $debugIndexingDisabled '''; _logger.info(status); diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 81100df3ed..ffe6a5904a 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -16,8 +16,8 @@ import "package:photos/models/file/file.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/services/collections_service.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart"; +import "package:photos/services/machine_learning/ml_isolate.dart"; import "package:photos/services/machine_learning/ml_result.dart"; -import "package:photos/services/machine_learning/ml_service.dart"; import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart"; import 'package:photos/services/machine_learning/semantic_search/embedding_store.dart'; @@ -295,7 +295,7 @@ class SemanticSearchService { if (cachedResult != null) { return cachedResult; } - final textEmbedding = await MLService.instance.runClipTextInIsolate(query); + final textEmbedding = await MLIsolate.instance.runClipText(query); _queryCache.put(query, textEmbedding); return textEmbedding; }