[mob][photos] Move ONNX model initialization in abstract class
This commit is contained in:
@@ -1,79 +1,40 @@
|
||||
import "dart:async";
|
||||
import "dart:developer" as dev show log;
|
||||
import "dart:io" show File;
|
||||
import 'dart:typed_data' show ByteData;
|
||||
import 'dart:ui' as ui show Image;
|
||||
|
||||
import "package:computer/computer.dart";
|
||||
import 'package:logging/logging.dart';
|
||||
import 'package:onnxruntime/onnxruntime.dart';
|
||||
import "package:photos/face/model/dimension.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart";
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
import "package:photos/utils/image_ml_util.dart";
|
||||
|
||||
class YOLOFaceInterpreterRunException implements Exception {}
|
||||
|
||||
/// This class is responsible for running the face detection model (YOLOv5Face) on ONNX runtime, and can be accessed through the singleton instance [FaceDetectionService.instance].
|
||||
class FaceDetectionService {
|
||||
class FaceDetectionService extends MlModel {
|
||||
static const kRemoteBucketModelPath = "yolov5s_face_640_640_dynamic.onnx";
|
||||
|
||||
@override
|
||||
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
|
||||
|
||||
@override
|
||||
Logger get logger => _logger;
|
||||
static final _logger = Logger('FaceDetectionService');
|
||||
|
||||
final _computer = Computer.shared();
|
||||
|
||||
int sessionAddress = 0;
|
||||
|
||||
static const String kModelBucketEndpoint = "https://models.ente.io/";
|
||||
static const String kRemoteBucketModelPath =
|
||||
"yolov5s_face_640_640_dynamic.onnx";
|
||||
static const String modelRemotePath =
|
||||
kModelBucketEndpoint + kRemoteBucketModelPath;
|
||||
|
||||
static const int kInputWidth = 640;
|
||||
static const int kInputHeight = 640;
|
||||
static const double kIouThreshold = 0.4;
|
||||
static const double kMinScoreSigmoidThreshold = 0.7;
|
||||
static const int kNumKeypoints = 5;
|
||||
|
||||
bool isInitialized = false;
|
||||
|
||||
// Singleton pattern
|
||||
FaceDetectionService._privateConstructor();
|
||||
static final instance = FaceDetectionService._privateConstructor();
|
||||
factory FaceDetectionService() => instance;
|
||||
|
||||
/// Check if the interpreter is initialized, if not initialize it with `loadModel()`
|
||||
Future<void> init() async {
|
||||
if (!isInitialized) {
|
||||
_logger.info('init is called');
|
||||
final model =
|
||||
await RemoteAssetsService.instance.getAsset(modelRemotePath);
|
||||
final startTime = DateTime.now();
|
||||
sessionAddress = await _computer.compute(
|
||||
_loadModel,
|
||||
param: {
|
||||
"modelPath": model.path,
|
||||
},
|
||||
);
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"Face detection model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
|
||||
);
|
||||
if (sessionAddress != -1) {
|
||||
isInitialized = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> release() async {
|
||||
if (isInitialized) {
|
||||
await _computer
|
||||
.compute(_releaseModel, param: {'address': sessionAddress});
|
||||
isInitialized = false;
|
||||
sessionAddress = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Detects faces in the given image data.
|
||||
static Future<(List<FaceDetectionRelative>, Dimensions)> predict(
|
||||
ui.Image image,
|
||||
@@ -184,30 +145,4 @@ class FaceDetectionService {
|
||||
|
||||
return relativeDetections;
|
||||
}
|
||||
|
||||
/// Initialize the interpreter by loading the model file.
|
||||
static Future<int> _loadModel(Map args) async {
|
||||
final sessionOptions = OrtSessionOptions()
|
||||
..setInterOpNumThreads(1)
|
||||
..setIntraOpNumThreads(1)
|
||||
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
||||
try {
|
||||
final session =
|
||||
OrtSession.fromFile(File(args["modelPath"]), sessionOptions);
|
||||
return session.address;
|
||||
} catch (e, s) {
|
||||
_logger.severe('Face detection model not loaded', e, s);
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
static Future<void> _releaseModel(Map args) async {
|
||||
final address = args['address'] as int;
|
||||
if (address == 0) {
|
||||
return;
|
||||
}
|
||||
final session = OrtSession.fromAddress(address);
|
||||
session.release();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,95 +1,33 @@
|
||||
import "dart:io" show File;
|
||||
import 'dart:math' as math show sqrt;
|
||||
import 'dart:typed_data' show Float32List;
|
||||
|
||||
import 'package:computer/computer.dart';
|
||||
import 'package:logging/logging.dart';
|
||||
import 'package:onnxruntime/onnxruntime.dart';
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
|
||||
class MobileFaceNetInterpreterRunException implements Exception {}
|
||||
|
||||
/// This class is responsible for running the face embedding model (MobileFaceNet) on ONNX runtime, and can be accessed through the singleton instance [FaceEmbeddingService.instance].
|
||||
class FaceEmbeddingService {
|
||||
static const kModelBucketEndpoint = "https://models.ente.io/";
|
||||
class FaceEmbeddingService extends MlModel {
|
||||
static const kRemoteBucketModelPath = "mobilefacenet_opset15.onnx";
|
||||
static const modelRemotePath = kModelBucketEndpoint + kRemoteBucketModelPath;
|
||||
|
||||
@override
|
||||
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
|
||||
|
||||
@override
|
||||
Logger get logger => _logger;
|
||||
static final _logger = Logger('FaceEmbeddingService');
|
||||
|
||||
static const int kInputSize = 112;
|
||||
static const int kEmbeddingSize = 192;
|
||||
static const int kNumChannels = 3;
|
||||
static const bool kPreWhiten = false;
|
||||
|
||||
static final _logger = Logger('FaceEmbeddingService');
|
||||
|
||||
bool isInitialized = false;
|
||||
int sessionAddress = 0;
|
||||
|
||||
final _computer = Computer.shared();
|
||||
|
||||
// Singleton pattern
|
||||
FaceEmbeddingService._privateConstructor();
|
||||
static final instance = FaceEmbeddingService._privateConstructor();
|
||||
factory FaceEmbeddingService() => instance;
|
||||
|
||||
/// Check if the interpreter is initialized, if not initialize it with `loadModel()`
|
||||
Future<void> init() async {
|
||||
if (!isInitialized) {
|
||||
_logger.info('init is called');
|
||||
final model =
|
||||
await RemoteAssetsService.instance.getAsset(modelRemotePath);
|
||||
final startTime = DateTime.now();
|
||||
// Doing this from main isolate since `rootBundle` cannot be accessed outside it
|
||||
sessionAddress = await _computer.compute(
|
||||
_loadModel,
|
||||
param: {
|
||||
"modelPath": model.path,
|
||||
},
|
||||
);
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"Face embedding model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
|
||||
);
|
||||
if (sessionAddress != -1) {
|
||||
isInitialized = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> release() async {
|
||||
if (isInitialized) {
|
||||
await _computer
|
||||
.compute(_releaseModel, param: {'address': sessionAddress});
|
||||
isInitialized = false;
|
||||
sessionAddress = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static Future<int> _loadModel(Map args) async {
|
||||
final sessionOptions = OrtSessionOptions()
|
||||
..setInterOpNumThreads(1)
|
||||
..setIntraOpNumThreads(1)
|
||||
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
||||
try {
|
||||
final session =
|
||||
OrtSession.fromFile(File(args["modelPath"]), sessionOptions);
|
||||
return session.address;
|
||||
} catch (e, s) {
|
||||
_logger.severe('Face embedding model not loaded', e, s);
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
static Future<void> _releaseModel(Map args) async {
|
||||
final address = args['address'] as int;
|
||||
if (address == 0) {
|
||||
return;
|
||||
}
|
||||
final session = OrtSession.fromAddress(address);
|
||||
session.release();
|
||||
return;
|
||||
}
|
||||
|
||||
static Future<List<List<double>>> predict(
|
||||
Float32List input,
|
||||
int sessionAddress,
|
||||
|
||||
76
mobile/lib/services/machine_learning/ml_model.dart
Normal file
76
mobile/lib/services/machine_learning/ml_model.dart
Normal file
@@ -0,0 +1,76 @@
|
||||
import "dart:io" show File;
|
||||
|
||||
import "package:computer/computer.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
|
||||
abstract class MlModel {
|
||||
Logger get logger;
|
||||
|
||||
String get kModelBucketEndpoint => "https://models.ente.io/";
|
||||
static const kRemoteBucketModelPath = "";
|
||||
|
||||
String get modelRemotePath;
|
||||
|
||||
bool isInitialized = false;
|
||||
int sessionAddress = 0;
|
||||
|
||||
final computer = Computer.shared();
|
||||
|
||||
Future<void> init() async {
|
||||
if (!isInitialized) {
|
||||
logger.info('init is called');
|
||||
final model =
|
||||
await RemoteAssetsService.instance.getAsset(modelRemotePath);
|
||||
final startTime = DateTime.now();
|
||||
try {
|
||||
sessionAddress = await computer.compute(
|
||||
_loadModel,
|
||||
param: {
|
||||
"modelPath": model.path,
|
||||
},
|
||||
);
|
||||
isInitialized = true;
|
||||
final endTime = DateTime.now();
|
||||
logger.info(
|
||||
"model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
|
||||
);
|
||||
} catch (e, s) {
|
||||
logger.severe('model not loaded', e, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> release() async {
|
||||
if (isInitialized) {
|
||||
await computer.compute(_releaseModel, param: {'address': sessionAddress});
|
||||
isInitialized = false;
|
||||
sessionAddress = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static Future<int> _loadModel(Map args) async {
|
||||
final sessionOptions = OrtSessionOptions()
|
||||
..setInterOpNumThreads(1)
|
||||
..setIntraOpNumThreads(1)
|
||||
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
||||
try {
|
||||
final session =
|
||||
OrtSession.fromFile(File(args["modelPath"]), sessionOptions);
|
||||
return session.address;
|
||||
} catch (e) {
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
|
||||
static Future<void> _releaseModel(Map args) async {
|
||||
final address = args['address'] as int;
|
||||
if (address == 0) {
|
||||
return;
|
||||
}
|
||||
final session = OrtSession.fromAddress(address);
|
||||
session.release();
|
||||
return;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user