From bd232c151a2a5357f0aef29f9b7afa3b527a181a Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 25 Sep 2024 14:34:25 +0530 Subject: [PATCH] [mob][photos] Test for iOS --- .../clip/clip_image_encoder.dart | 48 +++++++++++++------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart index f51cf25516..0f76a25ff6 100644 --- a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart @@ -1,15 +1,17 @@ -import "dart:typed_data" show Uint8List, Float32List; +import "dart:typed_data" show Float32List, Int8List, Uint8List; import "dart:ui" show Image; import "package:logging/logging.dart"; import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/machine_learning/ml_model.dart"; -import "package:photos/utils/image_ml_util.dart"; +import "package:photos/utils/debug_ml_export_data.dart"; +// import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/ml_util.dart"; class ClipImageEncoder extends MlModel { - static const kRemoteBucketModelPath = "mobileclip_s2_image.onnx"; + static const kRemoteBucketModelPath = + "mobileclip_s2_image_opset18_rgba_sim.onnx"; static const _modelName = "ClipImageEncoder"; @override @@ -34,16 +36,27 @@ class ClipImageEncoder extends MlModel { int? enteFileID, ]) async { final startTime = DateTime.now(); - final inputList = await preprocessImageClip(image, rawRgbaBytes); + // final inputListAa = await preprocessImageClip(image, rawRgbaBytes, true); + // final inputListNoaa = await preprocessImageClip(image, rawRgbaBytes, false); + // await encodeAndSaveData(inputListAa, "star-aa-mobile-input", "clip"); + // await encodeAndSaveData(inputListNoaa, "star-noaa-mobile-input", "clip"); final preprocessingTime = DateTime.now(); final preprocessingMs = preprocessingTime.difference(startTime).inMilliseconds; - late List result; + late List resultAa; //, resultNoaa; try { - if (MlModel.usePlatformPlugin) { - result = await _runPlatformPluginPredict(inputList); + if (false) { + resultAa = await _runPlatformPluginPredict(rawRgbaBytes); + // resultNoaa = await _runPlatformPluginPredict(inputListNoaa); + // await encodeAndSaveData(resultAa, "star-aa-mobile-embedding", "clip"); + // await encodeAndSaveData( + // resultNoaa, + // "star-noaa-mobile-embedding", + // "clip", + // ); } else { - result = _runFFIBasedPredict(inputList, sessionAddress); + resultAa = _runFFIBasedPredict(rawRgbaBytes, sessionAddress); + print('clip inference done with FFI package'); } } catch (e, stackTrace) { _logger.severe( @@ -59,15 +72,22 @@ class ClipImageEncoder extends MlModel { _logger.info( "Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""} (inference: $inferenceMs ms, preprocessing: $preprocessingMs ms)", ); - return result; + await encodeAndSaveData( + resultAa, + 'singapore-rgba-mobile-embedding-ffi', + 'clip', + ); + return resultAa; } static List _runFFIBasedPredict( - Float32List inputList, + Uint8List inputImageList, int sessionAddress, ) { - final inputOrt = - OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 256, 256]); + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + [1200, 1920, 4], + ); final inputs = {'input': inputOrt}; final session = OrtSession.fromAddress(sessionAddress); final runOptions = OrtRunOptions(); @@ -83,10 +103,10 @@ class ClipImageEncoder extends MlModel { } static Future> _runPlatformPluginPredict( - Float32List inputImageList, + Uint8List inputImageList, ) async { final OnnxDart plugin = OnnxDart(); - final result = await plugin.predict( + final result = await plugin.predictRgba( inputImageList, _modelName, );