From a33bbb22aee816ebe73f95d66fec9bbde49fc2c4 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 6 Nov 2024 14:46:53 +0530 Subject: [PATCH] [mob][photos] Single rust method for faces and clip preprocessing --- .../face_detection_service.dart | 36 +++--- .../face_ml/face_recognition_service.dart | 22 ++-- .../clip/clip_image_encoder.dart | 20 ++-- .../semantic_search_service.dart | 13 +-- mobile/lib/src/rust/api/image_processing.dart | 4 + mobile/lib/src/rust/frb_generated.dart | 100 ++++++++++++++++- mobile/lib/src/rust/frb_generated.io.dart | 81 ++++++++++++++ mobile/lib/utils/ml_util.dart | 24 ++-- mobile/rust/src/api/image_processing.rs | 88 +++++++++++++++ mobile/rust/src/frb_generated.rs | 104 +++++++++++++++++- 10 files changed, 434 insertions(+), 58 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart index 37df4cb955..319deaa4f0 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart @@ -1,6 +1,5 @@ import "dart:async"; import 'dart:typed_data' show Float32List, Uint8List; -import 'dart:ui' as ui show Image; import 'package:logging/logging.dart'; import "package:onnx_dart/onnx_dart.dart"; @@ -9,7 +8,6 @@ import "package:photos/models/ml/face/dimension.dart"; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart"; import "package:photos/services/machine_learning/ml_model.dart"; -import "package:photos/src/rust/api/image_processing.dart"; import "package:photos/utils/image_ml_util.dart"; class YOLOFaceInterpreterRunException implements Exception {} @@ -44,10 +42,10 @@ class FaceDetectionService extends MlModel { /// Detects faces in the given image data. static Future> predict( - ui.Image image, - Uint8List rawRgbaBytes, + Uint8List resizedBytes, + int resizedHeight, + int resizedWidth, int sessionAddress, - String imagePath, ) async { assert( !MlModel.usePlatformPlugin @@ -58,32 +56,30 @@ class FaceDetectionService extends MlModel { final startTime = DateTime.now(); - final (result, timing, newWidth, newHeight) = - await processYoloFace(imagePath: imagePath); - _logger.info("Face detection preprocessing: \n $timing"); - final scaledSize = - Dimensions(width: newWidth.toInt(), height: newHeight.toInt()); - final preprocessingTime = DateTime.now(); - final preprocessingMs = - preprocessingTime.difference(startTime).inMilliseconds; + // final (result, timing, newWidth, newHeight) = + // await processYoloFace(imagePath: imagePath); + // _logger.info("Face detection preprocessing: \n $timing"); + // final scaledSize = + // Dimensions(width: newWidth.toInt(), height: newHeight.toInt()); + // final preprocessingTime = DateTime.now(); + // final preprocessingMs = + // preprocessingTime.difference(startTime).inMilliseconds; - final tempTime = DateTime.now(); + final preprocessingTime = DateTime.now(); + final scaledSize = Dimensions(width: resizedWidth, height: resizedHeight); final inputImageList = await resizedToPreprocessedYoloFace( - result, + resizedBytes, scaledSize.width, scaledSize.height, ); - _logger.info( - 'Face detection remaining dart processing: ${DateTime.now().difference(tempTime).inMilliseconds} ms', - ); // final (inputImageList, scaledSize) = await preprocessImageYoloFace( // image, // rawRgbaBytes, // ); // final preprocessingTime = DateTime.now(); - // final preprocessingMs = - // preprocessingTime.difference(startTime).inMilliseconds; + final preprocessingMs = + preprocessingTime.difference(startTime).inMilliseconds; // Run inference List>>? nestedResults = []; diff --git a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart index 4af71019b3..f788fbf62e 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart @@ -73,9 +73,11 @@ class FaceRecognitionService { int enteFileID, Image image, Uint8List rawRgbaBytes, + Uint8List resizedBytes, + int resizedHeight, + int resizedWidth, int faceDetectionAddress, int faceEmbeddingAddress, - String imagePath, ) async { final faceResults = []; final startTime = DateTime.now(); @@ -84,11 +86,11 @@ class FaceRecognitionService { final List faceDetectionResult = await _detectFacesSync( enteFileID, - image, - rawRgbaBytes, + resizedBytes, + resizedHeight, + resizedWidth, faceDetectionAddress, faceResults, - imagePath, ); final detectFacesTime = DateTime.now(); final detectFacesMs = detectFacesTime.difference(startTime).inMilliseconds; @@ -133,20 +135,20 @@ class FaceRecognitionService { /// Runs face recognition on the given image data. static Future> _detectFacesSync( int fileID, - Image image, - Uint8List rawRgbaBytes, + Uint8List resizedBytes, + int resizedHeight, + int resizedWidth, int interpreterAddress, List faceResults, - String imagePath, ) async { try { // Get the bounding boxes of the faces final List faces = await FaceDetectionService.predict( - image, - rawRgbaBytes, + resizedBytes, + resizedHeight, + resizedWidth, interpreterAddress, - imagePath, ); // Add detected faces to the faceResults diff --git a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart index 2fb506b0f8..1b21d057cb 100644 --- a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart @@ -1,11 +1,9 @@ import "dart:typed_data" show Uint8List, Float32List; -import "dart:ui" show Image; import "package:logging/logging.dart"; import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/machine_learning/ml_model.dart"; -import "package:photos/src/rust/api/image_processing.dart"; import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/ml_util.dart"; @@ -29,17 +27,17 @@ class ClipImageEncoder extends MlModel { factory ClipImageEncoder() => instance; static Future> predict( - Image image, - Uint8List rawRgbaBytes, - int sessionAddress, - String imagePath, [ + Uint8List resizedBytes, + int resizedHeight, + int resizedWidth, + int sessionAddress, [ int? enteFileID, ]) async { final startTime = DateTime.now(); // final inputList = await preprocessImageClip(image, rawRgbaBytes); - final (resizedBytes, timing, rgbWidth, rgbHeight) = - await processClip(imagePath: imagePath); - _logger.info("Clip preprocessing: \n $timing"); + // final (resizedBytes, timing, rgbWidth, rgbHeight) = + // await processClip(imagePath: imagePath); + // _logger.info("Clip preprocessing: \n $timing"); final preprocessingTime = DateTime.now(); final preprocessingMs = preprocessingTime.difference(startTime).inMilliseconds; @@ -47,8 +45,8 @@ class ClipImageEncoder extends MlModel { final tempTime = DateTime.now(); final inputList = await resizedToPreprocessedClip( resizedBytes, - rgbWidth.toInt(), - rgbHeight.toInt(), + resizedWidth, + resizedHeight, ); _logger.info( 'Clip remaining dart processing: ${DateTime.now().difference(tempTime).inMilliseconds} ms', diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index ead64a97bf..e6b5f60fea 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -1,7 +1,6 @@ import "dart:async" show unawaited; import "dart:developer" as dev show log; import "dart:math" show min; -import "dart:ui" show Image; import "package:computer/computer.dart"; import "package:flutter/foundation.dart"; @@ -267,16 +266,16 @@ class SemanticSearchService { static Future runClipImage( int enteFileID, - Image image, - Uint8List rawRgbaBytes, + Uint8List resizedBytes, + int resizedHeight, + int resizedWidth, int clipImageAddress, - String imagePath, ) async { final embedding = await ClipImageEncoder.predict( - image, - rawRgbaBytes, + resizedBytes, + resizedHeight, + resizedWidth, clipImageAddress, - imagePath, enteFileID, ); diff --git a/mobile/lib/src/rust/api/image_processing.dart b/mobile/lib/src/rust/api/image_processing.dart index d812d4d96c..69a6d84fc3 100644 --- a/mobile/lib/src/rust/api/image_processing.dart +++ b/mobile/lib/src/rust/api/image_processing.dart @@ -8,6 +8,10 @@ import 'package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart'; import 'package:photos/src/rust/frb_generated.dart'; +Future<(Uint8List, BigInt, BigInt, Uint8List, BigInt, BigInt)> + processImageMlFromPath({required String imagePath}) => RustLib.instance.api + .crateApiImageProcessingProcessImageMlFromPath(imagePath: imagePath); + Future<(Uint8List, String, BigInt, BigInt)> processYoloFace( {required String imagePath}) => RustLib.instance.api diff --git a/mobile/lib/src/rust/frb_generated.dart b/mobile/lib/src/rust/frb_generated.dart index 2dd1ad6bf8..50042b923c 100644 --- a/mobile/lib/src/rust/frb_generated.dart +++ b/mobile/lib/src/rust/frb_generated.dart @@ -72,7 +72,7 @@ class RustLib extends BaseEntrypoint { String get codegenVersion => '2.5.1'; @override - int get rustContentHash => -2017876324; + int get rustContentHash => -1741400115; static const kDefaultExternalLibraryLoaderConfig = ExternalLibraryLoaderConfig( @@ -86,6 +86,10 @@ abstract class RustLibApi extends BaseApi { Future<(Uint8List, String, BigInt, BigInt)> crateApiImageProcessingProcessClip({required String imagePath}); + Future<(Uint8List, BigInt, BigInt, Uint8List, BigInt, BigInt)> + crateApiImageProcessingProcessImageMlFromPath( + {required String imagePath}); + Future<(Uint8List, String, BigInt, BigInt)> crateApiImageProcessingProcessYoloFace({required String imagePath}); @@ -128,6 +132,34 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { argNames: ["imagePath"], ); + @override + Future<(Uint8List, BigInt, BigInt, Uint8List, BigInt, BigInt)> + crateApiImageProcessingProcessImageMlFromPath( + {required String imagePath}) { + return handler.executeNormal(NormalTask( + callFfi: (port_) { + final arg0 = cst_encode_String(imagePath); + return wire + .wire__crate__api__image_processing__process_image_ml_from_path( + port_, arg0); + }, + codec: DcoCodec( + decodeSuccessData: + dco_decode_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize, + decodeErrorData: null, + ), + constMeta: kCrateApiImageProcessingProcessImageMlFromPathConstMeta, + argValues: [imagePath], + apiImpl: this, + )); + } + + TaskConstMeta get kCrateApiImageProcessingProcessImageMlFromPathConstMeta => + const TaskConstMeta( + debugName: "process_image_ml_from_path", + argNames: ["imagePath"], + ); + @override Future<(Uint8List, String, BigInt, BigInt)> crateApiImageProcessingProcessYoloFace({required String imagePath}) { @@ -225,6 +257,31 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { ); } + @protected + ( + Uint8List, + BigInt, + BigInt, + Uint8List, + BigInt, + BigInt + ) dco_decode_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize( + dynamic raw) { + // Codec=Dco (DartCObject based), see doc to use other codecs + final arr = raw as List; + if (arr.length != 6) { + throw Exception('Expected 6 elements, got ${arr.length}'); + } + return ( + dco_decode_list_prim_u_8_strict(arr[0]), + dco_decode_usize(arr[1]), + dco_decode_usize(arr[2]), + dco_decode_list_prim_u_8_strict(arr[3]), + dco_decode_usize(arr[4]), + dco_decode_usize(arr[5]), + ); + } + @protected int dco_decode_u_8(dynamic raw) { // Codec=Dco (DartCObject based), see doc to use other codecs @@ -269,6 +326,33 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { return (var_field0, var_field1, var_field2, var_field3); } + @protected + ( + Uint8List, + BigInt, + BigInt, + Uint8List, + BigInt, + BigInt + ) sse_decode_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize( + SseDeserializer deserializer) { + // Codec=Sse (Serialization based), see doc to use other codecs + final var_field0 = sse_decode_list_prim_u_8_strict(deserializer); + final var_field1 = sse_decode_usize(deserializer); + final var_field2 = sse_decode_usize(deserializer); + final var_field3 = sse_decode_list_prim_u_8_strict(deserializer); + final var_field4 = sse_decode_usize(deserializer); + final var_field5 = sse_decode_usize(deserializer); + return ( + var_field0, + var_field1, + var_field2, + var_field3, + var_field4, + var_field5 + ); + } + @protected int sse_decode_u_8(SseDeserializer deserializer) { // Codec=Sse (Serialization based), see doc to use other codecs @@ -334,6 +418,20 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_usize(self.$4, serializer); } + @protected + void + sse_encode_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize( + (Uint8List, BigInt, BigInt, Uint8List, BigInt, BigInt) self, + SseSerializer serializer) { + // Codec=Sse (Serialization based), see doc to use other codecs + sse_encode_list_prim_u_8_strict(self.$1, serializer); + sse_encode_usize(self.$2, serializer); + sse_encode_usize(self.$3, serializer); + sse_encode_list_prim_u_8_strict(self.$4, serializer); + sse_encode_usize(self.$5, serializer); + sse_encode_usize(self.$6, serializer); + } + @protected void sse_encode_u_8(int self, SseSerializer serializer) { // Codec=Sse (Serialization based), see doc to use other codecs diff --git a/mobile/lib/src/rust/frb_generated.io.dart b/mobile/lib/src/rust/frb_generated.io.dart index 273664010d..b06988bf2e 100644 --- a/mobile/lib/src/rust/frb_generated.io.dart +++ b/mobile/lib/src/rust/frb_generated.io.dart @@ -31,6 +31,17 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { (Uint8List, String, BigInt, BigInt) dco_decode_record_list_prim_u_8_strict_string_usize_usize(dynamic raw); + @protected + ( + Uint8List, + BigInt, + BigInt, + Uint8List, + BigInt, + BigInt + ) dco_decode_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize( + dynamic raw); + @protected int dco_decode_u_8(dynamic raw); @@ -51,6 +62,17 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { sse_decode_record_list_prim_u_8_strict_string_usize_usize( SseDeserializer deserializer); + @protected + ( + Uint8List, + BigInt, + BigInt, + Uint8List, + BigInt, + BigInt + ) sse_decode_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize( + SseDeserializer deserializer); + @protected int sse_decode_u_8(SseDeserializer deserializer); @@ -97,6 +119,19 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { wireObj.field3 = cst_encode_usize(apiObj.$4); } + @protected + void cst_api_fill_to_wire_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize( + (Uint8List, BigInt, BigInt, Uint8List, BigInt, BigInt) apiObj, + wire_cst_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize + wireObj) { + wireObj.field0 = cst_encode_list_prim_u_8_strict(apiObj.$1); + wireObj.field1 = cst_encode_usize(apiObj.$2); + wireObj.field2 = cst_encode_usize(apiObj.$3); + wireObj.field3 = cst_encode_list_prim_u_8_strict(apiObj.$4); + wireObj.field4 = cst_encode_usize(apiObj.$5); + wireObj.field5 = cst_encode_usize(apiObj.$6); + } + @protected int cst_encode_u_8(int raw); @@ -114,6 +149,12 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { void sse_encode_record_list_prim_u_8_strict_string_usize_usize( (Uint8List, String, BigInt, BigInt) self, SseSerializer serializer); + @protected + void + sse_encode_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize( + (Uint8List, BigInt, BigInt, Uint8List, BigInt, BigInt) self, + SseSerializer serializer); + @protected void sse_encode_u_8(int self, SseSerializer serializer); @@ -190,6 +231,27 @@ class RustLibWire implements BaseWire { _wire__crate__api__image_processing__process_clipPtr.asFunction< void Function(int, ffi.Pointer)>(); + void wire__crate__api__image_processing__process_image_ml_from_path( + int port_, + ffi.Pointer image_path, + ) { + return _wire__crate__api__image_processing__process_image_ml_from_path( + port_, + image_path, + ); + } + + late final _wire__crate__api__image_processing__process_image_ml_from_pathPtr = + _lookup< + ffi.NativeFunction< + ffi.Void Function( + ffi.Int64, ffi.Pointer)>>( + 'frbgen_photos_wire__crate__api__image_processing__process_image_ml_from_path'); + late final _wire__crate__api__image_processing__process_image_ml_from_path = + _wire__crate__api__image_processing__process_image_ml_from_pathPtr + .asFunction< + void Function(int, ffi.Pointer)>(); + void wire__crate__api__image_processing__process_yolo_face( int port_, ffi.Pointer image_path, @@ -295,3 +357,22 @@ final class wire_cst_record_list_prim_u_8_strict_string_usize_usize @ffi.UintPtr() external int field3; } + +final class wire_cst_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize + extends ffi.Struct { + external ffi.Pointer field0; + + @ffi.UintPtr() + external int field1; + + @ffi.UintPtr() + external int field2; + + external ffi.Pointer field3; + + @ffi.UintPtr() + external int field4; + + @ffi.UintPtr() + external int field5; +} diff --git a/mobile/lib/utils/ml_util.dart b/mobile/lib/utils/ml_util.dart index d511d22de3..783e049050 100644 --- a/mobile/lib/utils/ml_util.dart +++ b/mobile/lib/utils/ml_util.dart @@ -23,6 +23,7 @@ import "package:photos/services/machine_learning/ml_exceptions.dart"; import "package:photos/services/machine_learning/ml_result.dart"; import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; import "package:photos/services/search_service.dart"; +import "package:photos/src/rust/api/image_processing.dart"; import "package:photos/src/rust/custom/init_frb.dart"; import "package:photos/utils/file_util.dart"; import "package:photos/utils/image_ml_util.dart"; @@ -395,18 +396,23 @@ Future analyzeImageStatic(Map args) async { "Start analyzeImageStatic for fileID $enteFileID (runFaces: $runFaces, runClip: $runClip)", ); await initFrb(); - final safePath = await safePathFromImagepath(imagePath); - final startTime = DateTime.now(); // Decode the image once to use for both face detection and alignment + final safePath = await safePathFromImagepath(imagePath); final (image, rawRgbaBytes) = await decodeImageFromPath(imagePath); + final decodeTime = DateTime.now(); + final decodeMs = decodeTime.difference(startTime).inMilliseconds; + + final (faceBytes, faceHeight, faceWidth, clipBytes, clipHeight, clipWidth) = + await processImageMlFromPath(imagePath: safePath); + _logger.info( + 'ML processing in rust took ${DateTime.now().difference(decodeTime).inMilliseconds} ms', + ); final decodedImageSize = Dimensions(height: image.height, width: image.width); final result = MLResult.fromEnteFileID(enteFileID); result.decodedImageSize = decodedImageSize; - final decodeTime = DateTime.now(); - final decodeMs = decodeTime.difference(startTime).inMilliseconds; String faceMsString = "", clipMsString = ""; final pipelines = await Future.wait([ @@ -415,9 +421,11 @@ Future analyzeImageStatic(Map args) async { enteFileID, image, rawRgbaBytes, + faceBytes, + faceHeight.toInt(), + faceWidth.toInt(), faceDetectionAddress, faceEmbeddingAddress, - safePath, ).then((result) { faceMsString = ", faces: ${DateTime.now().difference(decodeTime).inMilliseconds} ms"; @@ -427,10 +435,10 @@ Future analyzeImageStatic(Map args) async { runClip ? SemanticSearchService.runClipImage( enteFileID, - image, - rawRgbaBytes, + clipBytes, + clipHeight.toInt(), + clipWidth.toInt(), clipImageAddress, - safePath, ).then((result) { clipMsString = ", clip: ${DateTime.now().difference(decodeTime).inMilliseconds} ms"; diff --git a/mobile/rust/src/api/image_processing.rs b/mobile/rust/src/api/image_processing.rs index ed31944d2b..70fc3d83d4 100644 --- a/mobile/rust/src/api/image_processing.rs +++ b/mobile/rust/src/api/image_processing.rs @@ -1,6 +1,94 @@ use resize::{px::RGB, Pixel::RGB8, Type::Lanczos3, Type::Mitchell}; use rgb::FromSlice; +pub fn process_image_ml_from_path( + image_path: &str, +) -> (Vec, usize, usize, Vec, usize, usize) { + // Load the image (~200ms) + let img = image::open(image_path).expect("Failed to open image"); + + // Get dimensions for resized images (0ms) + let (width, height) = (img.width() as usize, img.height() as usize); + let scale_face = f32::min(640.0 / width as f32, 640.0 / height as f32); + let scale_clip = f32::max(256.0 / width as f32, 256.0 / height as f32); + let (new_width_face, new_height_face) = ( + f32::round(width as f32 * scale_face) as usize, + f32::round(height as f32 * scale_face) as usize, + ); + let (new_width_clip, new_height_clip) = ( + f32::round(width as f32 * scale_clip) as usize, + f32::round(height as f32 * scale_clip) as usize, + ); + let mut interpolation_face = Lanczos3; + if scale_face > 1.0 { + interpolation_face = Mitchell; + } + let mut interpolation_clip = Lanczos3; + if scale_clip > 1.0 { + interpolation_clip = Mitchell; + } + + // Convert image to RGB8 (~150ms) + let rgb_img = img.into_rgb8(); + + // Convert RGB8 to Vec (~30ms) + let rgb_vec = rgb_img.to_vec(); + + // Create resizer (~20ms) + let mut resizer_face = resize::new( + width, + height, + new_width_face, + new_height_face, + RGB8, + interpolation_face, + ) + .unwrap(); + let mut resizer_clip = resize::new( + width, + height, + new_width_clip, + new_height_clip, + RGB8, + interpolation_clip, + ) + .unwrap(); + + // Create buffer for resized image (~120ms) + let mut dst_face = vec![RGB::new(0, 0, 0); new_width_face * new_height_face]; + let mut dst_clip = vec![RGB::new(0, 0, 0); new_width_clip * new_height_clip]; + + // Resize the image (~120ms) + resizer_face + .resize(rgb_vec.as_rgb(), &mut dst_face) + .unwrap(); + resizer_clip + .resize(rgb_vec.as_rgb(), &mut dst_clip) + .unwrap(); + + // Return resized images as Vec (~120ms) + let mut result_face = Vec::with_capacity(new_width_face * new_height_face * 3); + for pixel in dst_face { + result_face.push(pixel.r); + result_face.push(pixel.g); + result_face.push(pixel.b); + } + let mut result_clip = Vec::with_capacity(new_width_clip * new_height_clip * 3); + for pixel in dst_clip { + result_clip.push(pixel.r); + result_clip.push(pixel.g); + result_clip.push(pixel.b); + } + ( + result_face, + new_height_face, + new_width_face, + result_clip, + new_height_clip, + new_width_clip, + ) +} + pub fn process_yolo_face(image_path: &str) -> (Vec, String, usize, usize) { let mut timing = String::new(); timing.push_str("Yolo Face\n"); diff --git a/mobile/rust/src/frb_generated.rs b/mobile/rust/src/frb_generated.rs index b429be7b22..3f335840ad 100644 --- a/mobile/rust/src/frb_generated.rs +++ b/mobile/rust/src/frb_generated.rs @@ -37,7 +37,7 @@ flutter_rust_bridge::frb_generated_boilerplate!( default_rust_auto_opaque = RustAutoOpaqueNom, ); pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.5.1"; -pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -2017876324; +pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -1741400115; // Section: executor @@ -68,6 +68,29 @@ fn wire__crate__api__image_processing__process_clip_impl( }, ) } +fn wire__crate__api__image_processing__process_image_ml_from_path_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + image_path: impl CstDecode, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "process_image_ml_from_path", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let api_image_path = image_path.cst_decode(); + move |context| { + transform_result_dco::<_, _, ()>((move || { + let output_ok = Result::<_, ()>::Ok( + crate::api::image_processing::process_image_ml_from_path(&api_image_path), + )?; + Ok(output_ok) + })()) + } + }, + ) +} fn wire__crate__api__image_processing__process_yolo_face_impl( port_: flutter_rust_bridge::for_generated::MessagePort, image_path: impl CstDecode, @@ -174,6 +197,21 @@ impl SseDecode for (Vec, String, usize, usize) { } } +impl SseDecode for (Vec, usize, usize, Vec, usize, usize) { + // Codec=Sse (Serialization based), see doc to use other codecs + fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self { + let mut var_field0 = >::sse_decode(deserializer); + let mut var_field1 = ::sse_decode(deserializer); + let mut var_field2 = ::sse_decode(deserializer); + let mut var_field3 = >::sse_decode(deserializer); + let mut var_field4 = ::sse_decode(deserializer); + let mut var_field5 = ::sse_decode(deserializer); + return ( + var_field0, var_field1, var_field2, var_field3, var_field4, var_field5, + ); + } +} + impl SseDecode for u8 { // Codec=Sse (Serialization based), see doc to use other codecs fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self { @@ -261,6 +299,18 @@ impl SseEncode for (Vec, String, usize, usize) { } } +impl SseEncode for (Vec, usize, usize, Vec, usize, usize) { + // Codec=Sse (Serialization based), see doc to use other codecs + fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) { + >::sse_encode(self.0, serializer); + ::sse_encode(self.1, serializer); + ::sse_encode(self.2, serializer); + >::sse_encode(self.3, serializer); + ::sse_encode(self.4, serializer); + ::sse_encode(self.5, serializer); + } +} + impl SseEncode for u8 { // Codec=Sse (Serialization based), see doc to use other codecs fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) { @@ -346,6 +396,21 @@ mod io { ) } } + impl CstDecode<(Vec, usize, usize, Vec, usize, usize)> + for wire_cst_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize + { + // Codec=Cst (C-struct based), see doc to use other codecs + fn cst_decode(self) -> (Vec, usize, usize, Vec, usize, usize) { + ( + self.field0.cst_decode(), + self.field1.cst_decode(), + self.field2.cst_decode(), + self.field3.cst_decode(), + self.field4.cst_decode(), + self.field5.cst_decode(), + ) + } + } impl NewWithNullPtr for wire_cst_record_list_prim_u_8_strict_string_usize_usize { fn new_with_null_ptr() -> Self { Self { @@ -361,6 +426,25 @@ mod io { Self::new_with_null_ptr() } } + impl NewWithNullPtr + for wire_cst_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize + { + fn new_with_null_ptr() -> Self { + Self { + field0: core::ptr::null_mut(), + field1: Default::default(), + field2: Default::default(), + field3: core::ptr::null_mut(), + field4: Default::default(), + field5: Default::default(), + } + } + } + impl Default for wire_cst_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize { + fn default() -> Self { + Self::new_with_null_ptr() + } + } #[no_mangle] pub extern "C" fn frbgen_photos_wire__crate__api__image_processing__process_clip( @@ -370,6 +454,14 @@ mod io { wire__crate__api__image_processing__process_clip_impl(port_, image_path) } + #[no_mangle] + pub extern "C" fn frbgen_photos_wire__crate__api__image_processing__process_image_ml_from_path( + port_: i64, + image_path: *mut wire_cst_list_prim_u_8_strict, + ) { + wire__crate__api__image_processing__process_image_ml_from_path_impl(port_, image_path) + } + #[no_mangle] pub extern "C" fn frbgen_photos_wire__crate__api__image_processing__process_yolo_face( port_: i64, @@ -415,6 +507,16 @@ mod io { field2: usize, field3: usize, } + #[repr(C)] + #[derive(Clone, Copy)] + pub struct wire_cst_record_list_prim_u_8_strict_usize_usize_list_prim_u_8_strict_usize_usize { + field0: *mut wire_cst_list_prim_u_8_strict, + field1: usize, + field2: usize, + field3: *mut wire_cst_list_prim_u_8_strict, + field4: usize, + field5: usize, + } } #[cfg(not(target_family = "wasm"))] pub use io::*;