diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart b/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart index 81d1df75f4..a1258edbfb 100644 --- a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart +++ b/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart @@ -4,6 +4,7 @@ import "package:onnxruntime/onnxruntime.dart"; import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart'; import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart'; import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_encoder.dart'; +import "package:photos/utils/image_isolate.dart"; class ONNX extends MLFramework { static const kModelBucketEndpoint = "https://models.ente.io/"; @@ -72,11 +73,11 @@ class ONNX extends MLFramework { Future> getImageEmbedding(String imagePath) async { try { final startTime = DateTime.now(); - final result = await _clipImage.inferByImage({ - // TODO: add computer back later - "imagePath": imagePath, - "address": _imageEncoderAddress, - }); + // TODO: properly integrate with other ml later (FaceMlService) + final result = await ImageIsolate.instance.inferClipImageEmbedding( + imagePath, + _imageEncoderAddress, + ); final endTime = DateTime.now(); _logger.info( "createImageEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart index 8988913e3b..f4fe9c4c0f 100644 --- a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart @@ -26,7 +26,7 @@ class OnnxImageEncoder { return -1; } - Future> inferByImage(Map args) async { + static Future> inferByImage(Map args) async { final imageData = await File(args["imagePath"]).readAsBytes(); final image = await decodeImageFromData(imageData); final ByteData imgByteData = await getByteDataFromImage(image); diff --git a/mobile/lib/utils/image_isolate.dart b/mobile/lib/utils/image_isolate.dart index 9ebad2bb94..ce67c9fa9f 100644 --- a/mobile/lib/utils/image_isolate.dart +++ b/mobile/lib/utils/image_isolate.dart @@ -6,11 +6,13 @@ import 'dart:typed_data' show Uint8List; import "package:dart_ui_isolate/dart_ui_isolate.dart"; import "package:logging/logging.dart"; import "package:photos/face/model/box.dart"; +import "package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart"; import "package:photos/utils/image_ml_util.dart"; import "package:synchronized/synchronized.dart"; enum ImageOperation { generateFaceThumbnails, + clip, } class ImageIsolate { @@ -88,6 +90,15 @@ class ImageIsolate { faceBoxes, ); sendPort.send(List.from(results)); + case ImageOperation.clip: + final imagePath = args['imagePath'] as String; + final address = args['address'] as int; + final result = await OnnxImageEncoder.inferByImage({ + 'imagePath': imagePath, + 'address': address, + }); + sendPort.send(List.from(result)); + break; } } catch (e, stackTrace) { sendPort @@ -175,4 +186,16 @@ class ImageIsolate { ), ).then((value) => value.cast()); } + + Future> inferClipImageEmbedding(String imagePath, int encoderAddress) async { + return await _runInIsolate( + ( + ImageOperation.clip, + { + 'imagePath': imagePath, + 'address': encoderAddress, + }, + ), + ).then((value) => value.cast()); + } }