From b598729f2802f99f601f337595068634cbd5a163 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Mon, 2 Sep 2024 18:33:29 +0200 Subject: [PATCH] [mob][photos] Simplify face clustering isolate service --- mobile/lib/services/isolate_functions.dart | 15 +- .../face_clustering_service.dart | 190 ++++-------------- 2 files changed, 50 insertions(+), 155 deletions(-) diff --git a/mobile/lib/services/isolate_functions.dart b/mobile/lib/services/isolate_functions.dart index 4224200857..308a5f7b1d 100644 --- a/mobile/lib/services/isolate_functions.dart +++ b/mobile/lib/services/isolate_functions.dart @@ -3,6 +3,7 @@ import 'dart:typed_data' show Uint8List; import "package:logging/logging.dart"; import "package:photos/models/ml/face/box.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/ml_model.dart"; import "package:photos/services/machine_learning/ml_result.dart"; import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart"; @@ -34,6 +35,9 @@ enum IsolateOperation { /// [MLComputer] testLogging, + + /// [FaceClusteringService] + linearIncrementalClustering } /// WARNING: Only return primitives unless you know the method is only going @@ -125,8 +129,15 @@ Future isolateFunction( // throw Exception("XXX logging from isolate testing exception handling"); return true; - /// Cases for MLComputer ends here + /// Cases for MLComputer end here - /// + /// Cases for FaceClusteringService start here + + /// FaceClusteringService + case IsolateOperation.linearIncrementalClustering: + final ClusteringResult result = runLinearClustering(args); + return result; + + /// Cases for FaceClusteringService end here } } diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index fefe2910d2..c6dbad7446 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -1,20 +1,19 @@ import "dart:async"; import "dart:developer"; -import "dart:isolate"; import "dart:typed_data" show Uint8List; import "package:computer/computer.dart"; -import "package:flutter/foundation.dart" show debugPrint, kDebugMode; +import "package:flutter/foundation.dart" show kDebugMode; import "package:logging/logging.dart"; import "package:ml_linalg/dtype.dart"; import "package:ml_linalg/vector.dart"; -import "package:photos/core/error-reporting/super_logging.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/base/id.dart"; +import "package:photos/services/isolate_functions.dart"; +import "package:photos/services/isolate_service.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/ml_result.dart"; -import "package:synchronized/synchronized.dart"; class FaceInfo { final String faceID; @@ -61,26 +60,32 @@ class ClusteringResult { } } -class FaceClusteringService { +class FaceClusteringService extends SuperIsolate { + @override + Logger get logger => _logger; final _logger = Logger("FaceLinearClustering"); + final _computer = Computer.shared(); - Timer? _inactivityTimer; - final Duration _inactivityDuration = const Duration(minutes: 2); - int _activeTasks = 0; - - final _initLock = Lock(); - - late Isolate _isolate; - late ReceivePort _receivePort = ReceivePort(); - late SendPort _mainSendPort; - - bool isSpawned = false; bool isRunning = false; static const kRecommendedDistanceThreshold = 0.24; static const kConservativeDistanceThreshold = 0.16; + @override + bool get isDartUiIsolate => false; + + @override + String get isolateName => "FaceClusteringIsolate"; + + @override + Future onDispose() async { + return; + } + + @override + bool get shouldAutomaticDispose => true; + // singleton pattern FaceClusteringService._privateConstructor(); @@ -89,124 +94,7 @@ class FaceClusteringService { static final instance = FaceClusteringService._privateConstructor(); factory FaceClusteringService() => instance; - Future _initIsolate() async { - return _initLock.synchronized(() async { - if (isSpawned) return; - - _receivePort = ReceivePort(); - - try { - _isolate = await Isolate.spawn( - _isolateMain, - _receivePort.sendPort, - ); - _mainSendPort = await _receivePort.first as SendPort; - isSpawned = true; - - _resetInactivityTimer(); - } catch (e) { - _logger.severe('Could not spawn isolate', e); - isSpawned = false; - } - }); - } - - Future _ensureSpawnedIsolate() async { - if (!isSpawned) { - await _initIsolate(); - } - } - - /// The main execution function of the isolate. - static void _isolateMain(SendPort mainSendPort) async { - Logger.root.level = kDebugMode ? Level.ALL : Level.INFO; - // TODO:lau move to right isolate logging - Logger.root.onRecord.listen((LogRecord rec) { - debugPrint('[MLIsolate] ${rec.toPrettyString()}'); - }); - final receivePort = ReceivePort(); - mainSendPort.send(receivePort.sendPort); - - receivePort.listen((message) async { - final functionIndex = message[0] as int; - final function = ClusterOperation.values[functionIndex]; - final args = message[1] as Map; - final sendPort = message[2] as SendPort; - - try { - switch (function) { - case ClusterOperation.linearIncrementalClustering: - final ClusteringResult result = _runLinearClustering(args); - sendPort.send(result); - break; - } - } catch (e, stackTrace) { - sendPort - .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); - } - }); - } - - /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. - Future _runInIsolate( - (ClusterOperation, Map) message, - ) async { - await _ensureSpawnedIsolate(); - _resetInactivityTimer(); - final completer = Completer(); - final answerPort = ReceivePort(); - - _activeTasks++; - _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); - - answerPort.listen((receivedMessage) { - if (receivedMessage is Map && receivedMessage.containsKey('error')) { - // Handle the error - final errorMessage = receivedMessage['error']; - final errorStackTrace = receivedMessage['stackTrace']; - final exception = Exception(errorMessage); - final stackTrace = StackTrace.fromString(errorStackTrace); - _activeTasks--; - completer.completeError(exception, stackTrace); - } else { - _activeTasks--; - completer.complete(receivedMessage); - } - }); - - return completer.future; - } - - /// Resets a timer that kills the isolate after a certain amount of inactivity. - /// - /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) - void _resetInactivityTimer() { - _inactivityTimer?.cancel(); - _inactivityTimer = Timer(_inactivityDuration, () { - if (_activeTasks > 0) { - _logger.info('Tasks are still running. Delaying isolate disposal.'); - // Optionally, reschedule the timer to check again later. - _resetInactivityTimer(); - } else { - _logger.info( - 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', - ); - _dispose(); - } - }); - } - - /// Disposes the isolate worker. - void _dispose() { - if (!isSpawned) return; - - isSpawned = false; - _isolate.kill(); - _receivePort.close(); - _inactivityTimer?.cancel(); - } - - /// Runs the clustering algorithm [_runLinearClustering] on the given [input], in an isolate. + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate. /// /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. Future predictLinearIsolate( @@ -238,31 +126,27 @@ class FaceClusteringService { final stopwatchClustering = Stopwatch()..start(); // final Map faceIdToCluster = // await _runLinearClusteringInComputer(input); - final ClusteringResult faceIdToCluster = await _runInIsolate( - ( - ClusterOperation.linearIncrementalClustering, - { - 'input': input, - 'fileIDToCreationTime': fileIDToCreationTime, - 'distanceThreshold': distanceThreshold, - 'conservativeDistanceThreshold': conservativeDistanceThreshold, - 'useDynamicThreshold': useDynamicThreshold, - 'offset': offset, - 'oldClusterSummaries': oldClusterSummaries, - } - ), - ); + final ClusteringResult faceIdToCluster = + await runInIsolate(IsolateOperation.linearIncrementalClustering, { + 'input': input, + 'fileIDToCreationTime': fileIDToCreationTime, + 'distanceThreshold': distanceThreshold, + 'conservativeDistanceThreshold': conservativeDistanceThreshold, + 'useDynamicThreshold': useDynamicThreshold, + 'offset': offset, + 'oldClusterSummaries': oldClusterSummaries, + }); // return _runLinearClusteringInComputer(input); _logger.info( 'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', ); - isRunning = false; return faceIdToCluster; } catch (e, stackTrace) { _logger.severe('Error while running clustering', e, stackTrace); - isRunning = false; rethrow; + } finally { + isRunning = false; } } @@ -308,7 +192,7 @@ class FaceClusteringService { } } - /// Runs the clustering algorithm [_runLinearClustering] on the given [input], in computer, without any dynamic thresholding + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding Future predictLinearComputer( Map input, { Map? fileIDToCreationTime, @@ -344,7 +228,7 @@ class FaceClusteringService { .toSet(); final startTime = DateTime.now(); final faceIdToCluster = await _computer.compute( - _runLinearClustering, + runLinearClustering, param: { "input": clusteringInput, "fileIDToCreationTime": fileIDToCreationTime, @@ -415,7 +299,7 @@ class FaceClusteringService { final _logger = Logger("FaceLinearClustering"); -ClusteringResult _runLinearClustering(Map args) { +ClusteringResult runLinearClustering(Map args) { // final input = args['input'] as Map; final input = args['input'] as Set; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?;