[mob][photos] new clip works

This commit is contained in:
laurenspriem
2024-09-25 22:58:23 +05:30
parent 53c19bc64e
commit 7dd9d8aef3
5 changed files with 78 additions and 60 deletions

View File

@@ -1,12 +1,10 @@
import "dart:typed_data" show Uint8List;
import "dart:typed_data" show Int32List, Uint8List;
import "dart:ui" show Image;
import "package:logging/logging.dart";
// import "package:onnx_dart/onnx_dart.dart";
import "package:onnx_dart/onnx_dart.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/utils/debug_ml_export_data.dart";
// import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/ml_util.dart";
class ClipImageEncoder extends MlModel {
@@ -36,27 +34,13 @@ class ClipImageEncoder extends MlModel {
int? enteFileID,
]) async {
final startTime = DateTime.now();
// final inputListAa = await preprocessImageClip(image, rawRgbaBytes, true);
// final inputListNoaa = await preprocessImageClip(image, rawRgbaBytes, false);
// await encodeAndSaveData(inputListAa, "star-aa-mobile-input", "clip");
// await encodeAndSaveData(inputListNoaa, "star-noaa-mobile-input", "clip");
final preprocessingTime = DateTime.now();
final preprocessingMs =
preprocessingTime.difference(startTime).inMilliseconds;
late List<double> resultAa; //, resultNoaa;
final inputShape = <int>[image.height, image.width, 4]; // [H, W, C]
late List<double> result;
try {
if (false) {
// resultAa = await _runPlatformPluginPredict(rawRgbaBytes);
// resultNoaa = await _runPlatformPluginPredict(inputListNoaa);
// await encodeAndSaveData(resultAa, "star-aa-mobile-embedding", "clip");
// await encodeAndSaveData(
// resultNoaa,
// "star-noaa-mobile-embedding",
// "clip",
// );
if (MlModel.usePlatformPlugin) {
result = await _runPlatformPluginPredict(rawRgbaBytes, inputShape);
} else {
resultAa = _runFFIBasedPredict(rawRgbaBytes, sessionAddress);
print('clip inference done with FFI package');
result = _runFFIBasedPredict(rawRgbaBytes, inputShape, sessionAddress);
}
} catch (e, stackTrace) {
_logger.severe(
@@ -66,27 +50,21 @@ class ClipImageEncoder extends MlModel {
);
rethrow;
}
final inferTime = DateTime.now();
final inferenceMs = inferTime.difference(preprocessingTime).inMilliseconds;
final totalMs = inferTime.difference(startTime).inMilliseconds;
final totalMs = DateTime.now().difference(startTime).inMilliseconds;
_logger.info(
"Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""} (inference: $inferenceMs ms, preprocessing: $preprocessingMs ms)",
"Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""}",
);
await encodeAndSaveData(
resultAa,
'singapore-rgba-mobile-embedding-ffi',
'clip',
);
return resultAa;
return result;
}
static List<double> _runFFIBasedPredict(
Uint8List inputImageList,
List<int> inputImageShape,
int sessionAddress,
) {
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
[1200, 1920, 4],
inputImageShape,
);
final inputs = {'input': inputOrt};
final session = OrtSession.fromAddress(sessionAddress);
@@ -102,16 +80,18 @@ class ClipImageEncoder extends MlModel {
return embedding;
}
// static Future<List<double>> _runPlatformPluginPredict(
// Uint8List inputImageList,
// ) async {
// final OnnxDart plugin = OnnxDart();
// final result = await plugin.predictRgba(
// inputImageList,
// _modelName,
// );
// final List<double> embedding = result!.sublist(0, 512);
// normalizeEmbedding(embedding);
// return embedding;
// }
static Future<List<double>> _runPlatformPluginPredict(
Uint8List inputImageList,
List<int> inputImageShape,
) async {
final OnnxDart plugin = OnnxDart();
final result = await plugin.predictRgba(
inputImageList,
Int32List.fromList(inputImageShape),
_modelName,
);
final List<double> embedding = result!.sublist(0, 512);
normalizeEmbedding(embedding);
return embedding;
}
}

View File

@@ -93,13 +93,15 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
val modelType = call.argument<String>("modelType")
val inputDataArray = call.argument<FloatArray>("inputData")
val inputIntDataArray = call.argument<IntArray>("inputDataInt")
val inputUint8DataArray = call.argument<ByteArray>("inputDataUint8")
val inputShapeArray = call.argument<IntArray>("inputShapeList")
if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null)) {
if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null && inputUint8DataArray == null)) {
result.error("INVALID_ARGUMENT", "Session address, model type, or input data is missing", null)
return
}
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, result)
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, inputUint8DataArray, inputShapeArray, result)
}
else -> {
result.notImplemented()
@@ -155,9 +157,10 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
}
}
private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, result: Result) {
private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, inputUint8DataArray: ByteArray? = null, inputShapeArray: IntArray? = null, result: Result) {
// Assert that exactly one of inputDataFloat or inputDataInt is provided
assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" }
// assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" }
// TODO: lau re-add above assert later
scope.launch {
val modelState = sessionMap[modelType]
@@ -178,8 +181,11 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
inputTensorShape = longArrayOf(totalSize, 112, 112, 3)
}
ModelType.ClipImageEncoder -> {
inputTensorShape = longArrayOf(1, 3, 256, 256)
// inputTensorShape = longArrayOf(1, 3, 256, 256)
if (inputShapeArray != null) {
inputTensorShape = inputShapeArray.map { it.toLong() }.toLongArray()
} else {
result.error("INVALID_ARGUMENT", "Input shape is missing for clip image input", null)
}
}
ModelType.ClipTextEncoder -> {
inputTensorShape = longArrayOf(1, 77)
@@ -192,6 +198,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
val inputTensor = when {
inputDataFloat != null -> OnnxTensor.createTensor(env, FloatBuffer.wrap(inputDataFloat), inputTensorShape)
inputDataInt != null -> OnnxTensor.createTensor(env, IntBuffer.wrap(inputDataInt), inputTensorShape)
inputUint8DataArray != null -> OnnxTensor.createTensor(env, ByteBuffer.wrap(inputUint8DataArray), inputTensorShape, OnnxJavaType.UINT8)
else -> throw IllegalArgumentException("No input data provided")
}
val inputs = mutableMapOf<String, OnnxTensor>()
@@ -219,7 +226,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
inputTensor.close()
} catch (e: OrtException) {
withContext(Dispatchers.Main) {
result.error("PREDICTION_ERROR", "Error during prediction: ${e.message}", null)
result.error("PREDICTION_ERROR", "Error during prediction: ${e.message} ${e.stackTraceToString()}", null)
}
} catch (e: Exception) {
Log.e(TAG, "Error during prediction: ${e.message}", e)

View File

@@ -25,8 +25,13 @@ class OnnxDart {
String modelType, {
int sessionAddress = 0,
}) async {
final result = await OnnxDartPlatform.instance
.predict(inputData, null, modelType, sessionAddress: sessionAddress);
final result = await OnnxDartPlatform.instance.predict(
inputData,
null,
null,
modelType,
sessionAddress: sessionAddress,
);
return result;
}
@@ -35,8 +40,30 @@ class OnnxDart {
String modelType, {
int sessionAddress = 0,
}) async {
final result = await OnnxDartPlatform.instance
.predict(null, inputDataInt, modelType, sessionAddress: sessionAddress);
final result = await OnnxDartPlatform.instance.predict(
null,
inputDataInt,
null,
modelType,
sessionAddress: sessionAddress,
);
return result;
}
Future<Float32List?> predictRgba(
Uint8List inputBytes,
Int32List inputShape,
String modelType, {
int sessionAddress = 0,
}) async {
final result = await OnnxDartPlatform.instance.predict(
null,
null,
inputBytes,
modelType,
sessionAddress: sessionAddress,
inputShapeList: inputShape,
);
return result;
}
}

View File

@@ -40,8 +40,10 @@ class MethodChannelOnnxDart extends OnnxDartPlatform {
Future<Float32List?> predict(
Float32List? inputData,
Int32List? inputDataInt,
Uint8List? inputDataUint8,
String modelType, {
int sessionAddress = 0,
Int32List? inputShapeList,
}) {
return methodChannel.invokeMethod<Float32List?>(
'predict',
@@ -49,7 +51,9 @@ class MethodChannelOnnxDart extends OnnxDartPlatform {
'sessionAddress': sessionAddress,
'inputData': inputData,
'inputDataInt': inputDataInt,
'inputDataUint8': inputDataUint8,
'modelType': modelType,
"inputShapeList": inputShapeList,
},
);
}

View File

@@ -42,12 +42,12 @@ abstract class OnnxDartPlatform extends PlatformInterface {
Future<Float32List?> predict(
Float32List? inputData,
Int32List? inputDataInt,
Int32List? inputDataInt,
Uint8List? inputDataRgba,
String modelType, {
int sessionAddress = 0,
Int32List? inputShapeList,
}) {
throw UnimplementedError('predict() has not been implemented.');
}
}