[mob][onnx] Clean up

This commit is contained in:
Neeraj Gupta
2024-07-09 14:31:03 +05:30
parent bc04a3f87f
commit f422dea00e

View File

@@ -17,33 +17,6 @@ import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.concurrent.ConcurrentLinkedQueue
object LongArrayPool {
private val poolMap = ConcurrentHashMap<Int, ConcurrentLinkedQueue<LongArray>>()
fun get(size: Int): LongArray {
val pool = getPool(size)
return pool.poll() ?: LongArray(size)
}
fun release(array: LongArray) {
val pool = getPool(array.size)
pool.offer(array)
}
private fun getPool(size: Int): ConcurrentLinkedQueue<LongArray> {
var pool = poolMap[size]
if (pool == null) {
synchronized(poolMap) {
pool = poolMap[size]
if (pool == null) {
pool = ConcurrentLinkedQueue()
poolMap[size] = pool!!
}
}
}
return pool!!
}
}
/** OnnxDartPlugin */
class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
/// The MethodChannel that will the communication between Flutter and native Android
@@ -192,18 +165,12 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
try {
val env = OrtEnvironment.getEnvironment()
val inputTensorShape = LongArrayPool.get(4).apply {
this[0] = 1
this[1] = K_NUM_CHANNELS.toLong()
this[2] = K_INPUT_HEIGHT.toLong()
this[3] = K_INPUT_WIDTH.toLong()
}
val inputTensorShape = longArrayOf(1, K_NUM_CHANNELS.toLong(), K_INPUT_HEIGHT.toLong(), K_INPUT_WIDTH.toLong())
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
val inputs = mapOf("input" to inputTensor)
val outputs = session.run(inputs)
Log.d("OnnxFlutterPlugin", "Output shape: ${outputs.size()}")
inputTensor.close()
val totalSize = 1 * 25200 * 16
val flatArray = FloatArray(totalSize) { index -> index + 1.0f }
@@ -218,7 +185,6 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
Log.d("OnnxFlutterPlugin", "Output2 shape: ${outputTensor.size}")
outputs.close()
LongArrayPool.release(inputTensorShape)
// Send the result back to the Dart layer
val flatList = outputTensor.flatten().flatMap { it.toList() }
val flatList2 = outputTensor2.flatten().flatMap { it.toList() }