From 08f846c315eb1e0b2dd1b034f18488f6f151d86a Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 10 Jul 2024 18:22:04 +0530 Subject: [PATCH] [mob] Use custom plugin for clip image encoding --- .../services/machine_learning/ml_service.dart | 3 +- .../clip/clip_image_encoder.dart | 44 ++++++++++++++++--- .../semantic_search_service.dart | 13 ++++-- .../ente/photos/onnx_dart/OnnxDartPlugin.kt | 11 +++-- 4 files changed, 58 insertions(+), 13 deletions(-) diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index eca77c95ee..469fa289be 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -1,6 +1,6 @@ import "dart:async"; import "dart:developer" as dev show log; -import "dart:io" show File; +import "dart:io" show File, Platform; import "dart:isolate"; import "dart:math" show min; import "dart:typed_data" show Uint8List, ByteData; @@ -838,6 +838,7 @@ class MLService { image, imageByteData, clipImageAddress, + useEntePlugin: Platform.isAndroid, ); result.clip = clipResult; } diff --git a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart index 3bacee5fd5..ba68ed1dc7 100644 --- a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart @@ -1,14 +1,17 @@ -import "dart:typed_data" show ByteData; +import "dart:typed_data"; import "dart:ui" show Image; import "package:logging/logging.dart"; +import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; +import "package:photos/extensions/stop_watch.dart"; import "package:photos/services/machine_learning/ml_model.dart"; import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/ml_util.dart"; class ClipImageEncoder extends MlModel { static const kRemoteBucketModelPath = "clip-image-vit-32-float32.onnx"; + static const _modelName = "ClipImageEncoder"; @override String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath; @@ -18,7 +21,7 @@ class ClipImageEncoder extends MlModel { static final _logger = Logger('ClipImageEncoder'); @override - String get modelName => "ClipImageEncoder"; + String get modelName => _modelName; // Singleton pattern ClipImageEncoder._privateConstructor(); @@ -28,10 +31,27 @@ class ClipImageEncoder extends MlModel { static Future> predict( Image image, ByteData imageByteData, - int sessionAddress, - ) async { + int sessionAddress, { + bool useEntePlugin = false, + }) async { + final w = EnteWatch("ClipImageEncoder.predict")..start(); final inputList = await preprocessImageClip(image, imageByteData); + w.log("preprocessImageClip"); + if (useEntePlugin) { + final result = await _runEntePlugin(inputList); + w.stopWithLog("done"); + return result; + } + final result = _runFFIBasedPredict(inputList, sessionAddress); + w.stopWithLog("done"); + return result; + } + static List _runFFIBasedPredict( + Float32List inputList, + int sessionAddress, + ) { + final w = EnteWatch("ClipImageEncoder._runFFIBasedPredict")..start(); final inputOrt = OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 224, 224]); final inputs = {'input': inputOrt}; @@ -39,9 +59,23 @@ class ClipImageEncoder extends MlModel { final runOptions = OrtRunOptions(); final outputs = session.run(runOptions, inputs); final embedding = (outputs[0]?.value as List>)[0]; - normalizeEmbedding(embedding); + w.stopWithLog("done"); + return embedding; + } + static Future> _runEntePlugin( + Float32List inputImageList, + ) async { + final w = EnteWatch("ClipImageEncoder._runEntePlugin")..start(); + final OnnxDart plugin = OnnxDart(); + final result = await plugin.predict( + inputImageList, + _modelName, + ); + final List embedding = result!.sublist(0, 512); + normalizeEmbedding(embedding); + w.stopWithLog("done"); return embedding; } } diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 740f84f342..a03fe19a5a 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -334,10 +334,15 @@ class SemanticSearchService { int enteFileID, Image image, ByteData imageByteData, - int clipImageAddress, - ) async { - final embedding = - await ClipImageEncoder.predict(image, imageByteData, clipImageAddress); + int clipImageAddress, { + bool useEntePlugin = false, + }) async { + final embedding = await ClipImageEncoder.predict( + image, + imageByteData, + clipImageAddress, + useEntePlugin: useEntePlugin, + ); final clipResult = ClipResult(fileID: enteFileID, embedding: embedding); return clipResult; diff --git a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt index 2caa2c4236..a2dfbae4fb 100644 --- a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt +++ b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt @@ -168,6 +168,11 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { inputTensorShape[1] = 112 inputTensorShape[2] = 112 inputTensorShape[3] = 3 + } else if(modelType == ModelType.ClipImageEncoder) { + inputTensorShape[0] = 1 + inputTensorShape[1] = 3 + inputTensorShape[2] = 224 + inputTensorShape[3] = 224 } val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape) val inputs = mutableMapOf() @@ -178,14 +183,14 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { } val outputs = session.run(inputs) Log.d(TAG, "Output shape: ${outputs.size()}") - if (modelType == ModelType.MobileFaceNet) { - val outputTensor = (outputs[0].value as Array) + if (modelType == ModelType.YOLOv5Face) { + val outputTensor = (outputs[0].value as Array>).get(0) val flatList = outputTensor.flattenToFloatArray() withContext(Dispatchers.Main) { result.success(flatList) } } else { - val outputTensor = (outputs[0].value as Array>).get(0) + val outputTensor = (outputs[0].value as Array) val flatList = outputTensor.flattenToFloatArray() withContext(Dispatchers.Main) { result.success(flatList)