[mob][photos] Simplify face clustering isolate service

This commit is contained in:
laurenspriem
2024-09-02 18:33:29 +02:00
parent e1288bfd61
commit b598729f28
2 changed files with 50 additions and 155 deletions

View File

@@ -3,6 +3,7 @@ import 'dart:typed_data' show Uint8List;
import "package:logging/logging.dart";
import "package:photos/models/ml/face/box.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart";
@@ -34,6 +35,9 @@ enum IsolateOperation {
/// [MLComputer]
testLogging,
/// [FaceClusteringService]
linearIncrementalClustering
}
/// WARNING: Only return primitives unless you know the method is only going
@@ -125,8 +129,15 @@ Future<dynamic> isolateFunction(
// throw Exception("XXX logging from isolate testing exception handling");
return true;
/// Cases for MLComputer ends here
/// Cases for MLComputer end here
///
/// Cases for FaceClusteringService start here
/// FaceClusteringService
case IsolateOperation.linearIncrementalClustering:
final ClusteringResult result = runLinearClustering(args);
return result;
/// Cases for FaceClusteringService end here
}
}

View File

@@ -1,20 +1,19 @@
import "dart:async";
import "dart:developer";
import "dart:isolate";
import "dart:typed_data" show Uint8List;
import "package:computer/computer.dart";
import "package:flutter/foundation.dart" show debugPrint, kDebugMode;
import "package:flutter/foundation.dart" show kDebugMode;
import "package:logging/logging.dart";
import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/core/error-reporting/super_logging.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/base/id.dart";
import "package:photos/services/isolate_functions.dart";
import "package:photos/services/isolate_service.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:synchronized/synchronized.dart";
class FaceInfo {
final String faceID;
@@ -61,26 +60,32 @@ class ClusteringResult {
}
}
class FaceClusteringService {
class FaceClusteringService extends SuperIsolate {
@override
Logger get logger => _logger;
final _logger = Logger("FaceLinearClustering");
final _computer = Computer.shared();
Timer? _inactivityTimer;
final Duration _inactivityDuration = const Duration(minutes: 2);
int _activeTasks = 0;
final _initLock = Lock();
late Isolate _isolate;
late ReceivePort _receivePort = ReceivePort();
late SendPort _mainSendPort;
bool isSpawned = false;
bool isRunning = false;
static const kRecommendedDistanceThreshold = 0.24;
static const kConservativeDistanceThreshold = 0.16;
@override
bool get isDartUiIsolate => false;
@override
String get isolateName => "FaceClusteringIsolate";
@override
Future<void> onDispose() async {
return;
}
@override
bool get shouldAutomaticDispose => true;
// singleton pattern
FaceClusteringService._privateConstructor();
@@ -89,124 +94,7 @@ class FaceClusteringService {
static final instance = FaceClusteringService._privateConstructor();
factory FaceClusteringService() => instance;
Future<void> _initIsolate() async {
return _initLock.synchronized(() async {
if (isSpawned) return;
_receivePort = ReceivePort();
try {
_isolate = await Isolate.spawn(
_isolateMain,
_receivePort.sendPort,
);
_mainSendPort = await _receivePort.first as SendPort;
isSpawned = true;
_resetInactivityTimer();
} catch (e) {
_logger.severe('Could not spawn isolate', e);
isSpawned = false;
}
});
}
Future<void> _ensureSpawnedIsolate() async {
if (!isSpawned) {
await _initIsolate();
}
}
/// The main execution function of the isolate.
static void _isolateMain(SendPort mainSendPort) async {
Logger.root.level = kDebugMode ? Level.ALL : Level.INFO;
// TODO:lau move to right isolate logging
Logger.root.onRecord.listen((LogRecord rec) {
debugPrint('[MLIsolate] ${rec.toPrettyString()}');
});
final receivePort = ReceivePort();
mainSendPort.send(receivePort.sendPort);
receivePort.listen((message) async {
final functionIndex = message[0] as int;
final function = ClusterOperation.values[functionIndex];
final args = message[1] as Map<String, dynamic>;
final sendPort = message[2] as SendPort;
try {
switch (function) {
case ClusterOperation.linearIncrementalClustering:
final ClusteringResult result = _runLinearClustering(args);
sendPort.send(result);
break;
}
} catch (e, stackTrace) {
sendPort
.send({'error': e.toString(), 'stackTrace': stackTrace.toString()});
}
});
}
/// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result.
Future<dynamic> _runInIsolate(
(ClusterOperation, Map<String, dynamic>) message,
) async {
await _ensureSpawnedIsolate();
_resetInactivityTimer();
final completer = Completer<dynamic>();
final answerPort = ReceivePort();
_activeTasks++;
_mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]);
answerPort.listen((receivedMessage) {
if (receivedMessage is Map && receivedMessage.containsKey('error')) {
// Handle the error
final errorMessage = receivedMessage['error'];
final errorStackTrace = receivedMessage['stackTrace'];
final exception = Exception(errorMessage);
final stackTrace = StackTrace.fromString(errorStackTrace);
_activeTasks--;
completer.completeError(exception, stackTrace);
} else {
_activeTasks--;
completer.complete(receivedMessage);
}
});
return completer.future;
}
/// Resets a timer that kills the isolate after a certain amount of inactivity.
///
/// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`)
void _resetInactivityTimer() {
_inactivityTimer?.cancel();
_inactivityTimer = Timer(_inactivityDuration, () {
if (_activeTasks > 0) {
_logger.info('Tasks are still running. Delaying isolate disposal.');
// Optionally, reschedule the timer to check again later.
_resetInactivityTimer();
} else {
_logger.info(
'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.',
);
_dispose();
}
});
}
/// Disposes the isolate worker.
void _dispose() {
if (!isSpawned) return;
isSpawned = false;
_isolate.kill();
_receivePort.close();
_inactivityTimer?.cancel();
}
/// Runs the clustering algorithm [_runLinearClustering] on the given [input], in an isolate.
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate.
///
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
Future<ClusteringResult?> predictLinearIsolate(
@@ -238,31 +126,27 @@ class FaceClusteringService {
final stopwatchClustering = Stopwatch()..start();
// final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input);
final ClusteringResult faceIdToCluster = await _runInIsolate(
(
ClusterOperation.linearIncrementalClustering,
{
'input': input,
'fileIDToCreationTime': fileIDToCreationTime,
'distanceThreshold': distanceThreshold,
'conservativeDistanceThreshold': conservativeDistanceThreshold,
'useDynamicThreshold': useDynamicThreshold,
'offset': offset,
'oldClusterSummaries': oldClusterSummaries,
}
),
);
final ClusteringResult faceIdToCluster =
await runInIsolate(IsolateOperation.linearIncrementalClustering, {
'input': input,
'fileIDToCreationTime': fileIDToCreationTime,
'distanceThreshold': distanceThreshold,
'conservativeDistanceThreshold': conservativeDistanceThreshold,
'useDynamicThreshold': useDynamicThreshold,
'offset': offset,
'oldClusterSummaries': oldClusterSummaries,
});
// return _runLinearClusteringInComputer(input);
_logger.info(
'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
);
isRunning = false;
return faceIdToCluster;
} catch (e, stackTrace) {
_logger.severe('Error while running clustering', e, stackTrace);
isRunning = false;
rethrow;
} finally {
isRunning = false;
}
}
@@ -308,7 +192,7 @@ class FaceClusteringService {
}
}
/// Runs the clustering algorithm [_runLinearClustering] on the given [input], in computer, without any dynamic thresholding
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
Future<ClusteringResult> predictLinearComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
@@ -344,7 +228,7 @@ class FaceClusteringService {
.toSet();
final startTime = DateTime.now();
final faceIdToCluster = await _computer.compute(
_runLinearClustering,
runLinearClustering,
param: {
"input": clusteringInput,
"fileIDToCreationTime": fileIDToCreationTime,
@@ -415,7 +299,7 @@ class FaceClusteringService {
final _logger = Logger("FaceLinearClustering");
ClusteringResult _runLinearClustering(Map args) {
ClusteringResult runLinearClustering(Map args) {
// final input = args['input'] as Map<String, (int?, Uint8List)>;
final input = args['input'] as Set<FaceDbInfoForClustering>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;