[mob] Add support for Clip text in custom plugin

This commit is contained in:
Neeraj Gupta
2024-07-18 16:43:31 +05:30
parent 578c8a032f
commit 227877bd07
4 changed files with 80 additions and 47 deletions

View File

@@ -1,6 +1,7 @@
package io.ente.photos.onnx_dart
import android.content.Context
import java.nio.IntBuffer
import androidx.annotation.NonNull
import ai.onnxruntime.*
import java.util.EnumMap
@@ -35,7 +36,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
private lateinit var context: Context
enum class ModelType {
CLIP_TEXT, ClipImageEncoder, YOLOv5Face, MobileFaceNet
ClipTextEncoder, ClipImageEncoder, YOLOv5Face, MobileFaceNet
}
companion object {
const val DEFAULT_SESSION_COUNT = 1
@@ -63,41 +64,46 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
override fun onMethodCall(call: MethodCall, result: Result) {
if (call.method == "getPlatformVersion") {
result.success("Android ${android.os.Build.VERSION.RELEASE}")
} else if (call.method == "init") {
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
when (call.method) {
"getPlatformVersion" -> {
result.success("Android ${android.os.Build.VERSION.RELEASE}")
}
val modelPath = call.argument<String>("modelPath") ?: run {
result.error("INVALID_ARGUMENT", "Model path is missing", null)
return
"init" -> {
val modelType = call.argument<String>("modelType")
val modelPath = call.argument<String>("modelPath")
val sessionsCount = call.argument<Int>("sessionsCount") ?: DEFAULT_SESSION_COUNT
if (modelType == null || modelPath == null) {
result.error("INVALID_ARGUMENT", "Model type or path is missing", null)
return
}
init(ModelType.valueOf(modelType), modelPath, sessionsCount, result)
}
val sessionsCount = call.argument<Int>("sessionsCount") ?: DEFAULT_SESSION_COUNT
init(ModelType.valueOf(modelType), modelPath, sessionsCount, result)
}
else if (call.method == "release" ) {
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
"release" -> {
val modelType = call.argument<String>("modelType")
if (modelType == null) {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
}
release(ModelType.valueOf(modelType), result)
}
release(ModelType.valueOf(modelType), result)
}
else if (call.method == "predict" ) {
val sessionAddress = call.argument<Int>("sessionAddress")
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
"predict" -> {
val sessionAddress = call.argument<Int>("sessionAddress")
val modelType = call.argument<String>("modelType")
val inputDataArray = call.argument<FloatArray>("inputData")
val inputIntDataArray = call.argument<IntArray>("inputDataInt")
if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null)) {
result.error("INVALID_ARGUMENT", "Session address, model type, or input data is missing", null)
return
}
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, result)
}
val inputDataArray = call.argument<FloatArray>("inputData")
if (sessionAddress == null || inputDataArray == null) {
result.error("INVALID_ARGUMENT", "Session address or input data is missing", null)
return
else -> {
result.notImplemented()
}
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, result)
} else {
result.notImplemented()
}
}
@@ -149,7 +155,10 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
}
}
private fun predict(modelType: ModelType, sessionAddress: Int, inputData: FloatArray, result: Result) {
private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, result: Result) {
// Assert that exactly one of inputDataFloat or inputDataInt is provided
assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" }
scope.launch {
val modelState = sessionMap[modelType]
val session = modelState?.sessionAddresses?.get(sessionAddress)
@@ -162,19 +171,28 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
try {
val env = OrtEnvironment.getEnvironment()
val inputTensorShape = longArrayOf(1, 3, 640, 640)
if (modelType == ModelType.MobileFaceNet) {
inputTensorShape[0] = inputData.size.toLong()/ FACENET_SINGLE_INPUT_SIZE
inputTensorShape[1] = 112
inputTensorShape[2] = 112
inputTensorShape[3] = 3
} else if(modelType == ModelType.ClipImageEncoder) {
inputTensorShape[0] = 1
inputTensorShape[1] = 3
inputTensorShape[2] = 224
inputTensorShape[3] = 224
var inputTensorShape = longArrayOf(1, 3, 640, 640)
when (modelType) {
ModelType.MobileFaceNet -> {
val totalSize = inputDataFloat!!.size.toLong() / FACENET_SINGLE_INPUT_SIZE
inputTensorShape = longArrayOf(totalSize, 112, 112, 3)
}
ModelType.ClipImageEncoder -> {
inputTensorShape = longArrayOf(1, 3, 224, 224)
}
ModelType.ClipTextEncoder -> {
inputTensorShape = longArrayOf(1, 77)
}
ModelType.YOLOv5Face -> {
inputTensorShape = longArrayOf(1, 3, 640, 640)
}
}
val inputTensor = when {
inputDataFloat != null -> OnnxTensor.createTensor(env, FloatBuffer.wrap(inputDataFloat), inputTensorShape)
inputDataInt != null -> OnnxTensor.createTensor(env, IntBuffer.wrap(inputDataInt), inputTensorShape)
else -> throw IllegalArgumentException("No input data provided")
}
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
val inputs = mutableMapOf<String, OnnxTensor>()
if (modelType == ModelType.MobileFaceNet) {
inputs["img_inputs"] = inputTensor

View File

@@ -22,7 +22,17 @@ class OnnxDart {
int sessionAddress = 0,
}) async {
final result = await OnnxDartPlatform.instance
.predict(inputData, modelType, sessionAddress: sessionAddress);
.predict(inputData, null, modelType, sessionAddress: sessionAddress);
return result;
}
Future<Float32List?> predictInt(
Int32List inputDataInt,
String modelType, {
int sessionAddress = 0,
}) async {
final result = await OnnxDartPlatform.instance
.predict(null, inputDataInt, modelType, sessionAddress: sessionAddress);
return result;
}
}

View File

@@ -38,7 +38,8 @@ class MethodChannelOnnxDart extends OnnxDartPlatform {
@override
Future<Float32List?> predict(
Float32List inputData,
Float32List? inputData,
Int32List? inputDataInt,
String modelType, {
int sessionAddress = 0,
}) {
@@ -47,6 +48,7 @@ class MethodChannelOnnxDart extends OnnxDartPlatform {
{
'sessionAddress': sessionAddress,
'inputData': inputData,
'inputDataInt': inputDataInt,
'modelType': modelType,
},
);

View File

@@ -41,10 +41,13 @@ abstract class OnnxDartPlatform extends PlatformInterface {
}
Future<Float32List?> predict(
Float32List inputData,
Float32List? inputData,
Int32List? inputDataInt,
String modelType, {
int sessionAddress = 0,
}) {
throw UnimplementedError('predict() has not been implemented.');
}
}