[mob][photos] Test for iOS

This commit is contained in:
laurenspriem
2024-09-25 14:34:25 +05:30
parent 8970074f80
commit bd232c151a

View File

@@ -1,15 +1,17 @@
import "dart:typed_data" show Uint8List, Float32List;
import "dart:typed_data" show Float32List, Int8List, Uint8List;
import "dart:ui" show Image;
import "package:logging/logging.dart";
import "package:onnx_dart/onnx_dart.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/debug_ml_export_data.dart";
// import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/ml_util.dart";
class ClipImageEncoder extends MlModel {
static const kRemoteBucketModelPath = "mobileclip_s2_image.onnx";
static const kRemoteBucketModelPath =
"mobileclip_s2_image_opset18_rgba_sim.onnx";
static const _modelName = "ClipImageEncoder";
@override
@@ -34,16 +36,27 @@ class ClipImageEncoder extends MlModel {
int? enteFileID,
]) async {
final startTime = DateTime.now();
final inputList = await preprocessImageClip(image, rawRgbaBytes);
// final inputListAa = await preprocessImageClip(image, rawRgbaBytes, true);
// final inputListNoaa = await preprocessImageClip(image, rawRgbaBytes, false);
// await encodeAndSaveData(inputListAa, "star-aa-mobile-input", "clip");
// await encodeAndSaveData(inputListNoaa, "star-noaa-mobile-input", "clip");
final preprocessingTime = DateTime.now();
final preprocessingMs =
preprocessingTime.difference(startTime).inMilliseconds;
late List<double> result;
late List<double> resultAa; //, resultNoaa;
try {
if (MlModel.usePlatformPlugin) {
result = await _runPlatformPluginPredict(inputList);
if (false) {
resultAa = await _runPlatformPluginPredict(rawRgbaBytes);
// resultNoaa = await _runPlatformPluginPredict(inputListNoaa);
// await encodeAndSaveData(resultAa, "star-aa-mobile-embedding", "clip");
// await encodeAndSaveData(
// resultNoaa,
// "star-noaa-mobile-embedding",
// "clip",
// );
} else {
result = _runFFIBasedPredict(inputList, sessionAddress);
resultAa = _runFFIBasedPredict(rawRgbaBytes, sessionAddress);
print('clip inference done with FFI package');
}
} catch (e, stackTrace) {
_logger.severe(
@@ -59,15 +72,22 @@ class ClipImageEncoder extends MlModel {
_logger.info(
"Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""} (inference: $inferenceMs ms, preprocessing: $preprocessingMs ms)",
);
return result;
await encodeAndSaveData(
resultAa,
'singapore-rgba-mobile-embedding-ffi',
'clip',
);
return resultAa;
}
static List<double> _runFFIBasedPredict(
Float32List inputList,
Uint8List inputImageList,
int sessionAddress,
) {
final inputOrt =
OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 256, 256]);
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
[1200, 1920, 4],
);
final inputs = {'input': inputOrt};
final session = OrtSession.fromAddress(sessionAddress);
final runOptions = OrtRunOptions();
@@ -83,10 +103,10 @@ class ClipImageEncoder extends MlModel {
}
static Future<List<double>> _runPlatformPluginPredict(
Float32List inputImageList,
Uint8List inputImageList,
) async {
final OnnxDart plugin = OnnxDart();
final result = await plugin.predict(
final result = await plugin.predictRgba(
inputImageList,
_modelName,
);