diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart index f1dab0d935..46a748f27c 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart @@ -37,7 +37,9 @@ class FaceDetectionService extends MlModel { // Singleton pattern FaceDetectionService._privateConstructor(); + static final instance = FaceDetectionService._privateConstructor(); + factory FaceDetectionService() => instance; /// Detects faces in the given image data. @@ -141,17 +143,19 @@ class FaceDetectionService extends MlModel { inputImageList, _modelName, ); - final List>> reconstructedTensor = []; - for (int i = 0; i < result!.length; i += 25200 * 16) { - final List> outerArray = []; - for (int j = 0; j < 25200; j++) { - final List innerArray = - result.sublist(i + j * 16, i + (j + 1) * 16).cast(); - outerArray.add(innerArray); - } - reconstructedTensor.add(outerArray); - } - return reconstructedTensor; + + final int resultLength = result!.length; + assert(resultLength % 25200 * 16 == 0); + const int outerLength = 1; + const int middleLength = 25200; + const int innerLength = 16; + return List.generate( + outerLength, + (_) => List.generate( + middleLength, + (j) => result.sublist(j * innerLength, (j + 1) * innerLength).toList(), + ), + ); } static List _yoloPostProcessOutputs( diff --git a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt index c93da66a1e..61d7246050 100644 --- a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt +++ b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt @@ -175,8 +175,8 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { val inputs = mapOf("input" to inputTensor) val outputs = session.run(inputs) Log.d(TAG, "Output shape: ${outputs.size()}") - val outputTensor = outputs[0].value as Array> - val flatList = outputTensor.flatMapToFloatArray() + val outputTensor = (outputs[0].value as Array>).get(0) + val flatList = outputTensor.flattenToFloatArray() withContext(Dispatchers.Main) { result.success(flatList) } @@ -217,4 +217,15 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { } return result } + fun Array.flattenToFloatArray(): FloatArray { + val outputSize = this.sumOf { it.size } + val result = FloatArray(outputSize) + var index = 0 + for (inner in this) { + for (value in inner) { + result[index++] = value + } + } + return result + } }