diff --git a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt index 61d7246050..d212b75ab3 100644 --- a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt +++ b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt @@ -39,9 +39,6 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { } companion object { const val DEFAULT_SESSION_COUNT = 1 - const val K_INPUT_WIDTH = 640 - const val K_INPUT_HEIGHT = 640 - const val K_NUM_CHANNELS = 3 } @@ -170,7 +167,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { try { val env = OrtEnvironment.getEnvironment() - val inputTensorShape = longArrayOf(1, K_NUM_CHANNELS.toLong(), K_INPUT_HEIGHT.toLong(), K_INPUT_WIDTH.toLong()) + val inputTensorShape = longArrayOf(1, 3, 640, 640) val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape) val inputs = mapOf("input" to inputTensor) val outputs = session.run(inputs)