[mob] Refactor
This commit is contained in:
@@ -37,7 +37,9 @@ class FaceDetectionService extends MlModel {
|
||||
|
||||
// Singleton pattern
|
||||
FaceDetectionService._privateConstructor();
|
||||
|
||||
static final instance = FaceDetectionService._privateConstructor();
|
||||
|
||||
factory FaceDetectionService() => instance;
|
||||
|
||||
/// Detects faces in the given image data.
|
||||
@@ -141,17 +143,19 @@ class FaceDetectionService extends MlModel {
|
||||
inputImageList,
|
||||
_modelName,
|
||||
);
|
||||
final List<List<List<double>>> reconstructedTensor = [];
|
||||
for (int i = 0; i < result!.length; i += 25200 * 16) {
|
||||
final List<List<double>> outerArray = [];
|
||||
for (int j = 0; j < 25200; j++) {
|
||||
final List<double> innerArray =
|
||||
result.sublist(i + j * 16, i + (j + 1) * 16).cast<double>();
|
||||
outerArray.add(innerArray);
|
||||
}
|
||||
reconstructedTensor.add(outerArray);
|
||||
}
|
||||
return reconstructedTensor;
|
||||
|
||||
final int resultLength = result!.length;
|
||||
assert(resultLength % 25200 * 16 == 0);
|
||||
const int outerLength = 1;
|
||||
const int middleLength = 25200;
|
||||
const int innerLength = 16;
|
||||
return List.generate(
|
||||
outerLength,
|
||||
(_) => List.generate(
|
||||
middleLength,
|
||||
(j) => result.sublist(j * innerLength, (j + 1) * innerLength).toList(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
static List<FaceDetectionRelative> _yoloPostProcessOutputs(
|
||||
|
||||
@@ -175,8 +175,8 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
val inputs = mapOf("input" to inputTensor)
|
||||
val outputs = session.run(inputs)
|
||||
Log.d(TAG, "Output shape: ${outputs.size()}")
|
||||
val outputTensor = outputs[0].value as Array<Array<FloatArray>>
|
||||
val flatList = outputTensor.flatMapToFloatArray()
|
||||
val outputTensor = (outputs[0].value as Array<Array<FloatArray>>).get(0)
|
||||
val flatList = outputTensor.flattenToFloatArray()
|
||||
withContext(Dispatchers.Main) {
|
||||
result.success(flatList)
|
||||
}
|
||||
@@ -217,4 +217,15 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
}
|
||||
return result
|
||||
}
|
||||
fun Array<FloatArray>.flattenToFloatArray(): FloatArray {
|
||||
val outputSize = this.sumOf { it.size }
|
||||
val result = FloatArray(outputSize)
|
||||
var index = 0
|
||||
for (inner in this) {
|
||||
for (value in inner) {
|
||||
result[index++] = value
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user