From 273e5b88ebeb3738ded74b4c10ca8d40350bd5d0 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:41:37 +0530 Subject: [PATCH] [mob][droid] Use custom dart plugin for face detection --- .../kotlin/io/ente/photos/MainActivity.kt | 3 +- .../kotlin/io/ente/photos/ml/onnx/EnteOnnx.kt | 207 -------------- mobile/ios/Podfile.lock | 2 +- mobile/lib/nativeplugins/onnx.dart | 41 --- .../face_detection_service.dart | 139 ++++++---- .../face_ml/face_ml_service.dart | 2 +- mobile/plugins/onnx_dart/.metadata | 30 +++ mobile/plugins/onnx_dart/README.md | 15 ++ .../plugins/onnx_dart/analysis_options.yaml | 1 + mobile/plugins/onnx_dart/android/build.gradle | 70 +++++ .../onnx_dart/android/local.properties | 2 + .../plugins/onnx_dart/android/settings.gradle | 1 + .../android/src/main/AndroidManifest.xml | 3 + .../ente/photos/onnx_dart/OnnxDartPlugin.kt | 253 ++++++++++++++++++ mobile/plugins/onnx_dart/lib/onnx_dart.dart | 27 ++ .../lib/onnx_dart_method_channel.dart | 73 +++++ .../lib/onnx_dart_platform_interface.dart | 51 ++++ mobile/plugins/onnx_dart/pubspec.lock | 213 +++++++++++++++ mobile/plugins/onnx_dart/pubspec.yaml | 70 +++++ mobile/pubspec.lock | 9 +- mobile/pubspec.yaml | 2 + 21 files changed, 913 insertions(+), 301 deletions(-) delete mode 100644 mobile/android/app/src/main/kotlin/io/ente/photos/ml/onnx/EnteOnnx.kt delete mode 100644 mobile/lib/nativeplugins/onnx.dart create mode 100644 mobile/plugins/onnx_dart/.metadata create mode 100644 mobile/plugins/onnx_dart/README.md create mode 100644 mobile/plugins/onnx_dart/analysis_options.yaml create mode 100644 mobile/plugins/onnx_dart/android/build.gradle create mode 100644 mobile/plugins/onnx_dart/android/local.properties create mode 100644 mobile/plugins/onnx_dart/android/settings.gradle create mode 100644 mobile/plugins/onnx_dart/android/src/main/AndroidManifest.xml create mode 100644 mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt create mode 100644 mobile/plugins/onnx_dart/lib/onnx_dart.dart create mode 100644 mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart create mode 100644 mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart create mode 100644 mobile/plugins/onnx_dart/pubspec.lock create mode 100644 mobile/plugins/onnx_dart/pubspec.yaml diff --git a/mobile/android/app/src/main/kotlin/io/ente/photos/MainActivity.kt b/mobile/android/app/src/main/kotlin/io/ente/photos/MainActivity.kt index 227bc1788e..75edded083 100644 --- a/mobile/android/app/src/main/kotlin/io/ente/photos/MainActivity.kt +++ b/mobile/android/app/src/main/kotlin/io/ente/photos/MainActivity.kt @@ -3,11 +3,10 @@ package io.ente.photos import io.flutter.embedding.android.FlutterFragmentActivity import io.flutter.embedding.engine.FlutterEngine import io.flutter.plugins.GeneratedPluginRegistrant -import io.ente.photos.ml.onnx.EnteOnnxFlutterPlugin class MainActivity : FlutterFragmentActivity() { override fun configureFlutterEngine(flutterEngine: FlutterEngine) { GeneratedPluginRegistrant.registerWith(flutterEngine) - flutterEngine.plugins.add(EnteOnnxFlutterPlugin()) + } } diff --git a/mobile/android/app/src/main/kotlin/io/ente/photos/ml/onnx/EnteOnnx.kt b/mobile/android/app/src/main/kotlin/io/ente/photos/ml/onnx/EnteOnnx.kt deleted file mode 100644 index b301850681..0000000000 --- a/mobile/android/app/src/main/kotlin/io/ente/photos/ml/onnx/EnteOnnx.kt +++ /dev/null @@ -1,207 +0,0 @@ -package io.ente.photos.ml.onnx - -import android.content.Context -import androidx.annotation.NonNull -import ai.onnxruntime.* -import io.flutter.embedding.engine.plugins.FlutterPlugin -import io.flutter.plugin.common.MethodCall -import io.flutter.plugin.common.MethodChannel -import io.flutter.plugin.common.MethodChannel.MethodCallHandler -import io.flutter.plugin.common.MethodChannel.Result -import kotlinx.coroutines.* -import java.nio.FloatBuffer -import java.util.concurrent.ConcurrentHashMap -import android.util.Log -import java.io.File -import java.util.concurrent.ConcurrentLinkedQueue - -object LongArrayPool { - private val pool = ConcurrentLinkedQueue() - - fun get(size: Int): LongArray { - return pool.poll() ?: LongArray(size) - } - - fun release(array: LongArray) { - pool.offer(array) - } -} - -class EnteOnnxFlutterPlugin : FlutterPlugin, MethodCallHandler { - private var faceOrtEnv: OrtEnvironment = OrtEnvironment.getEnvironment() - private lateinit var channel: MethodChannel - private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) - private val sessionMap = ConcurrentHashMap() - private lateinit var context: Context - - companion object { - const val DEFAULT_SESSION_COUNT = 1 - const val K_INPUT_WIDTH = 640 - const val K_INPUT_HEIGHT = 640 - const val K_NUM_CHANNELS = 3 - } - - enum class ModelType { - CLIP_TEXT, CLIP_VISUAL, YOLO_FACE, MOBILENET_FACE - } - - data class ModelState( - var isInitialized: Boolean = false, - val sessionAddresses: ConcurrentHashMap = ConcurrentHashMap(), - // number of sessions that should have been created for given model - var sessionsCount: Int = DEFAULT_SESSION_COUNT - ) - - override fun onAttachedToEngine(@NonNull flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { - channel = MethodChannel(flutterPluginBinding.binaryMessenger, "ente_onnx_flutter_plugin") - channel.setMethodCallHandler(this) - context = flutterPluginBinding.applicationContext - } - - override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) { - channel.setMethodCallHandler(null) - releaseAllSessions() - scope.cancel() - } - - override fun onMethodCall(@NonNull call: MethodCall, @NonNull result: Result) { - when (call.method) { - "init" -> { - val modelType = call.argument("modelType") ?: run { - result.error("INVALID_ARGUMENT", "Model type is missing", null) - return - } - val modelPath = call.argument("modelPath") ?: run { - result.error("INVALID_ARGUMENT", "Model path is missing", null) - return - } - val sessionsCount = call.argument("sessionsCount") ?: DEFAULT_SESSION_COUNT - init(ModelType.valueOf(modelType), modelPath, sessionsCount, result) - } - "release" -> { - val modelType = call.argument("modelType") ?: run { - result.error("INVALID_ARGUMENT", "Model type is missing", null) - return - } - release(ModelType.valueOf(modelType), result) - } - "predict" -> { - val sessionAddress = call.argument("sessionAddress") - val inputData = call.argument>("inputData") - val modelType = call.argument("modelType") ?: run { - result.error("INVALID_ARGUMENT", "Model type is missing", null) - return - } - if (sessionAddress == null || inputData == null) { - result.error("INVALID_ARGUMENT", "Session address or input data is missing", null) - return - } - val inputDataArray = inputData.map { it.toFloat() }.toFloatArray() - predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, result) - } - else -> result.notImplemented() - } - } - - private fun readModelFile(modelPath: String): ByteArray { - return File(modelPath).readBytes() - } - - private fun init(modelType: ModelType, modelPath: String, sessionsCount: Int, result: Result) { - scope.launch { - val modelState: ModelState - if (sessionMap.containsKey(modelType)) { - modelState = sessionMap[modelType]!! - } else { - modelState = ModelState() - sessionMap[modelType] = modelState - } - if (!modelState.isInitialized) { - for (i in 0 until sessionsCount) { - val session = createSession(faceOrtEnv, modelPath) - if (session != null) { - modelState.sessionAddresses[i] = session - } - } - modelState.isInitialized = true - modelState.sessionsCount = sessionsCount - withContext(Dispatchers.Main) { - result.success(true) - } - } else { - withContext(Dispatchers.Main) { - result.success(false) - } - } - } - } - - private fun release(modelType: ModelType, result: Result) { - scope.launch { - val modelState = sessionMap[modelType] - modelState?.let { - it.sessionAddresses.forEach { entry: Map.Entry -> - entry.value.close() - } - it.sessionAddresses.clear() - it.isInitialized = false - } - withContext(Dispatchers.Main) { - result.success(true) - } - } - } - - private fun predict(modelType: ModelType, sessionAddress: Int, inputData: FloatArray, result: Result) { - scope.launch { - val modelState = sessionMap[modelType] - val session = modelState?.sessionAddresses?.get(sessionAddress) - if (session == null) { - withContext(Dispatchers.Main) { - result.error("SESSION_NOT_FOUND", "Session not found for address: $sessionAddress", null) - } - return@launch - } - - try { - val env = OrtEnvironment.getEnvironment() - val inputTensorShape = LongArrayPool.get(4).apply { - this[0] = 1 - this[1] = K_NUM_CHANNELS.toLong() - this[2] = K_INPUT_HEIGHT.toLong() - this[3] = K_INPUT_WIDTH.toLong() - } - val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape) - val inputs = mapOf("input" to inputTensor) - val outputs = session.run(inputs) - Log.d("OnnxFlutterPlugin", "Output shape: ${outputs.size()}") - - inputTensor.close() - outputs.close() - LongArrayPool.release(inputTensorShape) - withContext(Dispatchers.Main) { - val dummyResult = listOf(0.1, 0.2) // Replace with actual result processing - result.success(dummyResult) - } - } catch (e: OrtException) { - withContext(Dispatchers.Main) { - result.error("PREDICTION_ERROR", "Error during prediction: ${e.message}", null) - } - } - } - } - - private fun createSession(env: OrtEnvironment, modalPath: String): OrtSession? { - return env.createSession(modalPath, OrtSession.SessionOptions()) - } - - private fun releaseAllSessions() { - sessionMap.forEach { (_, modelState) -> - modelState.sessionAddresses.forEach { entry -> - entry.value.close() - } - modelState.sessionAddresses.clear() - } - sessionMap.clear() - } -} diff --git a/mobile/ios/Podfile.lock b/mobile/ios/Podfile.lock index fd11b0a5b1..5f03149e8d 100644 --- a/mobile/ios/Podfile.lock +++ b/mobile/ios/Podfile.lock @@ -466,7 +466,7 @@ SPEC CHECKSUMS: package_info_plus: 115f4ad11e0698c8c1c5d8a689390df880f47e85 path_provider_foundation: 3784922295ac71e43754bd15e0653ccfd36a147c permission_handler_apple: 9878588469a2b0d0fc1e048d9f43605f92e6cec2 - photo_manager: 4f6810b7dfc4feb03b461ac1a70dacf91fba7604 + photo_manager: ff695c7a1dd5bc379974953a2b5c0a293f7c4c8a PromisesObjC: f5707f49cb48b9636751c5b2e7d227e43fba9f47 receive_sharing_intent: 6837b01768e567fe8562182397bf43d63d8c6437 screen_brightness_ios: 715ca807df953bf676d339f11464e438143ee625 diff --git a/mobile/lib/nativeplugins/onnx.dart b/mobile/lib/nativeplugins/onnx.dart deleted file mode 100644 index 58f8bfe7d2..0000000000 --- a/mobile/lib/nativeplugins/onnx.dart +++ /dev/null @@ -1,41 +0,0 @@ -import 'package:flutter/services.dart'; - -class OnnxFlutterPlugin { - static const MethodChannel _channel = - MethodChannel('ente_onnx_flutter_plugin'); - - static Future init( - String modelType, - String modelPath, { - int sessionsCount = 1, - }) async { - final bool result = await _channel.invokeMethod('init', { - 'modelType': modelType, - 'modelPath': modelPath, - 'sessionsCount': sessionsCount, - }); - return result; - } - - static Future release(String modelType) async { - final bool result = - await _channel.invokeMethod('release', {'modelType': modelType}); - return result; - } - - static Future> predict( - List inputData, - String modelType, { - int sessionAddress = 0, - }) async { - final List result = await _channel.invokeMethod( - 'predict', - { - 'sessionAddress': sessionAddress, - 'inputData': inputData, - 'modelType': modelType, - }, - ); - return result.cast(); - } -} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart index 455f4e9c8d..9935ed19ac 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart @@ -1,11 +1,12 @@ import "dart:async"; import "dart:developer" as dev show log; -import "dart:io" show File; -import 'dart:typed_data' show ByteData; +import "dart:io" show File, Platform; +import 'dart:typed_data' show ByteBuffer, ByteData, Float32List, Uint8List; import 'dart:ui' as ui show Image; import "package:computer/computer.dart"; import 'package:logging/logging.dart'; +import "package:onnx_dart/onnx_dart.dart"; import 'package:onnxruntime/onnxruntime.dart'; import "package:photos/face/model/dimension.dart"; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; @@ -34,6 +35,7 @@ class FaceDetectionService { static const double kIouThreshold = 0.4; static const double kMinScoreSigmoidThreshold = 0.7; static const int kNumKeypoints = 5; + static bool useCustomPlugin = Platform.isAndroid; bool isInitialized = false; @@ -49,19 +51,23 @@ class FaceDetectionService { final model = await RemoteAssetsService.instance.getAsset(modelRemotePath); final startTime = DateTime.now(); - sessionAddress = await _computer.compute( - _loadModel, - param: { - "modelPath": model.path, - }, - ); + if (useCustomPlugin) { + final OnnxDart plugin = OnnxDart(); + final bool? initResult = await plugin.init("YOLO_FACE", model.path); + isInitialized = initResult ?? false; + } else { + sessionAddress = await _computer.compute( + _loadModel, + param: { + "modelPath": model.path, + }, + ); + isInitialized = sessionAddress != -1; + } final endTime = DateTime.now(); _logger.info( "Face detection model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", ); - if (sessionAddress != -1) { - isInitialized = true; - } } } @@ -80,7 +86,10 @@ class FaceDetectionService { ByteData imageByteData, int sessionAddress, ) async { - assert(sessionAddress != 0 && sessionAddress != -1); + assert( + !useCustomPlugin ? (sessionAddress != 0 && sessionAddress != -1) : true, + 'sessionAddress should be valid', + ); final stopwatch = Stopwatch()..start(); @@ -94,18 +103,6 @@ class FaceDetectionService { requiredHeight: kInputHeight, maintainAspectRatio: true, ); - - final inputShape = [ - 1, - 3, - kInputHeight, - kInputWidth, - ]; - final inputOrt = OrtValueTensor.createTensorWithDataList( - inputImageList, - inputShape, - ); - final inputs = {'input': inputOrt}; stopwatchPreprocessing.stop(); dev.log( 'Face detection image preprocessing is finished, in ${stopwatchPreprocessing.elapsedMilliseconds}ms', @@ -117,39 +114,90 @@ class FaceDetectionService { // Run inference final stopwatchInterpreter = Stopwatch()..start(); - List? outputs; + + List>>? nestedResults = []; try { - final runOptions = OrtRunOptions(); - final session = OrtSession.fromAddress(sessionAddress); - outputs = session.run(runOptions, inputs); - // inputOrt.release(); - // runOptions.release(); + if (useCustomPlugin) { + nestedResults = await _runCustomPlugin(inputImageList); + } else { + nestedResults = _runJNIBasedPlugin( + sessionAddress, + inputImageList, + ); // [1, 25200, 16] + } } catch (e, s) { - _logger.severe('Error while running inference: $e \n $s'); + dev.log('Error while running inference', error: e, stackTrace: s); throw YOLOFaceInterpreterRunException(); } stopwatchInterpreter.stop(); - _logger.info( - 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + try { + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = + _yoloPostProcessOutputs(nestedResults!, newSize); + stopwatch.stop(); + _logger.info( + 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return (relativeDetections, originalSize); + } catch (e, s) { + _logger.severe('Error while post processing', e, s); + rethrow; + } + } + + static List>>? _runJNIBasedPlugin( + int sessionAddress, + Float32List inputImageList, + ) { + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, ); + final inputs = {'input': inputOrt}; - final relativeDetections = _yoloPostProcessOutputs(outputs, newSize); + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + final List? outputs = session.run(runOptions, inputs); + // inputOrt.release(); + // runOptions.release(); + return outputs?[0]?.value as List>>; // [1, 25200, 16] + } - stopwatch.stop(); - _logger.info( - 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + static Future>>> _runCustomPlugin( + Float32List inputImageList, + ) async { + final OnnxDart plugin = OnnxDart(); + final result = await plugin.predict( + inputImageList, + "YOLO_FACE", ); - - return (relativeDetections, originalSize); + final List>> reconstructedTensor = []; + for (int i = 0; i < result.length; i += 25200 * 16) { + final List> outerArray = []; + for (int j = 0; j < 25200; j++) { + final List innerArray = + result.sublist(i + j * 16, i + (j + 1) * 16).cast(); + outerArray.add(innerArray); + } + reconstructedTensor.add(outerArray); + } + return reconstructedTensor; } static List _yoloPostProcessOutputs( - List? outputs, + List>> nestedResults, Dimensions newSize, ) { - // // Get output tensors - final nestedResults = - outputs?[0]?.value as List>>; // [1, 25200, 16] final firstResults = nestedResults[0]; // [25200, 16] // Filter output @@ -160,11 +208,6 @@ class FaceDetectionService { results: firstResults, ); - // Release outputs - // outputs?.forEach((element) { - // element?.release(); - // }); - // Account for the fact that the aspect ratio was maintained for (final faceDetection in relativeDetections) { faceDetection.correctForMaintainedAspectRatio( diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 99172153b5..1bfb13b622 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -744,6 +744,7 @@ class FaceMlService { final sendPort = message[2] as SendPort; try { + await FaceDetectionService.instance.init(); switch (function) { case FaceMlOperation.analyzeImage: final time = DateTime.now(); @@ -838,7 +839,6 @@ class FaceMlService { final String? filePath = await _getImagePathForML(enteFile, typeOfData: FileDataForML.fileData); - if (filePath == null) { _logger.warning( "Failed to get any data for enteFile with uploadedFileID ${enteFile.uploadedFileID} since its file path is null", diff --git a/mobile/plugins/onnx_dart/.metadata b/mobile/plugins/onnx_dart/.metadata new file mode 100644 index 0000000000..cf57c0cccc --- /dev/null +++ b/mobile/plugins/onnx_dart/.metadata @@ -0,0 +1,30 @@ +# This file tracks properties of this Flutter project. +# Used by Flutter tool to assess capabilities and perform upgrades etc. +# +# This file should be version controlled and should not be manually edited. + +version: + revision: "761747bfc538b5af34aa0d3fac380f1bc331ec49" + channel: "[user-branch]" + +project_type: plugin + +# Tracks metadata for the flutter migrate command +migration: + platforms: + - platform: root + create_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49 + base_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49 + - platform: android + create_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49 + base_revision: 761747bfc538b5af34aa0d3fac380f1bc331ec49 + + # User provided section + + # List of Local paths (relative to this file) that should be + # ignored by the migrate tool. + # + # Files that are not part of the templates will be ignored by default. + unmanaged_files: + - 'lib/main.dart' + - 'ios/Runner.xcodeproj/project.pbxproj' diff --git a/mobile/plugins/onnx_dart/README.md b/mobile/plugins/onnx_dart/README.md new file mode 100644 index 0000000000..c4b8097695 --- /dev/null +++ b/mobile/plugins/onnx_dart/README.md @@ -0,0 +1,15 @@ +# onnx_dart + +A new Flutter plugin project. + +## Getting Started + +This project is a starting point for a Flutter +[plug-in package](https://flutter.dev/developing-packages/), +a specialized package that includes platform-specific implementation code for +Android and/or iOS. + +For help getting started with Flutter development, view the +[online documentation](https://flutter.dev/docs), which offers tutorials, +samples, guidance on mobile development, and a full API reference. + diff --git a/mobile/plugins/onnx_dart/analysis_options.yaml b/mobile/plugins/onnx_dart/analysis_options.yaml new file mode 100644 index 0000000000..fac60e247c --- /dev/null +++ b/mobile/plugins/onnx_dart/analysis_options.yaml @@ -0,0 +1 @@ +include: ../../analysis_options.yaml \ No newline at end of file diff --git a/mobile/plugins/onnx_dart/android/build.gradle b/mobile/plugins/onnx_dart/android/build.gradle new file mode 100644 index 0000000000..61a05bb66a --- /dev/null +++ b/mobile/plugins/onnx_dart/android/build.gradle @@ -0,0 +1,70 @@ +group = "io.ente.photos.onnx_dart" +version = "1.0-SNAPSHOT" + +buildscript { + ext.kotlin_version = "1.7.10" + repositories { + google() + mavenCentral() + } + + dependencies { + classpath("com.android.tools.build:gradle:7.3.0") + classpath("org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version") + } +} + +allprojects { + repositories { + google() + mavenCentral() + } +} + +apply plugin: "com.android.library" +apply plugin: "kotlin-android" + +android { + if (project.android.hasProperty("namespace")) { + namespace = "io.ente.photos.onnx_dart" + } + + compileSdk = 34 + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + + kotlinOptions { + jvmTarget = "1.8" + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + test.java.srcDirs += "src/test/kotlin" + } + + defaultConfig { + minSdk = 21 + } + + dependencies { + testImplementation("org.jetbrains.kotlin:kotlin-test") + testImplementation("org.mockito:mockito-core:5.0.0") + + implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release' + } + + testOptions { + unitTests.all { + useJUnitPlatform() + + testLogging { + events "passed", "skipped", "failed", "standardOut", "standardError" + outputs.upToDateWhen {false} + showStandardStreams = true + } + } + } +} diff --git a/mobile/plugins/onnx_dart/android/local.properties b/mobile/plugins/onnx_dart/android/local.properties new file mode 100644 index 0000000000..54c35dc361 --- /dev/null +++ b/mobile/plugins/onnx_dart/android/local.properties @@ -0,0 +1,2 @@ +sdk.dir=/Users/ua741/Library/Android/sdk +flutter.sdk=/Users/ua741/work/flutter \ No newline at end of file diff --git a/mobile/plugins/onnx_dart/android/settings.gradle b/mobile/plugins/onnx_dart/android/settings.gradle new file mode 100644 index 0000000000..d207db8c75 --- /dev/null +++ b/mobile/plugins/onnx_dart/android/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'onnx_dart' diff --git a/mobile/plugins/onnx_dart/android/src/main/AndroidManifest.xml b/mobile/plugins/onnx_dart/android/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..7582c8c286 --- /dev/null +++ b/mobile/plugins/onnx_dart/android/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt new file mode 100644 index 0000000000..4091f7106c --- /dev/null +++ b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt @@ -0,0 +1,253 @@ +package io.ente.photos.onnx_dart + +import android.content.Context +import androidx.annotation.NonNull +import ai.onnxruntime.* +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import io.flutter.plugin.common.MethodChannel.MethodCallHandler +import io.flutter.plugin.common.MethodChannel.Result +import kotlinx.coroutines.* +import java.nio.FloatBuffer +import java.util.concurrent.ConcurrentHashMap +import android.util.Log +import java.io.File +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.concurrent.ConcurrentLinkedQueue + +object LongArrayPool { + private val poolMap = ConcurrentHashMap>() + + fun get(size: Int): LongArray { + val pool = getPool(size) + return pool.poll() ?: LongArray(size) + } + + fun release(array: LongArray) { + val pool = getPool(array.size) + pool.offer(array) + } + + private fun getPool(size: Int): ConcurrentLinkedQueue { + var pool = poolMap[size] + if (pool == null) { + synchronized(poolMap) { + pool = poolMap[size] + if (pool == null) { + pool = ConcurrentLinkedQueue() + poolMap[size] = pool!! + } + } + } + return pool!! + } +} +/** OnnxDartPlugin */ +class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { + /// The MethodChannel that will the communication between Flutter and native Android + /// + /// This local reference serves to register the plugin with the Flutter Engine and unregister it + /// when the Flutter Engine is detached from the Activity + private lateinit var channel : MethodChannel + + private var faceOrtEnv: OrtEnvironment = OrtEnvironment.getEnvironment() + private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) + private val sessionMap = ConcurrentHashMap() + private lateinit var context: Context + + companion object { + const val DEFAULT_SESSION_COUNT = 1 + const val K_INPUT_WIDTH = 640 + const val K_INPUT_HEIGHT = 640 + const val K_NUM_CHANNELS = 3 + } + + enum class ModelType { + CLIP_TEXT, CLIP_VISUAL, YOLO_FACE, MOBILENET_FACE + } + + data class ModelState( + var isInitialized: Boolean = false, + val sessionAddresses: ConcurrentHashMap = ConcurrentHashMap(), + // number of sessions that should have been created for given model + var sessionsCount: Int = DEFAULT_SESSION_COUNT + ) + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, "onnx_dart") + channel.setMethodCallHandler(this) + context = flutterPluginBinding.applicationContext + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } + + + override fun onMethodCall(call: MethodCall, result: Result) { + if (call.method == "getPlatformVersion") { + result.success("Android ${android.os.Build.VERSION.RELEASE}") + } else if (call.method == "init") { + val modelType = call.argument("modelType") ?: run { + result.error("INVALID_ARGUMENT", "Model type is missing", null) + return + } + val modelPath = call.argument("modelPath") ?: run { + result.error("INVALID_ARGUMENT", "Model path is missing", null) + return + } + val sessionsCount = call.argument("sessionsCount") ?: DEFAULT_SESSION_COUNT + init(ModelType.valueOf(modelType), modelPath, sessionsCount, result) + } + else if (call.method == "release" ) { + val modelType = call.argument("modelType") ?: run { + result.error("INVALID_ARGUMENT", "Model type is missing", null) + return + } + release(ModelType.valueOf(modelType), result) + } + else if (call.method == "predict" ) { + val sessionAddress = call.argument("sessionAddress") + val modelType = call.argument("modelType") ?: run { + result.error("INVALID_ARGUMENT", "Model type is missing", null) + return + } + val inputDataArray = call.argument("inputData") + + if (sessionAddress == null || inputDataArray == null) { + result.error("INVALID_ARGUMENT", "Session address or input data is missing", null) + return + } +// val inputDataArray = inputData.map { it.toFloat() }.toFloatArray() + predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, result) + } else { + result.notImplemented() + } + } + + + private fun readModelFile(modelPath: String): ByteArray { + return File(modelPath).readBytes() + } + + private fun init(modelType: ModelType, modelPath: String, sessionsCount: Int, result: Result) { + Log.d("OnnxFlutterPlugin", " v: $modelType, path: $modelPath, sessionsCount: $sessionsCount") + scope.launch { + val modelState: ModelState + if (sessionMap.containsKey(modelType)) { + modelState = sessionMap[modelType]!! + } else { + modelState = ModelState() + sessionMap[modelType] = modelState + } + if (!modelState.isInitialized) { + for (i in 0 until sessionsCount) { + val session = createSession(faceOrtEnv, modelPath) + if (session != null) { + modelState.sessionAddresses[i] = session + } + } + modelState.isInitialized = true + modelState.sessionsCount = sessionsCount + withContext(Dispatchers.Main) { + Log.d("OnnxFlutterPlugin", "Model initialized: $modelType") + result.success(true) + } + } else { + withContext(Dispatchers.Main) { + result.success(false) + } + } + } + } + + private fun release(modelType: ModelType, result: Result) { + scope.launch { + val modelState = sessionMap[modelType] + modelState?.let { + it.sessionAddresses.forEach { entry: Map.Entry -> + entry.value.close() + } + it.sessionAddresses.clear() + it.isInitialized = false + } + withContext(Dispatchers.Main) { + result.success(true) + } + } + } + + private fun predict(modelType: ModelType, sessionAddress: Int, inputData: FloatArray, result: Result) { + scope.launch { + val modelState = sessionMap[modelType] + val session = modelState?.sessionAddresses?.get(sessionAddress) + if (session == null) { + withContext(Dispatchers.Main) { + result.error("SESSION_NOT_FOUND", "Session not found for address: $sessionAddress", null) + } + return@launch + } + + try { + val env = OrtEnvironment.getEnvironment() + val inputTensorShape = LongArrayPool.get(4).apply { + this[0] = 1 + this[1] = K_NUM_CHANNELS.toLong() + this[2] = K_INPUT_HEIGHT.toLong() + this[3] = K_INPUT_WIDTH.toLong() + } + val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape) + val inputs = mapOf("input" to inputTensor) + val outputs = session.run(inputs) + Log.d("OnnxFlutterPlugin", "Output shape: ${outputs.size()}") + + + inputTensor.close() + val totalSize = 1 * 25200 * 16 + val flatArray = FloatArray(totalSize) { index -> index + 1.0f } + + val outputTensor2 = Array(1) { outerIndex -> + Array(25200) { innerIndex -> + val startIndex = (outerIndex * 25200 + innerIndex) * 16 + flatArray.sliceArray(startIndex until startIndex + 16) + } + } + val outputTensor = outputs[0].value as Array> + Log.d("OnnxFlutterPlugin", "Output2 shape: ${outputTensor.size}") + + outputs.close() + LongArrayPool.release(inputTensorShape) + // Send the result back to the Dart layer + val flatList = outputTensor.flatten().flatMap { it.toList() } + val flatList2 = outputTensor2.flatten().flatMap { it.toList() } + + + withContext(Dispatchers.Main) { + result.success(flatList) +// result.success(flatList2) + + } + } catch (e: OrtException) { + withContext(Dispatchers.Main) { + result.error("PREDICTION_ERROR", "Error during prediction: ${e.message}", null) + } + } + } + } + + private fun createSession(env: OrtEnvironment, modalPath: String): OrtSession? { + return env.createSession(modalPath, OrtSession.SessionOptions()) + } + + private fun releaseAllSessions() { + sessionMap.forEach { (_, modelState) -> + modelState.sessionAddresses.forEach { entry -> + entry.value.close() + } + modelState.sessionAddresses.clear() + } + sessionMap.clear() + } +} diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart.dart b/mobile/plugins/onnx_dart/lib/onnx_dart.dart new file mode 100644 index 0000000000..5beaee034c --- /dev/null +++ b/mobile/plugins/onnx_dart/lib/onnx_dart.dart @@ -0,0 +1,27 @@ +import 'dart:typed_data'; + +import 'package:onnx_dart/onnx_dart_platform_interface.dart'; + +class OnnxDart { + Future getPlatformVersion() { + return OnnxDartPlatform.instance.getPlatformVersion(); + } + + Future init( + String modelType, + String modelPath, { + int sessionsCount = 1, + }) { + return OnnxDartPlatform.instance + .init(modelType, modelPath, sessionsCount: sessionsCount); + } + + Future predict( + Float32List inputData, + String modelType, { + int sessionAddress = 0, + }) async { + return OnnxDartPlatform.instance + .predict(inputData, modelType, sessionAddress: sessionAddress); + } +} diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart b/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart new file mode 100644 index 0000000000..4a9d301fd5 --- /dev/null +++ b/mobile/plugins/onnx_dart/lib/onnx_dart_method_channel.dart @@ -0,0 +1,73 @@ +import 'package:flutter/foundation.dart'; +import 'package:flutter/services.dart'; + +import 'onnx_dart_platform_interface.dart'; + +/// An implementation of [OnnxDartPlatform] that uses method channels. +class MethodChannelOnnxDart extends OnnxDartPlatform { + /// The method channel used to interact with the native platform. + @visibleForTesting + final methodChannel = const MethodChannel('onnx_dart'); + + @override + Future getPlatformVersion() async { + final version = + await methodChannel.invokeMethod('getPlatformVersion'); + return version; + } + + @override + Future init( + String modelType, + String modelPath, { + int sessionsCount = 1, + }) async { + final result = await methodChannel.invokeMethod('init', { + 'modelType': modelType, + 'modelPath': modelPath, + 'sessionsCount': sessionsCount, + }); + return result; + } + + @override + Future release(String modelType) async { + final bool? result = await methodChannel + .invokeMethod('release', {'modelType': modelType}); + return result; + } + + // @override + // Future?> predict( + // List inputData, + // String modelType, { + // int sessionAddress = 0, + // }) async { + // final List? result = + // await methodChannel.invokeMethod?>( + // 'predict', + // { + // 'sessionAddress': sessionAddress, + // 'inputData': inputData, + // 'modelType': modelType, + // }, + // ); + // return result!.cast(); + // } + + @override + Future predict( + Float32List inputData, + String modelType, { + int sessionAddress = 0, + }) { + return methodChannel.invokeMethod( + 'predict', + { + 'sessionAddress': sessionAddress, + 'inputData': inputData, + 'modelType': modelType, + }, + ); + } +} diff --git a/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart b/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart new file mode 100644 index 0000000000..40f80fe523 --- /dev/null +++ b/mobile/plugins/onnx_dart/lib/onnx_dart_platform_interface.dart @@ -0,0 +1,51 @@ +import 'dart:typed_data'; + +import 'package:plugin_platform_interface/plugin_platform_interface.dart'; + +import 'onnx_dart_method_channel.dart'; + +abstract class OnnxDartPlatform extends PlatformInterface { + /// Constructs a OnnxDartPlatform. + OnnxDartPlatform() : super(token: _token); + + static final Object _token = Object(); + + static OnnxDartPlatform _instance = MethodChannelOnnxDart(); + + /// The default instance of [OnnxDartPlatform] to use. + /// + /// Defaults to [MethodChannelOnnxDart]. + static OnnxDartPlatform get instance => _instance; + + /// Platform-specific implementations should set this with their own + /// platform-specific class that extends [OnnxDartPlatform] when + /// they register themselves. + static set instance(OnnxDartPlatform instance) { + PlatformInterface.verifyToken(instance, _token); + _instance = instance; + } + + Future getPlatformVersion() { + throw UnimplementedError('platformVersion() has not been implemented.'); + } + + Future init( + String modelType, + String modelPath, { + int sessionsCount = 1, + }) { + throw UnimplementedError('init() has not been implemented.'); + } + + Future release(String modelType) { + throw UnimplementedError('release() has not been implemented.'); + } + + Future predict( + Float32List inputData, + String modelType, { + int sessionAddress = 0, + }) { + throw UnimplementedError('predict() has not been implemented.'); + } +} diff --git a/mobile/plugins/onnx_dart/pubspec.lock b/mobile/plugins/onnx_dart/pubspec.lock new file mode 100644 index 0000000000..c14d3ed876 --- /dev/null +++ b/mobile/plugins/onnx_dart/pubspec.lock @@ -0,0 +1,213 @@ +# Generated by pub +# See https://dart.dev/tools/pub/glossary#lockfile +packages: + async: + dependency: transitive + description: + name: async + sha256: "947bfcf187f74dbc5e146c9eb9c0f10c9f8b30743e341481c1e2ed3ecc18c20c" + url: "https://pub.dev" + source: hosted + version: "2.11.0" + boolean_selector: + dependency: transitive + description: + name: boolean_selector + sha256: "6cfb5af12253eaf2b368f07bacc5a80d1301a071c73360d746b7f2e32d762c66" + url: "https://pub.dev" + source: hosted + version: "2.1.1" + characters: + dependency: transitive + description: + name: characters + sha256: "04a925763edad70e8443c99234dc3328f442e811f1d8fd1a72f1c8ad0f69a605" + url: "https://pub.dev" + source: hosted + version: "1.3.0" + clock: + dependency: transitive + description: + name: clock + sha256: cb6d7f03e1de671e34607e909a7213e31d7752be4fb66a86d29fe1eb14bfb5cf + url: "https://pub.dev" + source: hosted + version: "1.1.1" + collection: + dependency: transitive + description: + name: collection + sha256: ee67cb0715911d28db6bf4af1026078bd6f0128b07a5f66fb2ed94ec6783c09a + url: "https://pub.dev" + source: hosted + version: "1.18.0" + fake_async: + dependency: transitive + description: + name: fake_async + sha256: "511392330127add0b769b75a987850d136345d9227c6b94c96a04cf4a391bf78" + url: "https://pub.dev" + source: hosted + version: "1.3.1" + flutter: + dependency: "direct main" + description: flutter + source: sdk + version: "0.0.0" + flutter_lints: + dependency: "direct dev" + description: + name: flutter_lints + sha256: "9e8c3858111da373efc5aa341de011d9bd23e2c5c5e0c62bccf32438e192d7b1" + url: "https://pub.dev" + source: hosted + version: "3.0.2" + flutter_test: + dependency: "direct dev" + description: flutter + source: sdk + version: "0.0.0" + leak_tracker: + dependency: transitive + description: + name: leak_tracker + sha256: "7f0df31977cb2c0b88585095d168e689669a2cc9b97c309665e3386f3e9d341a" + url: "https://pub.dev" + source: hosted + version: "10.0.4" + leak_tracker_flutter_testing: + dependency: transitive + description: + name: leak_tracker_flutter_testing + sha256: "06e98f569d004c1315b991ded39924b21af84cf14cc94791b8aea337d25b57f8" + url: "https://pub.dev" + source: hosted + version: "3.0.3" + leak_tracker_testing: + dependency: transitive + description: + name: leak_tracker_testing + sha256: "6ba465d5d76e67ddf503e1161d1f4a6bc42306f9d66ca1e8f079a47290fb06d3" + url: "https://pub.dev" + source: hosted + version: "3.0.1" + lints: + dependency: transitive + description: + name: lints + sha256: cbf8d4b858bb0134ef3ef87841abdf8d63bfc255c266b7bf6b39daa1085c4290 + url: "https://pub.dev" + source: hosted + version: "3.0.0" + matcher: + dependency: transitive + description: + name: matcher + sha256: d2323aa2060500f906aa31a895b4030b6da3ebdcc5619d14ce1aada65cd161cb + url: "https://pub.dev" + source: hosted + version: "0.12.16+1" + material_color_utilities: + dependency: transitive + description: + name: material_color_utilities + sha256: "0e0a020085b65b6083975e499759762399b4475f766c21668c4ecca34ea74e5a" + url: "https://pub.dev" + source: hosted + version: "0.8.0" + meta: + dependency: transitive + description: + name: meta + sha256: "7687075e408b093f36e6bbf6c91878cc0d4cd10f409506f7bc996f68220b9136" + url: "https://pub.dev" + source: hosted + version: "1.12.0" + path: + dependency: transitive + description: + name: path + sha256: "087ce49c3f0dc39180befefc60fdb4acd8f8620e5682fe2476afd0b3688bb4af" + url: "https://pub.dev" + source: hosted + version: "1.9.0" + plugin_platform_interface: + dependency: "direct main" + description: + name: plugin_platform_interface + sha256: "4820fbfdb9478b1ebae27888254d445073732dae3d6ea81f0b7e06d5dedc3f02" + url: "https://pub.dev" + source: hosted + version: "2.1.8" + sky_engine: + dependency: transitive + description: flutter + source: sdk + version: "0.0.99" + source_span: + dependency: transitive + description: + name: source_span + sha256: "53e943d4206a5e30df338fd4c6e7a077e02254531b138a15aec3bd143c1a8b3c" + url: "https://pub.dev" + source: hosted + version: "1.10.0" + stack_trace: + dependency: transitive + description: + name: stack_trace + sha256: "73713990125a6d93122541237550ee3352a2d84baad52d375a4cad2eb9b7ce0b" + url: "https://pub.dev" + source: hosted + version: "1.11.1" + stream_channel: + dependency: transitive + description: + name: stream_channel + sha256: ba2aa5d8cc609d96bbb2899c28934f9e1af5cddbd60a827822ea467161eb54e7 + url: "https://pub.dev" + source: hosted + version: "2.1.2" + string_scanner: + dependency: transitive + description: + name: string_scanner + sha256: "556692adab6cfa87322a115640c11f13cb77b3f076ddcc5d6ae3c20242bedcde" + url: "https://pub.dev" + source: hosted + version: "1.2.0" + term_glyph: + dependency: transitive + description: + name: term_glyph + sha256: a29248a84fbb7c79282b40b8c72a1209db169a2e0542bce341da992fe1bc7e84 + url: "https://pub.dev" + source: hosted + version: "1.2.1" + test_api: + dependency: transitive + description: + name: test_api + sha256: "9955ae474176f7ac8ee4e989dadfb411a58c30415bcfb648fa04b2b8a03afa7f" + url: "https://pub.dev" + source: hosted + version: "0.7.0" + vector_math: + dependency: transitive + description: + name: vector_math + sha256: "80b3257d1492ce4d091729e3a67a60407d227c27241d6927be0130c98e741803" + url: "https://pub.dev" + source: hosted + version: "2.1.4" + vm_service: + dependency: transitive + description: + name: vm_service + sha256: "3923c89304b715fb1eb6423f017651664a03bf5f4b29983627c4da791f74a4ec" + url: "https://pub.dev" + source: hosted + version: "14.2.1" +sdks: + dart: ">=3.4.3 <4.0.0" + flutter: ">=3.18.0-18.0.pre.54" diff --git a/mobile/plugins/onnx_dart/pubspec.yaml b/mobile/plugins/onnx_dart/pubspec.yaml new file mode 100644 index 0000000000..5c08c872a4 --- /dev/null +++ b/mobile/plugins/onnx_dart/pubspec.yaml @@ -0,0 +1,70 @@ +name: onnx_dart +description: "A new Flutter plugin project." +version: 0.0.1 +homepage: + +environment: + sdk: '>=3.4.3 <4.0.0' + flutter: '>=3.3.0' + +dependencies: + flutter: + sdk: flutter + plugin_platform_interface: ^2.0.2 + +dev_dependencies: + flutter_test: + sdk: flutter + flutter_lints: ^3.0.0 + +# For information on the generic Dart part of this file, see the +# following page: https://dart.dev/tools/pub/pubspec + +# The following section is specific to Flutter packages. +flutter: + # This section identifies this Flutter project as a plugin project. + # The 'pluginClass' specifies the class (in Java, Kotlin, Swift, Objective-C, etc.) + # which should be registered in the plugin registry. This is required for + # using method channels. + # The Android 'package' specifies package in which the registered class is. + # This is required for using method channels on Android. + # The 'ffiPlugin' specifies that native code should be built and bundled. + # This is required for using `dart:ffi`. + # All these are used by the tooling to maintain consistency when + # adding or updating assets for this project. + plugin: + platforms: + android: + package: io.ente.photos.onnx_dart + pluginClass: OnnxDartPlugin + + # To add assets to your plugin package, add an assets section, like this: + # assets: + # - images/a_dot_burr.jpeg + # - images/a_dot_ham.jpeg + # + # For details regarding assets in packages, see + # https://flutter.dev/assets-and-images/#from-packages + # + # An image asset can refer to one or more resolution-specific "variants", see + # https://flutter.dev/assets-and-images/#resolution-aware + + # To add custom fonts to your plugin package, add a fonts section here, + # in this "flutter" section. Each entry in this list should have a + # "family" key with the font family name, and a "fonts" key with a + # list giving the asset and other descriptors for the font. For + # example: + # fonts: + # - family: Schyler + # fonts: + # - asset: fonts/Schyler-Regular.ttf + # - asset: fonts/Schyler-Italic.ttf + # style: italic + # - family: Trajan Pro + # fonts: + # - asset: fonts/TrajanPro.ttf + # - asset: fonts/TrajanPro_Bold.ttf + # weight: 700 + # + # For details regarding fonts in packages, see + # https://flutter.dev/custom-fonts/#from-packages diff --git a/mobile/pubspec.lock b/mobile/pubspec.lock index bd536e27a7..db75a4328b 100644 --- a/mobile/pubspec.lock +++ b/mobile/pubspec.lock @@ -1606,6 +1606,13 @@ packages: url: "https://pub.dev" source: hosted version: "2.0.0" + onnx_dart: + dependency: "direct main" + description: + path: "plugins/onnx_dart" + relative: true + source: path + version: "0.0.1" onnxruntime: dependency: "direct main" description: @@ -2775,5 +2782,5 @@ packages: source: hosted version: "3.1.2" sdks: - dart: ">=3.4.0 <4.0.0" + dart: ">=3.4.3 <4.0.0" flutter: ">=3.22.0" diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index 58116972ce..2e9a71a4a2 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -128,6 +128,8 @@ dependencies: git: url: https://github.com/ente-io/onnxruntime.git ref: ente_onnxruntime + onnx_dart: + path: plugins/onnx_dart open_mail_app: ^0.4.5 package_info_plus: ^4.1.0 page_transition: ^2.0.2