[mob][onnx] Clean up
This commit is contained in:
@@ -17,33 +17,6 @@ import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
|
||||
object LongArrayPool {
|
||||
private val poolMap = ConcurrentHashMap<Int, ConcurrentLinkedQueue<LongArray>>()
|
||||
|
||||
fun get(size: Int): LongArray {
|
||||
val pool = getPool(size)
|
||||
return pool.poll() ?: LongArray(size)
|
||||
}
|
||||
|
||||
fun release(array: LongArray) {
|
||||
val pool = getPool(array.size)
|
||||
pool.offer(array)
|
||||
}
|
||||
|
||||
private fun getPool(size: Int): ConcurrentLinkedQueue<LongArray> {
|
||||
var pool = poolMap[size]
|
||||
if (pool == null) {
|
||||
synchronized(poolMap) {
|
||||
pool = poolMap[size]
|
||||
if (pool == null) {
|
||||
pool = ConcurrentLinkedQueue()
|
||||
poolMap[size] = pool!!
|
||||
}
|
||||
}
|
||||
}
|
||||
return pool!!
|
||||
}
|
||||
}
|
||||
/** OnnxDartPlugin */
|
||||
class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
/// The MethodChannel that will the communication between Flutter and native Android
|
||||
@@ -192,18 +165,12 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
|
||||
try {
|
||||
val env = OrtEnvironment.getEnvironment()
|
||||
val inputTensorShape = LongArrayPool.get(4).apply {
|
||||
this[0] = 1
|
||||
this[1] = K_NUM_CHANNELS.toLong()
|
||||
this[2] = K_INPUT_HEIGHT.toLong()
|
||||
this[3] = K_INPUT_WIDTH.toLong()
|
||||
}
|
||||
val inputTensorShape = longArrayOf(1, K_NUM_CHANNELS.toLong(), K_INPUT_HEIGHT.toLong(), K_INPUT_WIDTH.toLong())
|
||||
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
|
||||
val inputs = mapOf("input" to inputTensor)
|
||||
val outputs = session.run(inputs)
|
||||
Log.d("OnnxFlutterPlugin", "Output shape: ${outputs.size()}")
|
||||
|
||||
|
||||
inputTensor.close()
|
||||
val totalSize = 1 * 25200 * 16
|
||||
val flatArray = FloatArray(totalSize) { index -> index + 1.0f }
|
||||
@@ -218,7 +185,6 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
Log.d("OnnxFlutterPlugin", "Output2 shape: ${outputTensor.size}")
|
||||
|
||||
outputs.close()
|
||||
LongArrayPool.release(inputTensorShape)
|
||||
// Send the result back to the Dart layer
|
||||
val flatList = outputTensor.flatten().flatMap { it.toList() }
|
||||
val flatList2 = outputTensor2.flatten().flatMap { it.toList() }
|
||||
|
||||
Reference in New Issue
Block a user