|
|
|
|
@@ -1,6 +1,7 @@
|
|
|
|
|
package io.ente.photos.onnx_dart
|
|
|
|
|
|
|
|
|
|
import android.content.Context
|
|
|
|
|
import java.nio.IntBuffer
|
|
|
|
|
import androidx.annotation.NonNull
|
|
|
|
|
import ai.onnxruntime.*
|
|
|
|
|
import java.util.EnumMap
|
|
|
|
|
@@ -35,7 +36,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
|
|
|
|
private lateinit var context: Context
|
|
|
|
|
|
|
|
|
|
enum class ModelType {
|
|
|
|
|
CLIP_TEXT, ClipImageEncoder, YOLOv5Face, MobileFaceNet
|
|
|
|
|
ClipTextEncoder, ClipImageEncoder, YOLOv5Face, MobileFaceNet
|
|
|
|
|
}
|
|
|
|
|
companion object {
|
|
|
|
|
const val DEFAULT_SESSION_COUNT = 1
|
|
|
|
|
@@ -63,41 +64,46 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
override fun onMethodCall(call: MethodCall, result: Result) {
|
|
|
|
|
if (call.method == "getPlatformVersion") {
|
|
|
|
|
result.success("Android ${android.os.Build.VERSION.RELEASE}")
|
|
|
|
|
} else if (call.method == "init") {
|
|
|
|
|
val modelType = call.argument<String>("modelType") ?: run {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Model type is missing", null)
|
|
|
|
|
return
|
|
|
|
|
when (call.method) {
|
|
|
|
|
"getPlatformVersion" -> {
|
|
|
|
|
result.success("Android ${android.os.Build.VERSION.RELEASE}")
|
|
|
|
|
}
|
|
|
|
|
val modelPath = call.argument<String>("modelPath") ?: run {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Model path is missing", null)
|
|
|
|
|
return
|
|
|
|
|
"init" -> {
|
|
|
|
|
val modelType = call.argument<String>("modelType")
|
|
|
|
|
val modelPath = call.argument<String>("modelPath")
|
|
|
|
|
val sessionsCount = call.argument<Int>("sessionsCount") ?: DEFAULT_SESSION_COUNT
|
|
|
|
|
|
|
|
|
|
if (modelType == null || modelPath == null) {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Model type or path is missing", null)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
init(ModelType.valueOf(modelType), modelPath, sessionsCount, result)
|
|
|
|
|
}
|
|
|
|
|
val sessionsCount = call.argument<Int>("sessionsCount") ?: DEFAULT_SESSION_COUNT
|
|
|
|
|
init(ModelType.valueOf(modelType), modelPath, sessionsCount, result)
|
|
|
|
|
}
|
|
|
|
|
else if (call.method == "release" ) {
|
|
|
|
|
val modelType = call.argument<String>("modelType") ?: run {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Model type is missing", null)
|
|
|
|
|
return
|
|
|
|
|
"release" -> {
|
|
|
|
|
val modelType = call.argument<String>("modelType")
|
|
|
|
|
if (modelType == null) {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Model type is missing", null)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
release(ModelType.valueOf(modelType), result)
|
|
|
|
|
}
|
|
|
|
|
release(ModelType.valueOf(modelType), result)
|
|
|
|
|
}
|
|
|
|
|
else if (call.method == "predict" ) {
|
|
|
|
|
val sessionAddress = call.argument<Int>("sessionAddress")
|
|
|
|
|
val modelType = call.argument<String>("modelType") ?: run {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Model type is missing", null)
|
|
|
|
|
return
|
|
|
|
|
"predict" -> {
|
|
|
|
|
val sessionAddress = call.argument<Int>("sessionAddress")
|
|
|
|
|
val modelType = call.argument<String>("modelType")
|
|
|
|
|
val inputDataArray = call.argument<FloatArray>("inputData")
|
|
|
|
|
val inputIntDataArray = call.argument<IntArray>("inputDataInt")
|
|
|
|
|
|
|
|
|
|
if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null)) {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Session address, model type, or input data is missing", null)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, result)
|
|
|
|
|
}
|
|
|
|
|
val inputDataArray = call.argument<FloatArray>("inputData")
|
|
|
|
|
if (sessionAddress == null || inputDataArray == null) {
|
|
|
|
|
result.error("INVALID_ARGUMENT", "Session address or input data is missing", null)
|
|
|
|
|
return
|
|
|
|
|
else -> {
|
|
|
|
|
result.notImplemented()
|
|
|
|
|
}
|
|
|
|
|
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, result)
|
|
|
|
|
} else {
|
|
|
|
|
result.notImplemented()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -149,7 +155,10 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private fun predict(modelType: ModelType, sessionAddress: Int, inputData: FloatArray, result: Result) {
|
|
|
|
|
private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, result: Result) {
|
|
|
|
|
// Assert that exactly one of inputDataFloat or inputDataInt is provided
|
|
|
|
|
assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" }
|
|
|
|
|
|
|
|
|
|
scope.launch {
|
|
|
|
|
val modelState = sessionMap[modelType]
|
|
|
|
|
val session = modelState?.sessionAddresses?.get(sessionAddress)
|
|
|
|
|
@@ -162,19 +171,28 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
val env = OrtEnvironment.getEnvironment()
|
|
|
|
|
val inputTensorShape = longArrayOf(1, 3, 640, 640)
|
|
|
|
|
if (modelType == ModelType.MobileFaceNet) {
|
|
|
|
|
inputTensorShape[0] = inputData.size.toLong()/ FACENET_SINGLE_INPUT_SIZE
|
|
|
|
|
inputTensorShape[1] = 112
|
|
|
|
|
inputTensorShape[2] = 112
|
|
|
|
|
inputTensorShape[3] = 3
|
|
|
|
|
} else if(modelType == ModelType.ClipImageEncoder) {
|
|
|
|
|
inputTensorShape[0] = 1
|
|
|
|
|
inputTensorShape[1] = 3
|
|
|
|
|
inputTensorShape[2] = 224
|
|
|
|
|
inputTensorShape[3] = 224
|
|
|
|
|
var inputTensorShape = longArrayOf(1, 3, 640, 640)
|
|
|
|
|
when (modelType) {
|
|
|
|
|
ModelType.MobileFaceNet -> {
|
|
|
|
|
val totalSize = inputDataFloat!!.size.toLong() / FACENET_SINGLE_INPUT_SIZE
|
|
|
|
|
inputTensorShape = longArrayOf(totalSize, 112, 112, 3)
|
|
|
|
|
}
|
|
|
|
|
ModelType.ClipImageEncoder -> {
|
|
|
|
|
inputTensorShape = longArrayOf(1, 3, 224, 224)
|
|
|
|
|
}
|
|
|
|
|
ModelType.ClipTextEncoder -> {
|
|
|
|
|
inputTensorShape = longArrayOf(1, 77)
|
|
|
|
|
}
|
|
|
|
|
ModelType.YOLOv5Face -> {
|
|
|
|
|
inputTensorShape = longArrayOf(1, 3, 640, 640)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
val inputTensor = when {
|
|
|
|
|
inputDataFloat != null -> OnnxTensor.createTensor(env, FloatBuffer.wrap(inputDataFloat), inputTensorShape)
|
|
|
|
|
inputDataInt != null -> OnnxTensor.createTensor(env, IntBuffer.wrap(inputDataInt), inputTensorShape)
|
|
|
|
|
else -> throw IllegalArgumentException("No input data provided")
|
|
|
|
|
}
|
|
|
|
|
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
|
|
|
|
|
val inputs = mutableMapOf<String, OnnxTensor>()
|
|
|
|
|
if (modelType == ModelType.MobileFaceNet) {
|
|
|
|
|
inputs["img_inputs"] = inputTensor
|
|
|
|
|
|