[mob][photos] Test for iOS
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
import "dart:typed_data" show Uint8List, Float32List;
|
||||
import "dart:typed_data" show Float32List, Int8List, Uint8List;
|
||||
import "dart:ui" show Image;
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
import "package:photos/utils/image_ml_util.dart";
|
||||
import "package:photos/utils/debug_ml_export_data.dart";
|
||||
// import "package:photos/utils/image_ml_util.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
|
||||
class ClipImageEncoder extends MlModel {
|
||||
static const kRemoteBucketModelPath = "mobileclip_s2_image.onnx";
|
||||
static const kRemoteBucketModelPath =
|
||||
"mobileclip_s2_image_opset18_rgba_sim.onnx";
|
||||
static const _modelName = "ClipImageEncoder";
|
||||
|
||||
@override
|
||||
@@ -34,16 +36,27 @@ class ClipImageEncoder extends MlModel {
|
||||
int? enteFileID,
|
||||
]) async {
|
||||
final startTime = DateTime.now();
|
||||
final inputList = await preprocessImageClip(image, rawRgbaBytes);
|
||||
// final inputListAa = await preprocessImageClip(image, rawRgbaBytes, true);
|
||||
// final inputListNoaa = await preprocessImageClip(image, rawRgbaBytes, false);
|
||||
// await encodeAndSaveData(inputListAa, "star-aa-mobile-input", "clip");
|
||||
// await encodeAndSaveData(inputListNoaa, "star-noaa-mobile-input", "clip");
|
||||
final preprocessingTime = DateTime.now();
|
||||
final preprocessingMs =
|
||||
preprocessingTime.difference(startTime).inMilliseconds;
|
||||
late List<double> result;
|
||||
late List<double> resultAa; //, resultNoaa;
|
||||
try {
|
||||
if (MlModel.usePlatformPlugin) {
|
||||
result = await _runPlatformPluginPredict(inputList);
|
||||
if (false) {
|
||||
resultAa = await _runPlatformPluginPredict(rawRgbaBytes);
|
||||
// resultNoaa = await _runPlatformPluginPredict(inputListNoaa);
|
||||
// await encodeAndSaveData(resultAa, "star-aa-mobile-embedding", "clip");
|
||||
// await encodeAndSaveData(
|
||||
// resultNoaa,
|
||||
// "star-noaa-mobile-embedding",
|
||||
// "clip",
|
||||
// );
|
||||
} else {
|
||||
result = _runFFIBasedPredict(inputList, sessionAddress);
|
||||
resultAa = _runFFIBasedPredict(rawRgbaBytes, sessionAddress);
|
||||
print('clip inference done with FFI package');
|
||||
}
|
||||
} catch (e, stackTrace) {
|
||||
_logger.severe(
|
||||
@@ -59,15 +72,22 @@ class ClipImageEncoder extends MlModel {
|
||||
_logger.info(
|
||||
"Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""} (inference: $inferenceMs ms, preprocessing: $preprocessingMs ms)",
|
||||
);
|
||||
return result;
|
||||
await encodeAndSaveData(
|
||||
resultAa,
|
||||
'singapore-rgba-mobile-embedding-ffi',
|
||||
'clip',
|
||||
);
|
||||
return resultAa;
|
||||
}
|
||||
|
||||
static List<double> _runFFIBasedPredict(
|
||||
Float32List inputList,
|
||||
Uint8List inputImageList,
|
||||
int sessionAddress,
|
||||
) {
|
||||
final inputOrt =
|
||||
OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 256, 256]);
|
||||
final inputOrt = OrtValueTensor.createTensorWithDataList(
|
||||
inputImageList,
|
||||
[1200, 1920, 4],
|
||||
);
|
||||
final inputs = {'input': inputOrt};
|
||||
final session = OrtSession.fromAddress(sessionAddress);
|
||||
final runOptions = OrtRunOptions();
|
||||
@@ -83,10 +103,10 @@ class ClipImageEncoder extends MlModel {
|
||||
}
|
||||
|
||||
static Future<List<double>> _runPlatformPluginPredict(
|
||||
Float32List inputImageList,
|
||||
Uint8List inputImageList,
|
||||
) async {
|
||||
final OnnxDart plugin = OnnxDart();
|
||||
final result = await plugin.predict(
|
||||
final result = await plugin.predictRgba(
|
||||
inputImageList,
|
||||
_modelName,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user