[mob][photos] new clip works
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
import "dart:typed_data" show Uint8List;
|
||||
import "dart:typed_data" show Int32List, Uint8List;
|
||||
import "dart:ui" show Image;
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
// import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
import "package:photos/utils/debug_ml_export_data.dart";
|
||||
// import "package:photos/utils/image_ml_util.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
|
||||
class ClipImageEncoder extends MlModel {
|
||||
@@ -36,27 +34,13 @@ class ClipImageEncoder extends MlModel {
|
||||
int? enteFileID,
|
||||
]) async {
|
||||
final startTime = DateTime.now();
|
||||
// final inputListAa = await preprocessImageClip(image, rawRgbaBytes, true);
|
||||
// final inputListNoaa = await preprocessImageClip(image, rawRgbaBytes, false);
|
||||
// await encodeAndSaveData(inputListAa, "star-aa-mobile-input", "clip");
|
||||
// await encodeAndSaveData(inputListNoaa, "star-noaa-mobile-input", "clip");
|
||||
final preprocessingTime = DateTime.now();
|
||||
final preprocessingMs =
|
||||
preprocessingTime.difference(startTime).inMilliseconds;
|
||||
late List<double> resultAa; //, resultNoaa;
|
||||
final inputShape = <int>[image.height, image.width, 4]; // [H, W, C]
|
||||
late List<double> result;
|
||||
try {
|
||||
if (false) {
|
||||
// resultAa = await _runPlatformPluginPredict(rawRgbaBytes);
|
||||
// resultNoaa = await _runPlatformPluginPredict(inputListNoaa);
|
||||
// await encodeAndSaveData(resultAa, "star-aa-mobile-embedding", "clip");
|
||||
// await encodeAndSaveData(
|
||||
// resultNoaa,
|
||||
// "star-noaa-mobile-embedding",
|
||||
// "clip",
|
||||
// );
|
||||
if (MlModel.usePlatformPlugin) {
|
||||
result = await _runPlatformPluginPredict(rawRgbaBytes, inputShape);
|
||||
} else {
|
||||
resultAa = _runFFIBasedPredict(rawRgbaBytes, sessionAddress);
|
||||
print('clip inference done with FFI package');
|
||||
result = _runFFIBasedPredict(rawRgbaBytes, inputShape, sessionAddress);
|
||||
}
|
||||
} catch (e, stackTrace) {
|
||||
_logger.severe(
|
||||
@@ -66,27 +50,21 @@ class ClipImageEncoder extends MlModel {
|
||||
);
|
||||
rethrow;
|
||||
}
|
||||
final inferTime = DateTime.now();
|
||||
final inferenceMs = inferTime.difference(preprocessingTime).inMilliseconds;
|
||||
final totalMs = inferTime.difference(startTime).inMilliseconds;
|
||||
final totalMs = DateTime.now().difference(startTime).inMilliseconds;
|
||||
_logger.info(
|
||||
"Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""} (inference: $inferenceMs ms, preprocessing: $preprocessingMs ms)",
|
||||
"Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""}",
|
||||
);
|
||||
await encodeAndSaveData(
|
||||
resultAa,
|
||||
'singapore-rgba-mobile-embedding-ffi',
|
||||
'clip',
|
||||
);
|
||||
return resultAa;
|
||||
return result;
|
||||
}
|
||||
|
||||
static List<double> _runFFIBasedPredict(
|
||||
Uint8List inputImageList,
|
||||
List<int> inputImageShape,
|
||||
int sessionAddress,
|
||||
) {
|
||||
final inputOrt = OrtValueTensor.createTensorWithDataList(
|
||||
inputImageList,
|
||||
[1200, 1920, 4],
|
||||
inputImageShape,
|
||||
);
|
||||
final inputs = {'input': inputOrt};
|
||||
final session = OrtSession.fromAddress(sessionAddress);
|
||||
@@ -102,16 +80,18 @@ class ClipImageEncoder extends MlModel {
|
||||
return embedding;
|
||||
}
|
||||
|
||||
// static Future<List<double>> _runPlatformPluginPredict(
|
||||
// Uint8List inputImageList,
|
||||
// ) async {
|
||||
// final OnnxDart plugin = OnnxDart();
|
||||
// final result = await plugin.predictRgba(
|
||||
// inputImageList,
|
||||
// _modelName,
|
||||
// );
|
||||
// final List<double> embedding = result!.sublist(0, 512);
|
||||
// normalizeEmbedding(embedding);
|
||||
// return embedding;
|
||||
// }
|
||||
static Future<List<double>> _runPlatformPluginPredict(
|
||||
Uint8List inputImageList,
|
||||
List<int> inputImageShape,
|
||||
) async {
|
||||
final OnnxDart plugin = OnnxDart();
|
||||
final result = await plugin.predictRgba(
|
||||
inputImageList,
|
||||
Int32List.fromList(inputImageShape),
|
||||
_modelName,
|
||||
);
|
||||
final List<double> embedding = result!.sublist(0, 512);
|
||||
normalizeEmbedding(embedding);
|
||||
return embedding;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,13 +93,15 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
val modelType = call.argument<String>("modelType")
|
||||
val inputDataArray = call.argument<FloatArray>("inputData")
|
||||
val inputIntDataArray = call.argument<IntArray>("inputDataInt")
|
||||
val inputUint8DataArray = call.argument<ByteArray>("inputDataUint8")
|
||||
val inputShapeArray = call.argument<IntArray>("inputShapeList")
|
||||
|
||||
if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null)) {
|
||||
if (sessionAddress == null || modelType == null || (inputDataArray == null && inputIntDataArray == null && inputUint8DataArray == null)) {
|
||||
result.error("INVALID_ARGUMENT", "Session address, model type, or input data is missing", null)
|
||||
return
|
||||
}
|
||||
|
||||
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, result)
|
||||
predict(ModelType.valueOf(modelType), sessionAddress, inputDataArray, inputIntDataArray, inputUint8DataArray, inputShapeArray, result)
|
||||
}
|
||||
else -> {
|
||||
result.notImplemented()
|
||||
@@ -155,9 +157,10 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
}
|
||||
}
|
||||
|
||||
private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, result: Result) {
|
||||
private fun predict(modelType: ModelType, sessionAddress: Int, inputDataFloat: FloatArray? = null, inputDataInt: IntArray? = null, inputUint8DataArray: ByteArray? = null, inputShapeArray: IntArray? = null, result: Result) {
|
||||
// Assert that exactly one of inputDataFloat or inputDataInt is provided
|
||||
assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" }
|
||||
// assert((inputDataFloat != null).xor(inputDataInt != null)) { "Exactly one of inputDataFloat or inputDataInt must be provided" }
|
||||
// TODO: lau re-add above assert later
|
||||
|
||||
scope.launch {
|
||||
val modelState = sessionMap[modelType]
|
||||
@@ -178,8 +181,11 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
inputTensorShape = longArrayOf(totalSize, 112, 112, 3)
|
||||
}
|
||||
ModelType.ClipImageEncoder -> {
|
||||
inputTensorShape = longArrayOf(1, 3, 256, 256)
|
||||
// inputTensorShape = longArrayOf(1, 3, 256, 256)
|
||||
if (inputShapeArray != null) {
|
||||
inputTensorShape = inputShapeArray.map { it.toLong() }.toLongArray()
|
||||
} else {
|
||||
result.error("INVALID_ARGUMENT", "Input shape is missing for clip image input", null)
|
||||
}
|
||||
}
|
||||
ModelType.ClipTextEncoder -> {
|
||||
inputTensorShape = longArrayOf(1, 77)
|
||||
@@ -192,6 +198,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
val inputTensor = when {
|
||||
inputDataFloat != null -> OnnxTensor.createTensor(env, FloatBuffer.wrap(inputDataFloat), inputTensorShape)
|
||||
inputDataInt != null -> OnnxTensor.createTensor(env, IntBuffer.wrap(inputDataInt), inputTensorShape)
|
||||
inputUint8DataArray != null -> OnnxTensor.createTensor(env, ByteBuffer.wrap(inputUint8DataArray), inputTensorShape, OnnxJavaType.UINT8)
|
||||
else -> throw IllegalArgumentException("No input data provided")
|
||||
}
|
||||
val inputs = mutableMapOf<String, OnnxTensor>()
|
||||
@@ -219,7 +226,7 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
inputTensor.close()
|
||||
} catch (e: OrtException) {
|
||||
withContext(Dispatchers.Main) {
|
||||
result.error("PREDICTION_ERROR", "Error during prediction: ${e.message}", null)
|
||||
result.error("PREDICTION_ERROR", "Error during prediction: ${e.message} ${e.stackTraceToString()}", null)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error during prediction: ${e.message}", e)
|
||||
|
||||
@@ -25,8 +25,13 @@ class OnnxDart {
|
||||
String modelType, {
|
||||
int sessionAddress = 0,
|
||||
}) async {
|
||||
final result = await OnnxDartPlatform.instance
|
||||
.predict(inputData, null, modelType, sessionAddress: sessionAddress);
|
||||
final result = await OnnxDartPlatform.instance.predict(
|
||||
inputData,
|
||||
null,
|
||||
null,
|
||||
modelType,
|
||||
sessionAddress: sessionAddress,
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -35,8 +40,30 @@ class OnnxDart {
|
||||
String modelType, {
|
||||
int sessionAddress = 0,
|
||||
}) async {
|
||||
final result = await OnnxDartPlatform.instance
|
||||
.predict(null, inputDataInt, modelType, sessionAddress: sessionAddress);
|
||||
final result = await OnnxDartPlatform.instance.predict(
|
||||
null,
|
||||
inputDataInt,
|
||||
null,
|
||||
modelType,
|
||||
sessionAddress: sessionAddress,
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
Future<Float32List?> predictRgba(
|
||||
Uint8List inputBytes,
|
||||
Int32List inputShape,
|
||||
String modelType, {
|
||||
int sessionAddress = 0,
|
||||
}) async {
|
||||
final result = await OnnxDartPlatform.instance.predict(
|
||||
null,
|
||||
null,
|
||||
inputBytes,
|
||||
modelType,
|
||||
sessionAddress: sessionAddress,
|
||||
inputShapeList: inputShape,
|
||||
);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,8 +40,10 @@ class MethodChannelOnnxDart extends OnnxDartPlatform {
|
||||
Future<Float32List?> predict(
|
||||
Float32List? inputData,
|
||||
Int32List? inputDataInt,
|
||||
Uint8List? inputDataUint8,
|
||||
String modelType, {
|
||||
int sessionAddress = 0,
|
||||
Int32List? inputShapeList,
|
||||
}) {
|
||||
return methodChannel.invokeMethod<Float32List?>(
|
||||
'predict',
|
||||
@@ -49,7 +51,9 @@ class MethodChannelOnnxDart extends OnnxDartPlatform {
|
||||
'sessionAddress': sessionAddress,
|
||||
'inputData': inputData,
|
||||
'inputDataInt': inputDataInt,
|
||||
'inputDataUint8': inputDataUint8,
|
||||
'modelType': modelType,
|
||||
"inputShapeList": inputShapeList,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ abstract class OnnxDartPlatform extends PlatformInterface {
|
||||
|
||||
Future<Float32List?> predict(
|
||||
Float32List? inputData,
|
||||
Int32List? inputDataInt,
|
||||
Int32List? inputDataInt,
|
||||
Uint8List? inputDataRgba,
|
||||
String modelType, {
|
||||
int sessionAddress = 0,
|
||||
Int32List? inputShapeList,
|
||||
}) {
|
||||
throw UnimplementedError('predict() has not been implemented.');
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user