From 832f2c451e914bb99cb3b46a6c25ccf7aa128eab Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 8 May 2025 11:47:50 +0530 Subject: [PATCH] Add bulk get method to vector db api --- mobile/lib/src/rust/api/usearch_api.dart | 2 + mobile/lib/src/rust/frb_generated.dart | 59 +++++++++++++---- mobile/rust/src/api/usearch_api.rs | 9 +++ mobile/rust/src/frb_generated.rs | 84 ++++++++++++++++++++---- 4 files changed, 130 insertions(+), 24 deletions(-) diff --git a/mobile/lib/src/rust/api/usearch_api.dart b/mobile/lib/src/rust/api/usearch_api.dart index 8260600b32..0ff5ada833 100644 --- a/mobile/lib/src/rust/api/usearch_api.dart +++ b/mobile/lib/src/rust/api/usearch_api.dart @@ -17,6 +17,8 @@ abstract class VectorDb implements RustOpaqueInterface { Future bulkAddVectors( {required Uint64List keys, required List vectors}); + Future> bulkGetVectors({required Uint64List keys}); + Future bulkRemoveVectors({required Uint64List keys}); Future<(List, List)> bulkSearchVectors( diff --git a/mobile/lib/src/rust/frb_generated.dart b/mobile/lib/src/rust/frb_generated.dart index 6a044372fd..247866b139 100644 --- a/mobile/lib/src/rust/frb_generated.dart +++ b/mobile/lib/src/rust/frb_generated.dart @@ -72,7 +72,7 @@ class RustLib extends BaseEntrypoint { String get codegenVersion => '2.9.0'; @override - int get rustContentHash => -1131116360; + int get rustContentHash => 382419186; static const kDefaultExternalLibraryLoaderConfig = ExternalLibraryLoaderConfig( @@ -93,6 +93,9 @@ abstract class RustLibApi extends BaseApi { required Uint64List keys, required List vectors}); + Future> crateApiUsearchApiVectorDbBulkGetVectors( + {required VectorDb that, required Uint64List keys}); + Future crateApiUsearchApiVectorDbBulkRemoveVectors( {required VectorDb that, required Uint64List keys}); @@ -207,7 +210,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { ); @override - Future crateApiUsearchApiVectorDbBulkRemoveVectors( + Future> crateApiUsearchApiVectorDbBulkGetVectors( {required VectorDb that, required Uint64List keys}) { return handler.executeNormal(NormalTask( callFfi: (port_) { @@ -218,6 +221,34 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 3, port: port_); }, + codec: SseCodec( + decodeSuccessData: sse_decode_list_list_prim_f_32_strict, + decodeErrorData: null, + ), + constMeta: kCrateApiUsearchApiVectorDbBulkGetVectorsConstMeta, + argValues: [that, keys], + apiImpl: this, + )); + } + + TaskConstMeta get kCrateApiUsearchApiVectorDbBulkGetVectorsConstMeta => + const TaskConstMeta( + debugName: "VectorDb_bulk_get_vectors", + argNames: ["that", "keys"], + ); + + @override + Future crateApiUsearchApiVectorDbBulkRemoveVectors( + {required VectorDb that, required Uint64List keys}) { + return handler.executeNormal(NormalTask( + callFfi: (port_) { + final serializer = SseSerializer(generalizedFrbRustBinding); + sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( + that, serializer); + sse_encode_list_prim_u_64_strict(keys, serializer); + pdeCallFfi(generalizedFrbRustBinding, serializer, + funcId: 4, port: port_); + }, codec: SseCodec( decodeSuccessData: sse_decode_usize, decodeErrorData: null, @@ -248,7 +279,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_list_list_prim_f_32_strict(queries, serializer); sse_encode_usize(count, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 4, port: port_); + funcId: 5, port: port_); }, codec: SseCodec( decodeSuccessData: @@ -275,7 +306,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Owned_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 5, port: port_); + funcId: 6, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -302,7 +333,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 6, port: port_); + funcId: 7, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_record_usize_usize_usize_usize_usize, @@ -330,7 +361,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { that, serializer); sse_encode_u_64(key, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 7, port: port_); + funcId: 8, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_list_prim_f_32_strict, @@ -356,7 +387,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_String(filePath, serializer); sse_encode_usize(dimensions, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 8)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 9)!; }, codec: SseCodec( decodeSuccessData: @@ -385,7 +416,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { that, serializer); sse_encode_u_64(key, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 9, port: port_); + funcId: 10, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_usize, @@ -411,7 +442,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 10, port: port_); + funcId: 11, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -442,7 +473,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_list_prim_f_32_loose(query, serializer); sse_encode_usize(count, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 11, port: port_); + funcId: 12, port: port_); }, codec: SseCodec( decodeSuccessData: @@ -467,7 +498,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: () { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_String(name, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 12)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 13)!; }, codec: SseCodec( decodeSuccessData: sse_decode_String, @@ -490,7 +521,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 13, port: port_); + funcId: 14, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -1008,6 +1039,10 @@ class VectorDbImpl extends RustOpaque implements VectorDb { RustLib.instance.api.crateApiUsearchApiVectorDbBulkAddVectors( that: this, keys: keys, vectors: vectors); + Future> bulkGetVectors({required Uint64List keys}) => + RustLib.instance.api + .crateApiUsearchApiVectorDbBulkGetVectors(that: this, keys: keys); + Future bulkRemoveVectors({required Uint64List keys}) => RustLib.instance.api .crateApiUsearchApiVectorDbBulkRemoveVectors(that: this, keys: keys); diff --git a/mobile/rust/src/api/usearch_api.rs b/mobile/rust/src/api/usearch_api.rs index 2c32870d41..47a948cff1 100644 --- a/mobile/rust/src/api/usearch_api.rs +++ b/mobile/rust/src/api/usearch_api.rs @@ -108,6 +108,15 @@ impl VectorDB { vector } + pub fn bulk_get_vectors(&self, keys: Vec) -> Vec> { + let mut vectors = Vec::new(); + for key in keys { + let vector = self.get_vector(key); + vectors.push(vector); + } + vectors + } + pub fn remove_vector(&self, key: u64) -> usize { let removed_count = self.index.remove(key).expect("Failed to remove vector"); self.save_index(); diff --git a/mobile/rust/src/frb_generated.rs b/mobile/rust/src/frb_generated.rs index 4ec616d979..94670bc0b3 100644 --- a/mobile/rust/src/frb_generated.rs +++ b/mobile/rust/src/frb_generated.rs @@ -38,7 +38,7 @@ flutter_rust_bridge::frb_generated_boilerplate!( default_rust_auto_opaque = RustAutoOpaqueMoi, ); pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.9.0"; -pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -1131116360; +pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 382419186; // Section: executor @@ -160,6 +160,60 @@ fn wire__crate__api__usearch_api__VectorDb_bulk_add_vectors_impl( }, ) } +fn wire__crate__api__usearch_api__VectorDb_bulk_get_vectors_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, + rust_vec_len_: i32, + data_len_: i32, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "VectorDb_bulk_get_vectors", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let message = unsafe { + flutter_rust_bridge::for_generated::Dart2RustMessageSse::from_wire( + ptr_, + rust_vec_len_, + data_len_, + ) + }; + let mut deserializer = + flutter_rust_bridge::for_generated::SseDeserializer::new(message); + let api_that = , + >>::sse_decode(&mut deserializer); + let api_keys = >::sse_decode(&mut deserializer); + deserializer.end(); + move |context| { + transform_result_sse::<_, ()>((move || { + let mut api_that_guard = None; + let decode_indices_ = + flutter_rust_bridge::for_generated::lockable_compute_decode_order(vec![ + flutter_rust_bridge::for_generated::LockableOrderInfo::new( + &api_that, 0, false, + ), + ]); + for i in decode_indices_ { + match i { + 0 => api_that_guard = Some(api_that.lockable_decode_sync_ref()), + _ => unreachable!(), + } + } + let api_that_guard = api_that_guard.unwrap(); + let output_ok = + Result::<_, ()>::Ok(crate::api::usearch_api::VectorDB::bulk_get_vectors( + &*api_that_guard, + api_keys, + ))?; + Ok(output_ok) + })()) + } + }, + ) +} fn wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl( port_: flutter_rust_bridge::for_generated::MessagePort, ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, @@ -862,55 +916,61 @@ fn pde_ffi_dispatcher_primary_impl( rust_vec_len, data_len, ), - 3 => wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl( + 3 => wire__crate__api__usearch_api__VectorDb_bulk_get_vectors_impl( port, ptr, rust_vec_len, data_len, ), - 4 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( + 4 => wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl( port, ptr, rust_vec_len, data_len, ), - 5 => wire__crate__api__usearch_api__VectorDb_delete_index_impl( + 5 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( port, ptr, rust_vec_len, data_len, ), - 6 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl( + 6 => wire__crate__api__usearch_api__VectorDb_delete_index_impl( port, ptr, rust_vec_len, data_len, ), - 7 => wire__crate__api__usearch_api__VectorDb_get_vector_impl( + 7 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl( port, ptr, rust_vec_len, data_len, ), - 9 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl( + 8 => wire__crate__api__usearch_api__VectorDb_get_vector_impl( port, ptr, rust_vec_len, data_len, ), - 10 => wire__crate__api__usearch_api__VectorDb_reset_index_impl( + 10 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl( port, ptr, rust_vec_len, data_len, ), - 11 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl( + 11 => wire__crate__api__usearch_api__VectorDb_reset_index_impl( port, ptr, rust_vec_len, data_len, ), - 13 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len), + 12 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl( + port, + ptr, + rust_vec_len, + data_len, + ), + 14 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len), _ => unreachable!(), } } @@ -923,8 +983,8 @@ fn pde_ffi_dispatcher_sync_impl( ) -> flutter_rust_bridge::for_generated::WireSyncRust2DartSse { // Codec=Pde (Serialization + dispatch), see doc to use other codecs match func_id { - 8 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len), - 12 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len), + 9 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len), + 13 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len), _ => unreachable!(), } }