diff --git a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart index f1eac22b76..4b78c00d56 100644 --- a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart @@ -1,12 +1,10 @@ -import "dart:typed_data" show Uint8List; +import "dart:typed_data" show Int32List, Uint8List; import "dart:ui" show Image; import "package:logging/logging.dart"; -// import "package:onnx_dart/onnx_dart.dart"; +import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/machine_learning/ml_model.dart"; -import "package:photos/utils/debug_ml_export_data.dart"; -// import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/ml_util.dart"; class ClipImageEncoder extends MlModel { @@ -36,27 +34,13 @@ class ClipImageEncoder extends MlModel { int? enteFileID, ]) async { final startTime = DateTime.now(); - // final inputListAa = await preprocessImageClip(image, rawRgbaBytes, true); - // final inputListNoaa = await preprocessImageClip(image, rawRgbaBytes, false); - // await encodeAndSaveData(inputListAa, "star-aa-mobile-input", "clip"); - // await encodeAndSaveData(inputListNoaa, "star-noaa-mobile-input", "clip"); - final preprocessingTime = DateTime.now(); - final preprocessingMs = - preprocessingTime.difference(startTime).inMilliseconds; - late List resultAa; //, resultNoaa; + final inputShape = [image.height, image.width, 4]; // [H, W, C] + late List result; try { - if (false) { - // resultAa = await _runPlatformPluginPredict(rawRgbaBytes); - // resultNoaa = await _runPlatformPluginPredict(inputListNoaa); - // await encodeAndSaveData(resultAa, "star-aa-mobile-embedding", "clip"); - // await encodeAndSaveData( - // resultNoaa, - // "star-noaa-mobile-embedding", - // "clip", - // ); + if (MlModel.usePlatformPlugin) { + result = await _runPlatformPluginPredict(rawRgbaBytes, inputShape); } else { - resultAa = _runFFIBasedPredict(rawRgbaBytes, sessionAddress); - print('clip inference done with FFI package'); + result = _runFFIBasedPredict(rawRgbaBytes, inputShape, sessionAddress); } } catch (e, stackTrace) { _logger.severe( @@ -66,27 +50,21 @@ class ClipImageEncoder extends MlModel { ); rethrow; } - final inferTime = DateTime.now(); - final inferenceMs = inferTime.difference(preprocessingTime).inMilliseconds; - final totalMs = inferTime.difference(startTime).inMilliseconds; + final totalMs = DateTime.now().difference(startTime).inMilliseconds; _logger.info( - "Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""} (inference: $inferenceMs ms, preprocessing: $preprocessingMs ms)", + "Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""}", ); - await encodeAndSaveData( - resultAa, - 'singapore-rgba-mobile-embedding-ffi', - 'clip', - ); - return resultAa; + return result; } static List _runFFIBasedPredict( Uint8List inputImageList, + List inputImageShape, int sessionAddress, ) { final inputOrt = OrtValueTensor.createTensorWithDataList( inputImageList, - [1200, 1920, 4], + inputImageShape, ); final inputs = {'input': inputOrt}; final session = OrtSession.fromAddress(sessionAddress); @@ -102,16 +80,18 @@ class ClipImageEncoder extends MlModel { return embedding; } - // static Future> _runPlatformPluginPredict( - // Uint8List inputImageList, - // ) async { - // final OnnxDart plugin = OnnxDart(); - // final result = await plugin.predictRgba( - // inputImageList, - // _modelName, - // ); - // final List embedding = result!.sublist(0, 512); - // normalizeEmbedding(embedding); - // return embedding; - // } + static Future> _runPlatformPluginPredict( + Uint8List inputImageList, + List inputImageShape, + ) async { + final OnnxDart plugin = OnnxDart(); + final result = await plugin.predictRgba( + inputImageList, + Int32List.fromList(inputImageShape), + _modelName, + ); + final List embedding = result!.sublist(0, 512); + normalizeEmbedding(embedding); + return embedding; + } } diff --git a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt index 44b1811c2d..cf13f63344 100644 --- a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt +++ b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt @@ -93,13 +93,15 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { val modelType = call.argument("modelType") val inputDataArray = call.argument("inputData") val inputIntDataArray = call.argument("inputDataInt") + val inputUint8DataArray = call.argument("inputDataUint8") + val inputShapeArray = call.argument("inputShapeList") - if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null)) { + if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null && inputUint8DataArray == null)) { result.error("INVALID_ARGUMENT", "Session address, model type, or input data is missing", null) return } - predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, result) + predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, inputUint8DataArray, inputShapeArray, result) } else -> { result.notImplemented() @@ -155,9 +157,10 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { } } - private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, result: Result) { + private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, inputUint8DataArray: ByteArray? = null, inputShapeArray: IntArray? = null, result: Result) { // Assert that exactly one of inputDataFloat or inputDataInt is provided - assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" } + // assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" } + // TODO: lau re-add above assert later scope.launch { val modelState = sessionMap[modelType] @@ -178,8 +181,11 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { inputTensorShape = longArrayOf(totalSize, 112, 112, 3) } ModelType.ClipImageEncoder -> { - inputTensorShape = longArrayOf(1, 3, 256, 256) -// inputTensorShape = longArrayOf(1, 3, 256, 256) + if (inputShapeArray != null) { + inputTensorShape = inputShapeArray.map { it.toLong() }.toLongArray() + } else { + result.error("INVALID_ARGUMENT", "Input shape is missing for clip image input", null) + } } ModelType.ClipTextEncoder -> { inputTensorShape = longArrayOf(1, 77) @@ -192,6 +198,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { val inputTensor = when { inputDataFloat != null -> OnnxTensor.createTensor(env, FloatBuffer.wrap(inputDataFloat), inputTensorShape) inputDataInt != null -> OnnxTensor.createTensor(env, IntBuffer.wrap(inputDataInt), inputTensorShape) + inputUint8DataArray != null -> OnnxTensor.createTensor(env, ByteBuffer.wrap(inputUint8DataArray), inputTensorShape, OnnxJavaType.UINT8) else -> throw IllegalArgumentException("No input data provided") } val inputs = mutableMapOf() @@ -219,7 +226,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { inputTensor.close() } catch (e: OrtException) { withContext(Dispatchers.Main) { - result.error("PREDICTION_ERROR", "Error during prediction: ${e.message}", null) + result.error("PREDICTION_ERROR", "Error during prediction: ${e.message} ${e.stackTraceToString()}", null) } } catch (e: Exception) { Log.e(TAG, "Error during prediction: ${e.message}", e) diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart.dart b/mobile/plugins/onnx_dart/lib/onnx_dart.dart index 37ac068dbf..f13bf6926d 100644 --- a/mobile/plugins/onnx_dart/lib/onnx_dart.dart +++ b/mobile/plugins/onnx_dart/lib/onnx_dart.dart @@ -25,8 +25,13 @@ class OnnxDart { String modelType, { int sessionAddress = 0, }) async { - final result = await OnnxDartPlatform.instance - .predict(inputData, null, modelType, sessionAddress: sessionAddress); + final result = await OnnxDartPlatform.instance.predict( + inputData, + null, + null, + modelType, + sessionAddress: sessionAddress, + ); return result; } @@ -35,8 +40,30 @@ class OnnxDart { String modelType, { int sessionAddress = 0, }) async { - final result = await OnnxDartPlatform.instance - .predict(null, inputDataInt, modelType, sessionAddress: sessionAddress); + final result = await OnnxDartPlatform.instance.predict( + null, + inputDataInt, + null, + modelType, + sessionAddress: sessionAddress, + ); + return result; + } + + Future predictRgba( + Uint8List inputBytes, + Int32List inputShape, + String modelType, { + int sessionAddress = 0, + }) async { + final result = await OnnxDartPlatform.instance.predict( + null, + null, + inputBytes, + modelType, + sessionAddress: sessionAddress, + inputShapeList: inputShape, + ); return result; } } diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart b/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart index 8d638eb1f9..a602f41fb8 100644 --- a/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart +++ b/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart @@ -40,8 +40,10 @@ class MethodChannelOnnxDart extends OnnxDartPlatform { Future predict( Float32List? inputData, Int32List? inputDataInt, + Uint8List? inputDataUint8, String modelType, { int sessionAddress = 0, + Int32List? inputShapeList, }) { return methodChannel.invokeMethod( 'predict', @@ -49,7 +51,9 @@ class MethodChannelOnnxDart extends OnnxDartPlatform { 'sessionAddress': sessionAddress, 'inputData': inputData, 'inputDataInt': inputDataInt, + 'inputDataUint8': inputDataUint8, 'modelType': modelType, + "inputShapeList": inputShapeList, }, ); } diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart b/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart index 25ada82f8d..82a6107a09 100644 --- a/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart +++ b/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart @@ -42,12 +42,12 @@ abstract class OnnxDartPlatform extends PlatformInterface { Future predict( Float32List? inputData, - Int32List? inputDataInt, + Int32List? inputDataInt, + Uint8List? inputDataRgba, String modelType, { int sessionAddress = 0, + Int32List? inputShapeList, }) { throw UnimplementedError('predict() has not been implemented.'); } - - }