From 227877bd07e27930da8596200289dfa3f61d9c4f Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:43:31 +0530 Subject: [PATCH] [mob] Add support for Clip text in custom plugin --- .../ente/photos/onnx_dart/OnnxDartPlugin.kt | 106 ++++++++++-------- mobile/plugins/onnx_dart/lib/onnx_dart.dart | 12 +- .../lib/onnx_dart_method_channel.dart | 4 +- .../lib/onnx_dart_platform_interface.dart | 5 +- 4 files changed, 80 insertions(+), 47 deletions(-) diff --git a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt index a2dfbae4fb..eaf32c21e1 100644 --- a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt +++ b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt @@ -1,6 +1,7 @@ package io.ente.photos.onnx_dart import android.content.Context +import java.nio.IntBuffer import androidx.annotation.NonNull import ai.onnxruntime.* import java.util.EnumMap @@ -35,7 +36,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { private lateinit var context: Context enum class ModelType { - CLIP_TEXT, ClipImageEncoder, YOLOv5Face, MobileFaceNet + ClipTextEncoder, ClipImageEncoder, YOLOv5Face, MobileFaceNet } companion object { const val DEFAULT_SESSION_COUNT = 1 @@ -63,41 +64,46 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { override fun onMethodCall(call: MethodCall, result: Result) { - if (call.method == "getPlatformVersion") { - result.success("Android ${android.os.Build.VERSION.RELEASE}") - } else if (call.method == "init") { - val modelType = call.argument("modelType") ?: run { - result.error("INVALID_ARGUMENT", "Model type is missing", null) - return + when (call.method) { + "getPlatformVersion" -> { + result.success("Android ${android.os.Build.VERSION.RELEASE}") } - val modelPath = call.argument("modelPath") ?: run { - result.error("INVALID_ARGUMENT", "Model path is missing", null) - return + "init" -> { + val modelType = call.argument("modelType") + val modelPath = call.argument("modelPath") + val sessionsCount = call.argument("sessionsCount") ?: DEFAULT_SESSION_COUNT + + if (modelType == null || modelPath == null) { + result.error("INVALID_ARGUMENT", "Model type or path is missing", null) + return + } + + init(ModelType.valueOf(modelType), modelPath, sessionsCount, result) } - val sessionsCount = call.argument("sessionsCount") ?: DEFAULT_SESSION_COUNT - init(ModelType.valueOf(modelType), modelPath, sessionsCount, result) - } - else if (call.method == "release" ) { - val modelType = call.argument("modelType") ?: run { - result.error("INVALID_ARGUMENT", "Model type is missing", null) - return + "release" -> { + val modelType = call.argument("modelType") + if (modelType == null) { + result.error("INVALID_ARGUMENT", "Model type is missing", null) + return + } + release(ModelType.valueOf(modelType), result) } - release(ModelType.valueOf(modelType), result) - } - else if (call.method == "predict" ) { - val sessionAddress = call.argument("sessionAddress") - val modelType = call.argument("modelType") ?: run { - result.error("INVALID_ARGUMENT", "Model type is missing", null) - return + "predict" -> { + val sessionAddress = call.argument("sessionAddress") + val modelType = call.argument("modelType") + val inputDataArray = call.argument("inputData") + val inputIntDataArray = call.argument("inputDataInt") + + if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null)) { + result.error("INVALID_ARGUMENT", "Session address, model type, or input data is missing", null) + return + } + + predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, result) } - val inputDataArray = call.argument("inputData") - if (sessionAddress == null || inputDataArray == null) { - result.error("INVALID_ARGUMENT", "Session address or input data is missing", null) - return + else -> { + result.notImplemented() } - predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, result) - } else { - result.notImplemented() } } @@ -149,7 +155,10 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { } } - private fun predict(modelType: ModelType, sessionAddress: Int, inputData: FloatArray, result: Result) { + private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, result: Result) { + // Assert that exactly one of inputDataFloat or inputDataInt is provided + assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" } + scope.launch { val modelState = sessionMap[modelType] val session = modelState?.sessionAddresses?.get(sessionAddress) @@ -162,19 +171,28 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { try { val env = OrtEnvironment.getEnvironment() - val inputTensorShape = longArrayOf(1, 3, 640, 640) - if (modelType == ModelType.MobileFaceNet) { - inputTensorShape[0] = inputData.size.toLong()/ FACENET_SINGLE_INPUT_SIZE - inputTensorShape[1] = 112 - inputTensorShape[2] = 112 - inputTensorShape[3] = 3 - } else if(modelType == ModelType.ClipImageEncoder) { - inputTensorShape[0] = 1 - inputTensorShape[1] = 3 - inputTensorShape[2] = 224 - inputTensorShape[3] = 224 + var inputTensorShape = longArrayOf(1, 3, 640, 640) + when (modelType) { + ModelType.MobileFaceNet -> { + val totalSize = inputDataFloat!!.size.toLong() / FACENET_SINGLE_INPUT_SIZE + inputTensorShape = longArrayOf(totalSize, 112, 112, 3) + } + ModelType.ClipImageEncoder -> { + inputTensorShape = longArrayOf(1, 3, 224, 224) + } + ModelType.ClipTextEncoder -> { + inputTensorShape = longArrayOf(1, 77) + } + ModelType.YOLOv5Face -> { + inputTensorShape = longArrayOf(1, 3, 640, 640) + } + } + + val inputTensor = when { + inputDataFloat != null -> OnnxTensor.createTensor(env, FloatBuffer.wrap(inputDataFloat), inputTensorShape) + inputDataInt != null -> OnnxTensor.createTensor(env, IntBuffer.wrap(inputDataInt), inputTensorShape) + else -> throw IllegalArgumentException("No input data provided") } - val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape) val inputs = mutableMapOf() if (modelType == ModelType.MobileFaceNet) { inputs["img_inputs"] = inputTensor diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart.dart b/mobile/plugins/onnx_dart/lib/onnx_dart.dart index f21ebae21a..e04462958b 100644 --- a/mobile/plugins/onnx_dart/lib/onnx_dart.dart +++ b/mobile/plugins/onnx_dart/lib/onnx_dart.dart @@ -22,7 +22,17 @@ class OnnxDart { int sessionAddress = 0, }) async { final result = await OnnxDartPlatform.instance - .predict(inputData, modelType, sessionAddress: sessionAddress); + .predict(inputData, null, modelType, sessionAddress: sessionAddress); + return result; + } + + Future predictInt( + Int32List inputDataInt, + String modelType, { + int sessionAddress = 0, + }) async { + final result = await OnnxDartPlatform.instance + .predict(null, inputDataInt, modelType, sessionAddress: sessionAddress); return result; } } diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart b/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart index 669ceca1eb..8d638eb1f9 100644 --- a/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart +++ b/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart @@ -38,7 +38,8 @@ class MethodChannelOnnxDart extends OnnxDartPlatform { @override Future predict( - Float32List inputData, + Float32List? inputData, + Int32List? inputDataInt, String modelType, { int sessionAddress = 0, }) { @@ -47,6 +48,7 @@ class MethodChannelOnnxDart extends OnnxDartPlatform { { 'sessionAddress': sessionAddress, 'inputData': inputData, + 'inputDataInt': inputDataInt, 'modelType': modelType, }, ); diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart b/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart index 626f5236a1..25ada82f8d 100644 --- a/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart +++ b/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart @@ -41,10 +41,13 @@ abstract class OnnxDartPlatform extends PlatformInterface { } Future predict( - Float32List inputData, + Float32List? inputData, + Int32List? inputDataInt, String modelType, { int sessionAddress = 0, }) { throw UnimplementedError('predict() has not been implemented.'); } + + }