diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart index 455f4e9c8d..7619bc9015 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart @@ -1,79 +1,40 @@ import "dart:async"; import "dart:developer" as dev show log; -import "dart:io" show File; import 'dart:typed_data' show ByteData; import 'dart:ui' as ui show Image; -import "package:computer/computer.dart"; import 'package:logging/logging.dart'; import 'package:onnxruntime/onnxruntime.dart'; import "package:photos/face/model/dimension.dart"; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart"; -import "package:photos/services/remote_assets_service.dart"; +import "package:photos/services/machine_learning/ml_model.dart"; import "package:photos/utils/image_ml_util.dart"; class YOLOFaceInterpreterRunException implements Exception {} /// This class is responsible for running the face detection model (YOLOv5Face) on ONNX runtime, and can be accessed through the singleton instance [FaceDetectionService.instance]. -class FaceDetectionService { +class FaceDetectionService extends MlModel { + static const kRemoteBucketModelPath = "yolov5s_face_640_640_dynamic.onnx"; + + @override + String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath; + + @override + Logger get logger => _logger; static final _logger = Logger('FaceDetectionService'); - final _computer = Computer.shared(); - - int sessionAddress = 0; - - static const String kModelBucketEndpoint = "https://models.ente.io/"; - static const String kRemoteBucketModelPath = - "yolov5s_face_640_640_dynamic.onnx"; - static const String modelRemotePath = - kModelBucketEndpoint + kRemoteBucketModelPath; - static const int kInputWidth = 640; static const int kInputHeight = 640; static const double kIouThreshold = 0.4; static const double kMinScoreSigmoidThreshold = 0.7; static const int kNumKeypoints = 5; - bool isInitialized = false; - // Singleton pattern FaceDetectionService._privateConstructor(); static final instance = FaceDetectionService._privateConstructor(); factory FaceDetectionService() => instance; - /// Check if the interpreter is initialized, if not initialize it with `loadModel()` - Future init() async { - if (!isInitialized) { - _logger.info('init is called'); - final model = - await RemoteAssetsService.instance.getAsset(modelRemotePath); - final startTime = DateTime.now(); - sessionAddress = await _computer.compute( - _loadModel, - param: { - "modelPath": model.path, - }, - ); - final endTime = DateTime.now(); - _logger.info( - "Face detection model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", - ); - if (sessionAddress != -1) { - isInitialized = true; - } - } - } - - Future release() async { - if (isInitialized) { - await _computer - .compute(_releaseModel, param: {'address': sessionAddress}); - isInitialized = false; - sessionAddress = 0; - } - } - /// Detects faces in the given image data. static Future<(List, Dimensions)> predict( ui.Image image, @@ -184,30 +145,4 @@ class FaceDetectionService { return relativeDetections; } - - /// Initialize the interpreter by loading the model file. - static Future _loadModel(Map args) async { - final sessionOptions = OrtSessionOptions() - ..setInterOpNumThreads(1) - ..setIntraOpNumThreads(1) - ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); - try { - final session = - OrtSession.fromFile(File(args["modelPath"]), sessionOptions); - return session.address; - } catch (e, s) { - _logger.severe('Face detection model not loaded', e, s); - } - return -1; - } - - static Future _releaseModel(Map args) async { - final address = args['address'] as int; - if (address == 0) { - return; - } - final session = OrtSession.fromAddress(address); - session.release(); - return; - } } diff --git a/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart b/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart index 0e49cb8c4f..e4536dfeaa 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart @@ -1,95 +1,33 @@ -import "dart:io" show File; import 'dart:math' as math show sqrt; import 'dart:typed_data' show Float32List; -import 'package:computer/computer.dart'; import 'package:logging/logging.dart'; import 'package:onnxruntime/onnxruntime.dart'; -import "package:photos/services/remote_assets_service.dart"; +import "package:photos/services/machine_learning/ml_model.dart"; class MobileFaceNetInterpreterRunException implements Exception {} /// This class is responsible for running the face embedding model (MobileFaceNet) on ONNX runtime, and can be accessed through the singleton instance [FaceEmbeddingService.instance]. -class FaceEmbeddingService { - static const kModelBucketEndpoint = "https://models.ente.io/"; +class FaceEmbeddingService extends MlModel { static const kRemoteBucketModelPath = "mobilefacenet_opset15.onnx"; - static const modelRemotePath = kModelBucketEndpoint + kRemoteBucketModelPath; + + @override + String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath; + + @override + Logger get logger => _logger; + static final _logger = Logger('FaceEmbeddingService'); static const int kInputSize = 112; static const int kEmbeddingSize = 192; static const int kNumChannels = 3; static const bool kPreWhiten = false; - static final _logger = Logger('FaceEmbeddingService'); - - bool isInitialized = false; - int sessionAddress = 0; - - final _computer = Computer.shared(); - // Singleton pattern FaceEmbeddingService._privateConstructor(); static final instance = FaceEmbeddingService._privateConstructor(); factory FaceEmbeddingService() => instance; - /// Check if the interpreter is initialized, if not initialize it with `loadModel()` - Future init() async { - if (!isInitialized) { - _logger.info('init is called'); - final model = - await RemoteAssetsService.instance.getAsset(modelRemotePath); - final startTime = DateTime.now(); - // Doing this from main isolate since `rootBundle` cannot be accessed outside it - sessionAddress = await _computer.compute( - _loadModel, - param: { - "modelPath": model.path, - }, - ); - final endTime = DateTime.now(); - _logger.info( - "Face embedding model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", - ); - if (sessionAddress != -1) { - isInitialized = true; - } - } - } - - Future release() async { - if (isInitialized) { - await _computer - .compute(_releaseModel, param: {'address': sessionAddress}); - isInitialized = false; - sessionAddress = 0; - } - } - - static Future _loadModel(Map args) async { - final sessionOptions = OrtSessionOptions() - ..setInterOpNumThreads(1) - ..setIntraOpNumThreads(1) - ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); - try { - final session = - OrtSession.fromFile(File(args["modelPath"]), sessionOptions); - return session.address; - } catch (e, s) { - _logger.severe('Face embedding model not loaded', e, s); - } - return -1; - } - - static Future _releaseModel(Map args) async { - final address = args['address'] as int; - if (address == 0) { - return; - } - final session = OrtSession.fromAddress(address); - session.release(); - return; - } - static Future>> predict( Float32List input, int sessionAddress, diff --git a/mobile/lib/services/machine_learning/ml_model.dart b/mobile/lib/services/machine_learning/ml_model.dart new file mode 100644 index 0000000000..6565b67edb --- /dev/null +++ b/mobile/lib/services/machine_learning/ml_model.dart @@ -0,0 +1,76 @@ +import "dart:io" show File; + +import "package:computer/computer.dart"; +import "package:logging/logging.dart"; +import "package:onnxruntime/onnxruntime.dart"; +import "package:photos/services/remote_assets_service.dart"; + +abstract class MlModel { + Logger get logger; + + String get kModelBucketEndpoint => "https://models.ente.io/"; + static const kRemoteBucketModelPath = ""; + + String get modelRemotePath; + + bool isInitialized = false; + int sessionAddress = 0; + + final computer = Computer.shared(); + + Future init() async { + if (!isInitialized) { + logger.info('init is called'); + final model = + await RemoteAssetsService.instance.getAsset(modelRemotePath); + final startTime = DateTime.now(); + try { + sessionAddress = await computer.compute( + _loadModel, + param: { + "modelPath": model.path, + }, + ); + isInitialized = true; + final endTime = DateTime.now(); + logger.info( + "model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", + ); + } catch (e, s) { + logger.severe('model not loaded', e, s); + } + } + } + + Future release() async { + if (isInitialized) { + await computer.compute(_releaseModel, param: {'address': sessionAddress}); + isInitialized = false; + sessionAddress = 0; + } + } + + static Future _loadModel(Map args) async { + final sessionOptions = OrtSessionOptions() + ..setInterOpNumThreads(1) + ..setIntraOpNumThreads(1) + ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); + try { + final session = + OrtSession.fromFile(File(args["modelPath"]), sessionOptions); + return session.address; + } catch (e) { + rethrow; + } + } + + static Future _releaseModel(Map args) async { + final address = args['address'] as int; + if (address == 0) { + return; + } + final session = OrtSession.fromAddress(address); + session.release(); + return; + } +}