From 856a87f01c0998b6334665e0a99b3e7c277fcdda Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 12 Aug 2025 11:36:38 +0530 Subject: [PATCH] Rust api to do entire search with potential keys in rust --- .../photos/lib/src/rust/api/usearch_api.dart | 5 + .../photos/lib/src/rust/frb_generated.dart | 117 ++++++++++++++++-- .../photos/lib/src/rust/frb_generated.io.dart | 22 ++++ .../apps/photos/rust/src/api/usearch_api.rs | 32 +++++ mobile/apps/photos/rust/src/frb_generated.rs | 107 ++++++++++++++-- 5 files changed, 259 insertions(+), 24 deletions(-) diff --git a/mobile/apps/photos/lib/src/rust/api/usearch_api.dart b/mobile/apps/photos/lib/src/rust/api/usearch_api.dart index 7d370723cd..118cb1b38c 100644 --- a/mobile/apps/photos/lib/src/rust/api/usearch_api.dart +++ b/mobile/apps/photos/lib/src/rust/api/usearch_api.dart @@ -21,6 +21,11 @@ abstract class VectorDb implements RustOpaqueInterface { Future bulkRemoveVectors({required Uint64List keys}); + Future<(Uint64List, List, List)> bulkSearchKeys( + {required Uint64List potentialKeys, + required BigInt count, + required bool exact}); + Future<(List, List)> bulkSearchVectors( {required List queries, required BigInt count, diff --git a/mobile/apps/photos/lib/src/rust/frb_generated.dart b/mobile/apps/photos/lib/src/rust/frb_generated.dart index c9496ecf9e..931d5ffb3a 100644 --- a/mobile/apps/photos/lib/src/rust/frb_generated.dart +++ b/mobile/apps/photos/lib/src/rust/frb_generated.dart @@ -74,7 +74,7 @@ class RustLib extends BaseEntrypoint { String get codegenVersion => '2.11.1'; @override - int get rustContentHash => 1543559173; + int get rustContentHash => 1360671619; static const kDefaultExternalLibraryLoaderConfig = ExternalLibraryLoaderConfig( @@ -101,6 +101,13 @@ abstract class RustLibApi extends BaseApi { Future crateApiUsearchApiVectorDbBulkRemoveVectors( {required VectorDb that, required Uint64List keys}); + Future<(Uint64List, List, List)> + crateApiUsearchApiVectorDbBulkSearchKeys( + {required VectorDb that, + required Uint64List potentialKeys, + required BigInt count, + required bool exact}); + Future<(List, List)> crateApiUsearchApiVectorDbBulkSearchVectors( {required VectorDb that, @@ -272,6 +279,41 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { argNames: ["that", "keys"], ); + @override + Future<(Uint64List, List, List)> + crateApiUsearchApiVectorDbBulkSearchKeys( + {required VectorDb that, + required Uint64List potentialKeys, + required BigInt count, + required bool exact}) { + return handler.executeNormal(NormalTask( + callFfi: (port_) { + final serializer = SseSerializer(generalizedFrbRustBinding); + sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( + that, serializer); + sse_encode_list_prim_u_64_strict(potentialKeys, serializer); + sse_encode_usize(count, serializer); + sse_encode_bool(exact, serializer); + pdeCallFfi(generalizedFrbRustBinding, serializer, + funcId: 5, port: port_); + }, + codec: SseCodec( + decodeSuccessData: + sse_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict, + decodeErrorData: null, + ), + constMeta: kCrateApiUsearchApiVectorDbBulkSearchKeysConstMeta, + argValues: [that, potentialKeys, count, exact], + apiImpl: this, + )); + } + + TaskConstMeta get kCrateApiUsearchApiVectorDbBulkSearchKeysConstMeta => + const TaskConstMeta( + debugName: "VectorDb_bulk_search_keys", + argNames: ["that", "potentialKeys", "count", "exact"], + ); + @override Future<(List, List)> crateApiUsearchApiVectorDbBulkSearchVectors( @@ -288,7 +330,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_usize(count, serializer); sse_encode_bool(exact, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 5, port: port_); + funcId: 6, port: port_); }, codec: SseCodec( decodeSuccessData: @@ -317,7 +359,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { that, serializer); sse_encode_u_64(key, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 6, port: port_); + funcId: 7, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_bool, @@ -343,7 +385,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Owned_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 7, port: port_); + funcId: 8, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -370,7 +412,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 8, port: port_); + funcId: 9, port: port_); }, codec: SseCodec( decodeSuccessData: @@ -399,7 +441,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { that, serializer); sse_encode_u_64(key, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 9, port: port_); + funcId: 10, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_list_prim_f_32_strict, @@ -425,7 +467,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_String(filePath, serializer); sse_encode_usize(dimensions, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 10)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 11)!; }, codec: SseCodec( decodeSuccessData: @@ -454,7 +496,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { that, serializer); sse_encode_u_64(key, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 11, port: port_); + funcId: 12, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_usize, @@ -480,7 +522,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 12, port: port_); + funcId: 13, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -513,7 +555,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_usize(count, serializer); sse_encode_bool(exact, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 13, port: port_); + funcId: 14, port: port_); }, codec: SseCodec( decodeSuccessData: @@ -538,7 +580,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: () { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_String(name, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 14)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 15)!; }, codec: SseCodec( decodeSuccessData: sse_decode_String, @@ -561,7 +603,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 15, port: port_); + funcId: 16, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -683,6 +725,25 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { ); } + @protected + ( + Uint64List, + List, + List + ) dco_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + dynamic raw) { + // Codec=Dco (DartCObject based), see doc to use other codecs + final arr = raw as List; + if (arr.length != 3) { + throw Exception('Expected 3 elements, got ${arr.length}'); + } + return ( + dco_decode_list_prim_u_64_strict(arr[0]), + dco_decode_list_list_prim_u_64_strict(arr[1]), + dco_decode_list_list_prim_f_32_strict(arr[2]), + ); + } + @protected ( Uint64List, @@ -852,6 +913,20 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { return (var_field0, var_field1); } + @protected + ( + Uint64List, + List, + List + ) sse_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + SseDeserializer deserializer) { + // Codec=Sse (Serialization based), see doc to use other codecs + final var_field0 = sse_decode_list_prim_u_64_strict(deserializer); + final var_field1 = sse_decode_list_list_prim_u_64_strict(deserializer); + final var_field2 = sse_decode_list_list_prim_f_32_strict(deserializer); + return (var_field0, var_field1, var_field2); + } + @protected (Uint64List, Float32List) sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict( @@ -1020,6 +1095,17 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_list_list_prim_f_32_strict(self.$2, serializer); } + @protected + void + sse_encode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + (Uint64List, List, List) self, + SseSerializer serializer) { + // Codec=Sse (Serialization based), see doc to use other codecs + sse_encode_list_prim_u_64_strict(self.$1, serializer); + sse_encode_list_list_prim_u_64_strict(self.$2, serializer); + sse_encode_list_list_prim_f_32_strict(self.$3, serializer); + } + @protected void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict( (Uint64List, Float32List) self, SseSerializer serializer) { @@ -1108,6 +1194,13 @@ class VectorDbImpl extends RustOpaque implements VectorDb { RustLib.instance.api .crateApiUsearchApiVectorDbBulkRemoveVectors(that: this, keys: keys); + Future<(Uint64List, List, List)> bulkSearchKeys( + {required Uint64List potentialKeys, + required BigInt count, + required bool exact}) => + RustLib.instance.api.crateApiUsearchApiVectorDbBulkSearchKeys( + that: this, potentialKeys: potentialKeys, count: count, exact: exact); + Future<(List, List)> bulkSearchVectors( {required List queries, required BigInt count, diff --git a/mobile/apps/photos/lib/src/rust/frb_generated.io.dart b/mobile/apps/photos/lib/src/rust/frb_generated.io.dart index 6ab196d6e2..f7b687f7d8 100644 --- a/mobile/apps/photos/lib/src/rust/frb_generated.io.dart +++ b/mobile/apps/photos/lib/src/rust/frb_generated.io.dart @@ -72,6 +72,14 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( dynamic raw); + @protected + ( + Uint64List, + List, + List + ) dco_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + dynamic raw); + @protected ( Uint64List, @@ -143,6 +151,14 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( SseDeserializer deserializer); + @protected + ( + Uint64List, + List, + List + ) sse_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + SseDeserializer deserializer); + @protected (Uint64List, Float32List) sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict( @@ -220,6 +236,12 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( (List, List) self, SseSerializer serializer); + @protected + void + sse_encode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + (Uint64List, List, List) self, + SseSerializer serializer); + @protected void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict( (Uint64List, Float32List) self, SseSerializer serializer); diff --git a/mobile/apps/photos/rust/src/api/usearch_api.rs b/mobile/apps/photos/rust/src/api/usearch_api.rs index 78503c9dd5..28957d7114 100644 --- a/mobile/apps/photos/rust/src/api/usearch_api.rs +++ b/mobile/apps/photos/rust/src/api/usearch_api.rs @@ -118,6 +118,38 @@ impl VectorDB { (keys, distances) } + pub fn bulk_search_keys( + &self, + potential_keys: &Vec, + count: usize, + exact: bool, + ) -> (Vec, Vec>, Vec>) { + // let max_contained_keys = potential_keys.len(); + let mut contained_keys = Vec::new(); + let mut queries = Vec::new(); + + for key in potential_keys { + let contains: bool = self.index.contains(*key); + if contains { + let embedding = self.get_vector(*key); + contained_keys.push(*key); + queries.push(embedding); + } + } + + let mut closeby_keys = Vec::new(); + let mut distances = Vec::new(); + for query in &queries { + let (keys_result, distances_result) = self.search_vectors(query, count, exact); + closeby_keys.push(keys_result); + distances.push(distances_result); + } + if contained_keys.len() != closeby_keys.len() { + panic!("The number of contained keys does not match the number of keys"); + } + (contained_keys, closeby_keys, distances) + } + /// Check if a vector with the given key exists in the index. /// `true` if the index contains the vector with the given key, `false` otherwise. pub fn contains_vector(&self, key: u64) -> bool { diff --git a/mobile/apps/photos/rust/src/frb_generated.rs b/mobile/apps/photos/rust/src/frb_generated.rs index 6b114469a7..8ccbd11c20 100644 --- a/mobile/apps/photos/rust/src/frb_generated.rs +++ b/mobile/apps/photos/rust/src/frb_generated.rs @@ -38,7 +38,7 @@ flutter_rust_bridge::frb_generated_boilerplate!( default_rust_auto_opaque = RustAutoOpaqueMoi, ); pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.11.1"; -pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 1543559173; +pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 1360671619; // Section: executor @@ -269,6 +269,64 @@ fn wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl( }, ) } +fn wire__crate__api__usearch_api__VectorDb_bulk_search_keys_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, + rust_vec_len_: i32, + data_len_: i32, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "VectorDb_bulk_search_keys", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let message = unsafe { + flutter_rust_bridge::for_generated::Dart2RustMessageSse::from_wire( + ptr_, + rust_vec_len_, + data_len_, + ) + }; + let mut deserializer = + flutter_rust_bridge::for_generated::SseDeserializer::new(message); + let api_that = , + >>::sse_decode(&mut deserializer); + let api_potential_keys = >::sse_decode(&mut deserializer); + let api_count = ::sse_decode(&mut deserializer); + let api_exact = ::sse_decode(&mut deserializer); + deserializer.end(); + move |context| { + transform_result_sse::<_, ()>((move || { + let mut api_that_guard = None; + let decode_indices_ = + flutter_rust_bridge::for_generated::lockable_compute_decode_order(vec![ + flutter_rust_bridge::for_generated::LockableOrderInfo::new( + &api_that, 0, false, + ), + ]); + for i in decode_indices_ { + match i { + 0 => api_that_guard = Some(api_that.lockable_decode_sync_ref()), + _ => unreachable!(), + } + } + let api_that_guard = api_that_guard.unwrap(); + let output_ok = + Result::<_, ()>::Ok(crate::api::usearch_api::VectorDB::bulk_search_keys( + &*api_that_guard, + &api_potential_keys, + api_count, + api_exact, + ))?; + Ok(output_ok) + })()) + } + }, + ) +} fn wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( port_: flutter_rust_bridge::for_generated::MessagePort, ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, @@ -899,6 +957,16 @@ impl SseDecode for (Vec>, Vec>) { } } +impl SseDecode for (Vec, Vec>, Vec>) { + // Codec=Sse (Serialization based), see doc to use other codecs + fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self { + let mut var_field0 = >::sse_decode(deserializer); + let mut var_field1 = >>::sse_decode(deserializer); + let mut var_field2 = >>::sse_decode(deserializer); + return (var_field0, var_field1, var_field2); + } +} + impl SseDecode for (Vec, Vec) { // Codec=Sse (Serialization based), see doc to use other codecs fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self { @@ -990,55 +1058,61 @@ fn pde_ffi_dispatcher_primary_impl( rust_vec_len, data_len, ), - 5 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( + 5 => wire__crate__api__usearch_api__VectorDb_bulk_search_keys_impl( port, ptr, rust_vec_len, data_len, ), - 6 => wire__crate__api__usearch_api__VectorDb_contains_vector_impl( + 6 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( port, ptr, rust_vec_len, data_len, ), - 7 => wire__crate__api__usearch_api__VectorDb_delete_index_impl( + 7 => wire__crate__api__usearch_api__VectorDb_contains_vector_impl( port, ptr, rust_vec_len, data_len, ), - 8 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl( + 8 => wire__crate__api__usearch_api__VectorDb_delete_index_impl( port, ptr, rust_vec_len, data_len, ), - 9 => wire__crate__api__usearch_api__VectorDb_get_vector_impl( + 9 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl( port, ptr, rust_vec_len, data_len, ), - 11 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl( + 10 => wire__crate__api__usearch_api__VectorDb_get_vector_impl( port, ptr, rust_vec_len, data_len, ), - 12 => wire__crate__api__usearch_api__VectorDb_reset_index_impl( + 12 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl( port, ptr, rust_vec_len, data_len, ), - 13 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl( + 13 => wire__crate__api__usearch_api__VectorDb_reset_index_impl( port, ptr, rust_vec_len, data_len, ), - 15 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len), + 14 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl( + port, + ptr, + rust_vec_len, + data_len, + ), + 16 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len), _ => unreachable!(), } } @@ -1051,8 +1125,8 @@ fn pde_ffi_dispatcher_sync_impl( ) -> flutter_rust_bridge::for_generated::WireSyncRust2DartSse { // Codec=Pde (Serialization + dispatch), see doc to use other codecs match func_id { - 10 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len), - 14 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len), + 11 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len), + 15 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len), _ => unreachable!(), } } @@ -1171,6 +1245,15 @@ impl SseEncode for (Vec>, Vec>) { } } +impl SseEncode for (Vec, Vec>, Vec>) { + // Codec=Sse (Serialization based), see doc to use other codecs + fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) { + >::sse_encode(self.0, serializer); + >>::sse_encode(self.1, serializer); + >>::sse_encode(self.2, serializer); + } +} + impl SseEncode for (Vec, Vec) { // Codec=Sse (Serialization based), see doc to use other codecs fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {