[mob][droid] Use custom dart plugin for face detection

This commit is contained in:
Neeraj Gupta
2024-07-08 17:41:37 +05:30
parent 37519179a1
commit 273e5b88eb
21 changed files with 913 additions and 301 deletions

View File

@@ -3,11 +3,10 @@ package io.ente.photos
import io.flutter.embedding.android.FlutterFragmentActivity
import io.flutter.embedding.engine.FlutterEngine
import io.flutter.plugins.GeneratedPluginRegistrant
import io.ente.photos.ml.onnx.EnteOnnxFlutterPlugin
class MainActivity : FlutterFragmentActivity() {
override fun configureFlutterEngine(flutterEngine: FlutterEngine) {
GeneratedPluginRegistrant.registerWith(flutterEngine)
flutterEngine.plugins.add(EnteOnnxFlutterPlugin())
}
}

View File

@@ -1,207 +0,0 @@
package io.ente.photos.ml.onnx
import android.content.Context
import androidx.annotation.NonNull
import ai.onnxruntime.*
import io.flutter.embedding.engine.plugins.FlutterPlugin
import io.flutter.plugin.common.MethodCall
import io.flutter.plugin.common.MethodChannel
import io.flutter.plugin.common.MethodChannel.MethodCallHandler
import io.flutter.plugin.common.MethodChannel.Result
import kotlinx.coroutines.*
import java.nio.FloatBuffer
import java.util.concurrent.ConcurrentHashMap
import android.util.Log
import java.io.File
import java.util.concurrent.ConcurrentLinkedQueue
object LongArrayPool {
private val pool = ConcurrentLinkedQueue<LongArray>()
fun get(size: Int): LongArray {
return pool.poll() ?: LongArray(size)
}
fun release(array: LongArray) {
pool.offer(array)
}
}
class EnteOnnxFlutterPlugin : FlutterPlugin, MethodCallHandler {
private var faceOrtEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
private lateinit var channel: MethodChannel
private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob())
private val sessionMap = ConcurrentHashMap<ModelType, ModelState>()
private lateinit var context: Context
companion object {
const val DEFAULT_SESSION_COUNT = 1
const val K_INPUT_WIDTH = 640
const val K_INPUT_HEIGHT = 640
const val K_NUM_CHANNELS = 3
}
enum class ModelType {
CLIP_TEXT, CLIP_VISUAL, YOLO_FACE, MOBILENET_FACE
}
data class ModelState(
var isInitialized: Boolean = false,
val sessionAddresses: ConcurrentHashMap<Int, OrtSession> = ConcurrentHashMap(),
// number of sessions that should have been created for given model
var sessionsCount: Int = DEFAULT_SESSION_COUNT
)
override fun onAttachedToEngine(@NonNull flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) {
channel = MethodChannel(flutterPluginBinding.binaryMessenger, "ente_onnx_flutter_plugin")
channel.setMethodCallHandler(this)
context = flutterPluginBinding.applicationContext
}
override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) {
channel.setMethodCallHandler(null)
releaseAllSessions()
scope.cancel()
}
override fun onMethodCall(@NonNull call: MethodCall, @NonNull result: Result) {
when (call.method) {
"init" -> {
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
}
val modelPath = call.argument<String>("modelPath") ?: run {
result.error("INVALID_ARGUMENT", "Model path is missing", null)
return
}
val sessionsCount = call.argument<Int>("sessionsCount") ?: DEFAULT_SESSION_COUNT
init(ModelType.valueOf(modelType), modelPath, sessionsCount, result)
}
"release" -> {
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
}
release(ModelType.valueOf(modelType), result)
}
"predict" -> {
val sessionAddress = call.argument<Int>("sessionAddress")
val inputData = call.argument<List<Double>>("inputData")
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
}
if (sessionAddress == null || inputData == null) {
result.error("INVALID_ARGUMENT", "Session address or input data is missing", null)
return
}
val inputDataArray = inputData.map { it.toFloat() }.toFloatArray()
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, result)
}
else -> result.notImplemented()
}
}
private fun readModelFile(modelPath: String): ByteArray {
return File(modelPath).readBytes()
}
private fun init(modelType: ModelType, modelPath: String, sessionsCount: Int, result: Result) {
scope.launch {
val modelState: ModelState
if (sessionMap.containsKey(modelType)) {
modelState = sessionMap[modelType]!!
} else {
modelState = ModelState()
sessionMap[modelType] = modelState
}
if (!modelState.isInitialized) {
for (i in 0 until sessionsCount) {
val session = createSession(faceOrtEnv, modelPath)
if (session != null) {
modelState.sessionAddresses[i] = session
}
}
modelState.isInitialized = true
modelState.sessionsCount = sessionsCount
withContext(Dispatchers.Main) {
result.success(true)
}
} else {
withContext(Dispatchers.Main) {
result.success(false)
}
}
}
}
private fun release(modelType: ModelType, result: Result) {
scope.launch {
val modelState = sessionMap[modelType]
modelState?.let {
it.sessionAddresses.forEach { entry: Map.Entry<Int, OrtSession> ->
entry.value.close()
}
it.sessionAddresses.clear()
it.isInitialized = false
}
withContext(Dispatchers.Main) {
result.success(true)
}
}
}
private fun predict(modelType: ModelType, sessionAddress: Int, inputData: FloatArray, result: Result) {
scope.launch {
val modelState = sessionMap[modelType]
val session = modelState?.sessionAddresses?.get(sessionAddress)
if (session == null) {
withContext(Dispatchers.Main) {
result.error("SESSION_NOT_FOUND", "Session not found for address: $sessionAddress", null)
}
return@launch
}
try {
val env = OrtEnvironment.getEnvironment()
val inputTensorShape = LongArrayPool.get(4).apply {
this[0] = 1
this[1] = K_NUM_CHANNELS.toLong()
this[2] = K_INPUT_HEIGHT.toLong()
this[3] = K_INPUT_WIDTH.toLong()
}
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
val inputs = mapOf("input" to inputTensor)
val outputs = session.run(inputs)
Log.d("OnnxFlutterPlugin", "Output shape: ${outputs.size()}")
inputTensor.close()
outputs.close()
LongArrayPool.release(inputTensorShape)
withContext(Dispatchers.Main) {
val dummyResult = listOf(0.1, 0.2) // Replace with actual result processing
result.success(dummyResult)
}
} catch (e: OrtException) {
withContext(Dispatchers.Main) {
result.error("PREDICTION_ERROR", "Error during prediction: ${e.message}", null)
}
}
}
}
private fun createSession(env: OrtEnvironment, modalPath: String): OrtSession? {
return env.createSession(modalPath, OrtSession.SessionOptions())
}
private fun releaseAllSessions() {
sessionMap.forEach { (_, modelState) ->
modelState.sessionAddresses.forEach { entry ->
entry.value.close()
}
modelState.sessionAddresses.clear()
}
sessionMap.clear()
}
}

View File

@@ -466,7 +466,7 @@ SPEC CHECKSUMS:
package_info_plus: 115f4ad11e0698c8c1c5d8a689390df880f47e85
path_provider_foundation: 3784922295ac71e43754bd15e0653ccfd36a147c
permission_handler_apple: 9878588469a2b0d0fc1e048d9f43605f92e6cec2
photo_manager: 4f6810b7dfc4feb03b461ac1a70dacf91fba7604
photo_manager: ff695c7a1dd5bc379974953a2b5c0a293f7c4c8a
PromisesObjC: f5707f49cb48b9636751c5b2e7d227e43fba9f47
receive_sharing_intent: 6837b01768e567fe8562182397bf43d63d8c6437
screen_brightness_ios: 715ca807df953bf676d339f11464e438143ee625

View File

@@ -1,41 +0,0 @@
import 'package:flutter/services.dart';
class OnnxFlutterPlugin {
static const MethodChannel _channel =
MethodChannel('ente_onnx_flutter_plugin');
static Future<bool> init(
String modelType,
String modelPath, {
int sessionsCount = 1,
}) async {
final bool result = await _channel.invokeMethod('init', {
'modelType': modelType,
'modelPath': modelPath,
'sessionsCount': sessionsCount,
});
return result;
}
static Future<bool> release(String modelType) async {
final bool result =
await _channel.invokeMethod('release', {'modelType': modelType});
return result;
}
static Future<List<double>> predict(
List<double> inputData,
String modelType, {
int sessionAddress = 0,
}) async {
final List<dynamic> result = await _channel.invokeMethod(
'predict',
{
'sessionAddress': sessionAddress,
'inputData': inputData,
'modelType': modelType,
},
);
return result.cast<double>();
}
}

View File

@@ -1,11 +1,12 @@
import "dart:async";
import "dart:developer" as dev show log;
import "dart:io" show File;
import 'dart:typed_data' show ByteData;
import "dart:io" show File, Platform;
import 'dart:typed_data' show ByteBuffer, ByteData, Float32List, Uint8List;
import 'dart:ui' as ui show Image;
import "package:computer/computer.dart";
import 'package:logging/logging.dart';
import "package:onnx_dart/onnx_dart.dart";
import 'package:onnxruntime/onnxruntime.dart';
import "package:photos/face/model/dimension.dart";
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
@@ -34,6 +35,7 @@ class FaceDetectionService {
static const double kIouThreshold = 0.4;
static const double kMinScoreSigmoidThreshold = 0.7;
static const int kNumKeypoints = 5;
static bool useCustomPlugin = Platform.isAndroid;
bool isInitialized = false;
@@ -49,19 +51,23 @@ class FaceDetectionService {
final model =
await RemoteAssetsService.instance.getAsset(modelRemotePath);
final startTime = DateTime.now();
sessionAddress = await _computer.compute(
_loadModel,
param: {
"modelPath": model.path,
},
);
if (useCustomPlugin) {
final OnnxDart plugin = OnnxDart();
final bool? initResult = await plugin.init("YOLO_FACE", model.path);
isInitialized = initResult ?? false;
} else {
sessionAddress = await _computer.compute(
_loadModel,
param: {
"modelPath": model.path,
},
);
isInitialized = sessionAddress != -1;
}
final endTime = DateTime.now();
_logger.info(
"Face detection model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
);
if (sessionAddress != -1) {
isInitialized = true;
}
}
}
@@ -80,7 +86,10 @@ class FaceDetectionService {
ByteData imageByteData,
int sessionAddress,
) async {
assert(sessionAddress != 0 && sessionAddress != -1);
assert(
!useCustomPlugin ? (sessionAddress != 0 && sessionAddress != -1) : true,
'sessionAddress should be valid',
);
final stopwatch = Stopwatch()..start();
@@ -94,18 +103,6 @@ class FaceDetectionService {
requiredHeight: kInputHeight,
maintainAspectRatio: true,
);
final inputShape = [
1,
3,
kInputHeight,
kInputWidth,
];
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputShape,
);
final inputs = {'input': inputOrt};
stopwatchPreprocessing.stop();
dev.log(
'Face detection image preprocessing is finished, in ${stopwatchPreprocessing.elapsedMilliseconds}ms',
@@ -117,39 +114,90 @@ class FaceDetectionService {
// Run inference
final stopwatchInterpreter = Stopwatch()..start();
List<OrtValue?>? outputs;
List<List<List<double>>>? nestedResults = [];
try {
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
outputs = session.run(runOptions, inputs);
// inputOrt.release();
// runOptions.release();
if (useCustomPlugin) {
nestedResults = await _runCustomPlugin(inputImageList);
} else {
nestedResults = _runJNIBasedPlugin(
sessionAddress,
inputImageList,
); // [1, 25200, 16]
}
} catch (e, s) {
_logger.severe('Error while running inference: $e \n $s');
dev.log('Error while running inference', error: e, stackTrace: s);
throw YOLOFaceInterpreterRunException();
}
stopwatchInterpreter.stop();
_logger.info(
'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms',
try {
_logger.info(
'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms',
);
final relativeDetections =
_yoloPostProcessOutputs(nestedResults!, newSize);
stopwatch.stop();
_logger.info(
'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms',
);
return (relativeDetections, originalSize);
} catch (e, s) {
_logger.severe('Error while post processing', e, s);
rethrow;
}
}
static List<List<List<double>>>? _runJNIBasedPlugin(
int sessionAddress,
Float32List inputImageList,
) {
final inputShape = [
1,
3,
kInputHeight,
kInputWidth,
];
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputShape,
);
final inputs = {'input': inputOrt};
final relativeDetections = _yoloPostProcessOutputs(outputs, newSize);
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
final List<OrtValue?>? outputs = session.run(runOptions, inputs);
// inputOrt.release();
// runOptions.release();
return outputs?[0]?.value as List<List<List<double>>>; // [1, 25200, 16]
}
stopwatch.stop();
_logger.info(
'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms',
static Future<List<List<List<double>>>> _runCustomPlugin(
Float32List inputImageList,
) async {
final OnnxDart plugin = OnnxDart();
final result = await plugin.predict(
inputImageList,
"YOLO_FACE",
);
return (relativeDetections, originalSize);
final List<List<List<double>>> reconstructedTensor = [];
for (int i = 0; i < result.length; i += 25200 * 16) {
final List<List<double>> outerArray = [];
for (int j = 0; j < 25200; j++) {
final List<double> innerArray =
result.sublist(i + j * 16, i + (j + 1) * 16).cast<double>();
outerArray.add(innerArray);
}
reconstructedTensor.add(outerArray);
}
return reconstructedTensor;
}
static List<FaceDetectionRelative> _yoloPostProcessOutputs(
List<OrtValue?>? outputs,
List<List<List<double>>> nestedResults,
Dimensions newSize,
) {
// // Get output tensors
final nestedResults =
outputs?[0]?.value as List<List<List<double>>>; // [1, 25200, 16]
final firstResults = nestedResults[0]; // [25200, 16]
// Filter output
@@ -160,11 +208,6 @@ class FaceDetectionService {
results: firstResults,
);
// Release outputs
// outputs?.forEach((element) {
// element?.release();
// });
// Account for the fact that the aspect ratio was maintained
for (final faceDetection in relativeDetections) {
faceDetection.correctForMaintainedAspectRatio(

View File

@@ -744,6 +744,7 @@ class FaceMlService {
final sendPort = message[2] as SendPort;
try {
await FaceDetectionService.instance.init();
switch (function) {
case FaceMlOperation.analyzeImage:
final time = DateTime.now();
@@ -838,7 +839,6 @@ class FaceMlService {
final String? filePath =
await _getImagePathForML(enteFile, typeOfData: FileDataForML.fileData);
if (filePath == null) {
_logger.warning(
"Failed to get any data for enteFile with uploadedFileID ${enteFile.uploadedFileID} since its file path is null",

View File

@@ -0,0 +1,30 @@
# This file tracks properties of this Flutter project.
# Used by Flutter tool to assess capabilities and perform upgrades etc.
#
# This file should be version controlled and should not be manually edited.
version:
revision: "761747bfc538b5af34aa0d3fac380f1bc331ec49"
channel: "[user-branch]"
project_type: plugin
# Tracks metadata for the flutter migrate command
migration:
platforms:
- platform: root
create_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49
base_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49
- platform: android
create_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49
base_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49
# User provided section
# List of Local paths (relative to this file) that should be
# ignored by the migrate tool.
#
# Files that are not part of the templates will be ignored by default.
unmanaged_files:
- 'lib/main.dart'
- 'ios/Runner.xcodeproj/project.pbxproj'

View File

@@ -0,0 +1,15 @@
# onnx_dart
A new Flutter plugin project.
## Getting Started
This project is a starting point for a Flutter
[plug-in package](https://flutter.dev/developing-packages/),
a specialized package that includes platform-specific implementation code for
Android and/or iOS.
For help getting started with Flutter development, view the
[online documentation](https://flutter.dev/docs), which offers tutorials,
samples, guidance on mobile development, and a full API reference.

View File

@@ -0,0 +1 @@
include: ../../analysis_options.yaml

View File

@@ -0,0 +1,70 @@
group = "io.ente.photos.onnx_dart"
version = "1.0-SNAPSHOT"
buildscript {
ext.kotlin_version = "1.7.10"
repositories {
google()
mavenCentral()
}
dependencies {
classpath("com.android.tools.build:gradle:7.3.0")
classpath("org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version")
}
}
allprojects {
repositories {
google()
mavenCentral()
}
}
apply plugin: "com.android.library"
apply plugin: "kotlin-android"
android {
if (project.android.hasProperty("namespace")) {
namespace = "io.ente.photos.onnx_dart"
}
compileSdk = 34
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = "1.8"
}
sourceSets {
main.java.srcDirs += "src/main/kotlin"
test.java.srcDirs += "src/test/kotlin"
}
defaultConfig {
minSdk = 21
}
dependencies {
testImplementation("org.jetbrains.kotlin:kotlin-test")
testImplementation("org.mockito:mockito-core:5.0.0")
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
}
testOptions {
unitTests.all {
useJUnitPlatform()
testLogging {
events "passed", "skipped", "failed", "standardOut", "standardError"
outputs.upToDateWhen {false}
showStandardStreams = true
}
}
}
}

View File

@@ -0,0 +1,2 @@
sdk.dir=/Users/ua741/Library/Android/sdk
flutter.sdk=/Users/ua741/work/flutter

View File

@@ -0,0 +1 @@
rootProject.name = 'onnx_dart'

View File

@@ -0,0 +1,3 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="io.ente.photos.onnx_dart">
</manifest>

View File

@@ -0,0 +1,253 @@
package io.ente.photos.onnx_dart
import android.content.Context
import androidx.annotation.NonNull
import ai.onnxruntime.*
import io.flutter.embedding.engine.plugins.FlutterPlugin
import io.flutter.plugin.common.MethodCall
import io.flutter.plugin.common.MethodChannel
import io.flutter.plugin.common.MethodChannel.MethodCallHandler
import io.flutter.plugin.common.MethodChannel.Result
import kotlinx.coroutines.*
import java.nio.FloatBuffer
import java.util.concurrent.ConcurrentHashMap
import android.util.Log
import java.io.File
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.concurrent.ConcurrentLinkedQueue
object LongArrayPool {
private val poolMap = ConcurrentHashMap<Int, ConcurrentLinkedQueue<LongArray>>()
fun get(size: Int): LongArray {
val pool = getPool(size)
return pool.poll() ?: LongArray(size)
}
fun release(array: LongArray) {
val pool = getPool(array.size)
pool.offer(array)
}
private fun getPool(size: Int): ConcurrentLinkedQueue<LongArray> {
var pool = poolMap[size]
if (pool == null) {
synchronized(poolMap) {
pool = poolMap[size]
if (pool == null) {
pool = ConcurrentLinkedQueue()
poolMap[size] = pool!!
}
}
}
return pool!!
}
}
/** OnnxDartPlugin */
class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
/// The MethodChannel that will the communication between Flutter and native Android
///
/// This local reference serves to register the plugin with the Flutter Engine and unregister it
/// when the Flutter Engine is detached from the Activity
private lateinit var channel : MethodChannel
private var faceOrtEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob())
private val sessionMap = ConcurrentHashMap<ModelType, ModelState>()
private lateinit var context: Context
companion object {
const val DEFAULT_SESSION_COUNT = 1
const val K_INPUT_WIDTH = 640
const val K_INPUT_HEIGHT = 640
const val K_NUM_CHANNELS = 3
}
enum class ModelType {
CLIP_TEXT, CLIP_VISUAL, YOLO_FACE, MOBILENET_FACE
}
data class ModelState(
var isInitialized: Boolean = false,
val sessionAddresses: ConcurrentHashMap<Int, OrtSession> = ConcurrentHashMap(),
// number of sessions that should have been created for given model
var sessionsCount: Int = DEFAULT_SESSION_COUNT
)
override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) {
channel = MethodChannel(flutterPluginBinding.binaryMessenger, "onnx_dart")
channel.setMethodCallHandler(this)
context = flutterPluginBinding.applicationContext
}
override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) {
channel.setMethodCallHandler(null)
}
override fun onMethodCall(call: MethodCall, result: Result) {
if (call.method == "getPlatformVersion") {
result.success("Android ${android.os.Build.VERSION.RELEASE}")
} else if (call.method == "init") {
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
}
val modelPath = call.argument<String>("modelPath") ?: run {
result.error("INVALID_ARGUMENT", "Model path is missing", null)
return
}
val sessionsCount = call.argument<Int>("sessionsCount") ?: DEFAULT_SESSION_COUNT
init(ModelType.valueOf(modelType), modelPath, sessionsCount, result)
}
else if (call.method == "release" ) {
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
}
release(ModelType.valueOf(modelType), result)
}
else if (call.method == "predict" ) {
val sessionAddress = call.argument<Int>("sessionAddress")
val modelType = call.argument<String>("modelType") ?: run {
result.error("INVALID_ARGUMENT", "Model type is missing", null)
return
}
val inputDataArray = call.argument<FloatArray>("inputData")
if (sessionAddress == null || inputDataArray == null) {
result.error("INVALID_ARGUMENT", "Session address or input data is missing", null)
return
}
// val inputDataArray = inputData.map { it.toFloat() }.toFloatArray()
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, result)
} else {
result.notImplemented()
}
}
private fun readModelFile(modelPath: String): ByteArray {
return File(modelPath).readBytes()
}
private fun init(modelType: ModelType, modelPath: String, sessionsCount: Int, result: Result) {
Log.d("OnnxFlutterPlugin", " v: $modelType, path: $modelPath, sessionsCount: $sessionsCount")
scope.launch {
val modelState: ModelState
if (sessionMap.containsKey(modelType)) {
modelState = sessionMap[modelType]!!
} else {
modelState = ModelState()
sessionMap[modelType] = modelState
}
if (!modelState.isInitialized) {
for (i in 0 until sessionsCount) {
val session = createSession(faceOrtEnv, modelPath)
if (session != null) {
modelState.sessionAddresses[i] = session
}
}
modelState.isInitialized = true
modelState.sessionsCount = sessionsCount
withContext(Dispatchers.Main) {
Log.d("OnnxFlutterPlugin", "Model initialized: $modelType")
result.success(true)
}
} else {
withContext(Dispatchers.Main) {
result.success(false)
}
}
}
}
private fun release(modelType: ModelType, result: Result) {
scope.launch {
val modelState = sessionMap[modelType]
modelState?.let {
it.sessionAddresses.forEach { entry: Map.Entry<Int, OrtSession> ->
entry.value.close()
}
it.sessionAddresses.clear()
it.isInitialized = false
}
withContext(Dispatchers.Main) {
result.success(true)
}
}
}
private fun predict(modelType: ModelType, sessionAddress: Int, inputData: FloatArray, result: Result) {
scope.launch {
val modelState = sessionMap[modelType]
val session = modelState?.sessionAddresses?.get(sessionAddress)
if (session == null) {
withContext(Dispatchers.Main) {
result.error("SESSION_NOT_FOUND", "Session not found for address: $sessionAddress", null)
}
return@launch
}
try {
val env = OrtEnvironment.getEnvironment()
val inputTensorShape = LongArrayPool.get(4).apply {
this[0] = 1
this[1] = K_NUM_CHANNELS.toLong()
this[2] = K_INPUT_HEIGHT.toLong()
this[3] = K_INPUT_WIDTH.toLong()
}
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
val inputs = mapOf("input" to inputTensor)
val outputs = session.run(inputs)
Log.d("OnnxFlutterPlugin", "Output shape: ${outputs.size()}")
inputTensor.close()
val totalSize = 1 * 25200 * 16
val flatArray = FloatArray(totalSize) { index -> index + 1.0f }
val outputTensor2 = Array(1) { outerIndex ->
Array(25200) { innerIndex ->
val startIndex = (outerIndex * 25200 + innerIndex) * 16
flatArray.sliceArray(startIndex until startIndex + 16)
}
}
val outputTensor = outputs[0].value as Array<Array<FloatArray>>
Log.d("OnnxFlutterPlugin", "Output2 shape: ${outputTensor.size}")
outputs.close()
LongArrayPool.release(inputTensorShape)
// Send the result back to the Dart layer
val flatList = outputTensor.flatten().flatMap { it.toList() }
val flatList2 = outputTensor2.flatten().flatMap { it.toList() }
withContext(Dispatchers.Main) {
result.success(flatList)
// result.success(flatList2)
}
} catch (e: OrtException) {
withContext(Dispatchers.Main) {
result.error("PREDICTION_ERROR", "Error during prediction: ${e.message}", null)
}
}
}
}
private fun createSession(env: OrtEnvironment, modalPath: String): OrtSession? {
return env.createSession(modalPath, OrtSession.SessionOptions())
}
private fun releaseAllSessions() {
sessionMap.forEach { (_, modelState) ->
modelState.sessionAddresses.forEach { entry ->
entry.value.close()
}
modelState.sessionAddresses.clear()
}
sessionMap.clear()
}
}

View File

@@ -0,0 +1,27 @@
import 'dart:typed_data';
import 'package:onnx_dart/onnx_dart_platform_interface.dart';
class OnnxDart {
Future<String?> getPlatformVersion() {
return OnnxDartPlatform.instance.getPlatformVersion();
}
Future<bool?> init(
String modelType,
String modelPath, {
int sessionsCount = 1,
}) {
return OnnxDartPlatform.instance
.init(modelType, modelPath, sessionsCount: sessionsCount);
}
Future<dynamic?> predict(
Float32List inputData,
String modelType, {
int sessionAddress = 0,
}) async {
return OnnxDartPlatform.instance
.predict(inputData, modelType, sessionAddress: sessionAddress);
}
}

View File

@@ -0,0 +1,73 @@
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'onnx_dart_platform_interface.dart';
/// An implementation of [OnnxDartPlatform] that uses method channels.
class MethodChannelOnnxDart extends OnnxDartPlatform {
/// The method channel used to interact with the native platform.
@visibleForTesting
final methodChannel = const MethodChannel('onnx_dart');
@override
Future<String?> getPlatformVersion() async {
final version =
await methodChannel.invokeMethod<String>('getPlatformVersion');
return version;
}
@override
Future<bool?> init(
String modelType,
String modelPath, {
int sessionsCount = 1,
}) async {
final result = await methodChannel.invokeMethod<bool>('init', {
'modelType': modelType,
'modelPath': modelPath,
'sessionsCount': sessionsCount,
});
return result;
}
@override
Future<bool?> release(String modelType) async {
final bool? result = await methodChannel
.invokeMethod<bool>('release', {'modelType': modelType});
return result;
}
// @override
// Future<List<double>?> predict(
// List<double> inputData,
// String modelType, {
// int sessionAddress = 0,
// }) async {
// final List<dynamic>? result =
// await methodChannel.invokeMethod<List<double>?>(
// 'predict',
// {
// 'sessionAddress': sessionAddress,
// 'inputData': inputData,
// 'modelType': modelType,
// },
// );
// return result!.cast<double>();
// }
@override
Future<dynamic?> predict(
Float32List inputData,
String modelType, {
int sessionAddress = 0,
}) {
return methodChannel.invokeMethod<dynamic?>(
'predict',
{
'sessionAddress': sessionAddress,
'inputData': inputData,
'modelType': modelType,
},
);
}
}

View File

@@ -0,0 +1,51 @@
import 'dart:typed_data';
import 'package:plugin_platform_interface/plugin_platform_interface.dart';
import 'onnx_dart_method_channel.dart';
abstract class OnnxDartPlatform extends PlatformInterface {
/// Constructs a OnnxDartPlatform.
OnnxDartPlatform() : super(token: _token);
static final Object _token = Object();
static OnnxDartPlatform _instance = MethodChannelOnnxDart();
/// The default instance of [OnnxDartPlatform] to use.
///
/// Defaults to [MethodChannelOnnxDart].
static OnnxDartPlatform get instance => _instance;
/// Platform-specific implementations should set this with their own
/// platform-specific class that extends [OnnxDartPlatform] when
/// they register themselves.
static set instance(OnnxDartPlatform instance) {
PlatformInterface.verifyToken(instance, _token);
_instance = instance;
}
Future<String?> getPlatformVersion() {
throw UnimplementedError('platformVersion() has not been implemented.');
}
Future<bool?> init(
String modelType,
String modelPath, {
int sessionsCount = 1,
}) {
throw UnimplementedError('init() has not been implemented.');
}
Future<bool?> release(String modelType) {
throw UnimplementedError('release() has not been implemented.');
}
Future<dynamic?> predict(
Float32List inputData,
String modelType, {
int sessionAddress = 0,
}) {
throw UnimplementedError('predict() has not been implemented.');
}
}

View File

@@ -0,0 +1,213 @@
# Generated by pub
# See https://dart.dev/tools/pub/glossary#lockfile
packages:
async:
dependency: transitive
description:
name: async
sha256: "947bfcf187f74dbc5e146c9eb9c0f10c9f8b30743e341481c1e2ed3ecc18c20c"
url: "https://pub.dev"
source: hosted
version: "2.11.0"
boolean_selector:
dependency: transitive
description:
name: boolean_selector
sha256: "6cfb5af12253eaf2b368f07bacc5a80d1301a071c73360d746b7f2e32d762c66"
url: "https://pub.dev"
source: hosted
version: "2.1.1"
characters:
dependency: transitive
description:
name: characters
sha256: "04a925763edad70e8443c99234dc3328f442e811f1d8fd1a72f1c8ad0f69a605"
url: "https://pub.dev"
source: hosted
version: "1.3.0"
clock:
dependency: transitive
description:
name: clock
sha256: cb6d7f03e1de671e34607e909a7213e31d7752be4fb66a86d29fe1eb14bfb5cf
url: "https://pub.dev"
source: hosted
version: "1.1.1"
collection:
dependency: transitive
description:
name: collection
sha256: ee67cb0715911d28db6bf4af1026078bd6f0128b07a5f66fb2ed94ec6783c09a
url: "https://pub.dev"
source: hosted
version: "1.18.0"
fake_async:
dependency: transitive
description:
name: fake_async
sha256: "511392330127add0b769b75a987850d136345d9227c6b94c96a04cf4a391bf78"
url: "https://pub.dev"
source: hosted
version: "1.3.1"
flutter:
dependency: "direct main"
description: flutter
source: sdk
version: "0.0.0"
flutter_lints:
dependency: "direct dev"
description:
name: flutter_lints
sha256: "9e8c3858111da373efc5aa341de011d9bd23e2c5c5e0c62bccf32438e192d7b1"
url: "https://pub.dev"
source: hosted
version: "3.0.2"
flutter_test:
dependency: "direct dev"
description: flutter
source: sdk
version: "0.0.0"
leak_tracker:
dependency: transitive
description:
name: leak_tracker
sha256: "7f0df31977cb2c0b88585095d168e689669a2cc9b97c309665e3386f3e9d341a"
url: "https://pub.dev"
source: hosted
version: "10.0.4"
leak_tracker_flutter_testing:
dependency: transitive
description:
name: leak_tracker_flutter_testing
sha256: "06e98f569d004c1315b991ded39924b21af84cf14cc94791b8aea337d25b57f8"
url: "https://pub.dev"
source: hosted
version: "3.0.3"
leak_tracker_testing:
dependency: transitive
description:
name: leak_tracker_testing
sha256: "6ba465d5d76e67ddf503e1161d1f4a6bc42306f9d66ca1e8f079a47290fb06d3"
url: "https://pub.dev"
source: hosted
version: "3.0.1"
lints:
dependency: transitive
description:
name: lints
sha256: cbf8d4b858bb0134ef3ef87841abdf8d63bfc255c266b7bf6b39daa1085c4290
url: "https://pub.dev"
source: hosted
version: "3.0.0"
matcher:
dependency: transitive
description:
name: matcher
sha256: d2323aa2060500f906aa31a895b4030b6da3ebdcc5619d14ce1aada65cd161cb
url: "https://pub.dev"
source: hosted
version: "0.12.16+1"
material_color_utilities:
dependency: transitive
description:
name: material_color_utilities
sha256: "0e0a020085b65b6083975e499759762399b4475f766c21668c4ecca34ea74e5a"
url: "https://pub.dev"
source: hosted
version: "0.8.0"
meta:
dependency: transitive
description:
name: meta
sha256: "7687075e408b093f36e6bbf6c91878cc0d4cd10f409506f7bc996f68220b9136"
url: "https://pub.dev"
source: hosted
version: "1.12.0"
path:
dependency: transitive
description:
name: path
sha256: "087ce49c3f0dc39180befefc60fdb4acd8f8620e5682fe2476afd0b3688bb4af"
url: "https://pub.dev"
source: hosted
version: "1.9.0"
plugin_platform_interface:
dependency: "direct main"
description:
name: plugin_platform_interface
sha256: "4820fbfdb9478b1ebae27888254d445073732dae3d6ea81f0b7e06d5dedc3f02"
url: "https://pub.dev"
source: hosted
version: "2.1.8"
sky_engine:
dependency: transitive
description: flutter
source: sdk
version: "0.0.99"
source_span:
dependency: transitive
description:
name: source_span
sha256: "53e943d4206a5e30df338fd4c6e7a077e02254531b138a15aec3bd143c1a8b3c"
url: "https://pub.dev"
source: hosted
version: "1.10.0"
stack_trace:
dependency: transitive
description:
name: stack_trace
sha256: "73713990125a6d93122541237550ee3352a2d84baad52d375a4cad2eb9b7ce0b"
url: "https://pub.dev"
source: hosted
version: "1.11.1"
stream_channel:
dependency: transitive
description:
name: stream_channel
sha256: ba2aa5d8cc609d96bbb2899c28934f9e1af5cddbd60a827822ea467161eb54e7
url: "https://pub.dev"
source: hosted
version: "2.1.2"
string_scanner:
dependency: transitive
description:
name: string_scanner
sha256: "556692adab6cfa87322a115640c11f13cb77b3f076ddcc5d6ae3c20242bedcde"
url: "https://pub.dev"
source: hosted
version: "1.2.0"
term_glyph:
dependency: transitive
description:
name: term_glyph
sha256: a29248a84fbb7c79282b40b8c72a1209db169a2e0542bce341da992fe1bc7e84
url: "https://pub.dev"
source: hosted
version: "1.2.1"
test_api:
dependency: transitive
description:
name: test_api
sha256: "9955ae474176f7ac8ee4e989dadfb411a58c30415bcfb648fa04b2b8a03afa7f"
url: "https://pub.dev"
source: hosted
version: "0.7.0"
vector_math:
dependency: transitive
description:
name: vector_math
sha256: "80b3257d1492ce4d091729e3a67a60407d227c27241d6927be0130c98e741803"
url: "https://pub.dev"
source: hosted
version: "2.1.4"
vm_service:
dependency: transitive
description:
name: vm_service
sha256: "3923c89304b715fb1eb6423f017651664a03bf5f4b29983627c4da791f74a4ec"
url: "https://pub.dev"
source: hosted
version: "14.2.1"
sdks:
dart: ">=3.4.3 <4.0.0"
flutter: ">=3.18.0-18.0.pre.54"

View File

@@ -0,0 +1,70 @@
name: onnx_dart
description: "A new Flutter plugin project."
version: 0.0.1
homepage:
environment:
sdk: '>=3.4.3 <4.0.0'
flutter: '>=3.3.0'
dependencies:
flutter:
sdk: flutter
plugin_platform_interface: ^2.0.2
dev_dependencies:
flutter_test:
sdk: flutter
flutter_lints: ^3.0.0
# For information on the generic Dart part of this file, see the
# following page: https://dart.dev/tools/pub/pubspec
# The following section is specific to Flutter packages.
flutter:
# This section identifies this Flutter project as a plugin project.
# The 'pluginClass' specifies the class (in Java, Kotlin, Swift, Objective-C, etc.)
# which should be registered in the plugin registry. This is required for
# using method channels.
# The Android 'package' specifies package in which the registered class is.
# This is required for using method channels on Android.
# The 'ffiPlugin' specifies that native code should be built and bundled.
# This is required for using `dart:ffi`.
# All these are used by the tooling to maintain consistency when
# adding or updating assets for this project.
plugin:
platforms:
android:
package: io.ente.photos.onnx_dart
pluginClass: OnnxDartPlugin
# To add assets to your plugin package, add an assets section, like this:
# assets:
# - images/a_dot_burr.jpeg
# - images/a_dot_ham.jpeg
#
# For details regarding assets in packages, see
# https://flutter.dev/assets-and-images/#from-packages
#
# An image asset can refer to one or more resolution-specific "variants", see
# https://flutter.dev/assets-and-images/#resolution-aware
# To add custom fonts to your plugin package, add a fonts section here,
# in this "flutter" section. Each entry in this list should have a
# "family" key with the font family name, and a "fonts" key with a
# list giving the asset and other descriptors for the font. For
# example:
# fonts:
# - family: Schyler
# fonts:
# - asset: fonts/Schyler-Regular.ttf
# - asset: fonts/Schyler-Italic.ttf
# style: italic
# - family: Trajan Pro
# fonts:
# - asset: fonts/TrajanPro.ttf
# - asset: fonts/TrajanPro_Bold.ttf
# weight: 700
#
# For details regarding fonts in packages, see
# https://flutter.dev/custom-fonts/#from-packages

View File

@@ -1606,6 +1606,13 @@ packages:
url: "https://pub.dev"
source: hosted
version: "2.0.0"
onnx_dart:
dependency: "direct main"
description:
path: "plugins/onnx_dart"
relative: true
source: path
version: "0.0.1"
onnxruntime:
dependency: "direct main"
description:
@@ -2775,5 +2782,5 @@ packages:
source: hosted
version: "3.1.2"
sdks:
dart: ">=3.4.0 <4.0.0"
dart: ">=3.4.3 <4.0.0"
flutter: ">=3.22.0"

View File

@@ -128,6 +128,8 @@ dependencies:
git:
url: https://github.com/ente-io/onnxruntime.git
ref: ente_onnxruntime
onnx_dart:
path: plugins/onnx_dart
open_mail_app: ^0.4.5
package_info_plus: ^4.1.0
page_transition: ^2.0.2