From 715c7c23a735433bdad80ae7eb3084a3115ebed7 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 8 May 2025 10:29:25 +0530 Subject: [PATCH] Add bulk remove embeddings api --- mobile/lib/src/rust/api/usearch_api.dart | 2 + mobile/lib/src/rust/frb_generated.dart | 57 ++++++++++++---- mobile/rust/src/api/usearch_api.rs | 12 ++++ mobile/rust/src/frb_generated.rs | 83 ++++++++++++++++++++---- 4 files changed, 132 insertions(+), 22 deletions(-) diff --git a/mobile/lib/src/rust/api/usearch_api.dart b/mobile/lib/src/rust/api/usearch_api.dart index 7630098f59..8260600b32 100644 --- a/mobile/lib/src/rust/api/usearch_api.dart +++ b/mobile/lib/src/rust/api/usearch_api.dart @@ -17,6 +17,8 @@ abstract class VectorDb implements RustOpaqueInterface { Future bulkAddVectors( {required Uint64List keys, required List vectors}); + Future bulkRemoveVectors({required Uint64List keys}); + Future<(List, List)> bulkSearchVectors( {required List queries, required BigInt count}); diff --git a/mobile/lib/src/rust/frb_generated.dart b/mobile/lib/src/rust/frb_generated.dart index 4e2ee761ca..6a044372fd 100644 --- a/mobile/lib/src/rust/frb_generated.dart +++ b/mobile/lib/src/rust/frb_generated.dart @@ -72,7 +72,7 @@ class RustLib extends BaseEntrypoint { String get codegenVersion => '2.9.0'; @override - int get rustContentHash => -674813457; + int get rustContentHash => -1131116360; static const kDefaultExternalLibraryLoaderConfig = ExternalLibraryLoaderConfig( @@ -93,6 +93,9 @@ abstract class RustLibApi extends BaseApi { required Uint64List keys, required List vectors}); + Future crateApiUsearchApiVectorDbBulkRemoveVectors( + {required VectorDb that, required Uint64List keys}); + Future<(List, List)> crateApiUsearchApiVectorDbBulkSearchVectors( {required VectorDb that, @@ -203,6 +206,34 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { argNames: ["that", "keys", "vectors"], ); + @override + Future crateApiUsearchApiVectorDbBulkRemoveVectors( + {required VectorDb that, required Uint64List keys}) { + return handler.executeNormal(NormalTask( + callFfi: (port_) { + final serializer = SseSerializer(generalizedFrbRustBinding); + sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( + that, serializer); + sse_encode_list_prim_u_64_strict(keys, serializer); + pdeCallFfi(generalizedFrbRustBinding, serializer, + funcId: 3, port: port_); + }, + codec: SseCodec( + decodeSuccessData: sse_decode_usize, + decodeErrorData: null, + ), + constMeta: kCrateApiUsearchApiVectorDbBulkRemoveVectorsConstMeta, + argValues: [that, keys], + apiImpl: this, + )); + } + + TaskConstMeta get kCrateApiUsearchApiVectorDbBulkRemoveVectorsConstMeta => + const TaskConstMeta( + debugName: "VectorDb_bulk_remove_vectors", + argNames: ["that", "keys"], + ); + @override Future<(List, List)> crateApiUsearchApiVectorDbBulkSearchVectors( @@ -217,7 +248,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_list_list_prim_f_32_strict(queries, serializer); sse_encode_usize(count, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 3, port: port_); + funcId: 4, port: port_); }, codec: SseCodec( decodeSuccessData: @@ -244,7 +275,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Owned_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 4, port: port_); + funcId: 5, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -271,7 +302,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 5, port: port_); + funcId: 6, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_record_usize_usize_usize_usize_usize, @@ -299,7 +330,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { that, serializer); sse_encode_u_64(key, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 6, port: port_); + funcId: 7, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_list_prim_f_32_strict, @@ -325,7 +356,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_String(filePath, serializer); sse_encode_usize(dimensions, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 7)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 8)!; }, codec: SseCodec( decodeSuccessData: @@ -354,7 +385,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { that, serializer); sse_encode_u_64(key, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 8, port: port_); + funcId: 9, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_usize, @@ -380,7 +411,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB( that, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 9, port: port_); + funcId: 10, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -411,7 +442,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_list_prim_f_32_loose(query, serializer); sse_encode_usize(count, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 10, port: port_); + funcId: 11, port: port_); }, codec: SseCodec( decodeSuccessData: @@ -436,7 +467,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: () { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_String(name, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 11)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 12)!; }, codec: SseCodec( decodeSuccessData: sse_decode_String, @@ -459,7 +490,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 12, port: port_); + funcId: 13, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -977,6 +1008,10 @@ class VectorDbImpl extends RustOpaque implements VectorDb { RustLib.instance.api.crateApiUsearchApiVectorDbBulkAddVectors( that: this, keys: keys, vectors: vectors); + Future bulkRemoveVectors({required Uint64List keys}) => + RustLib.instance.api + .crateApiUsearchApiVectorDbBulkRemoveVectors(that: this, keys: keys); + Future<(List, List)> bulkSearchVectors( {required List queries, required BigInt count}) => RustLib.instance.api.crateApiUsearchApiVectorDbBulkSearchVectors( diff --git a/mobile/rust/src/api/usearch_api.rs b/mobile/rust/src/api/usearch_api.rs index 107495dbda..2c32870d41 100644 --- a/mobile/rust/src/api/usearch_api.rs +++ b/mobile/rust/src/api/usearch_api.rs @@ -114,6 +114,18 @@ impl VectorDB { removed_count } + pub fn bulk_remove_vectors(&self, keys: Vec) -> usize { + let mut removed_count = 0; + for key in keys { + removed_count += self + .index + .remove(key) + .expect("Failed to (bulk) remove vector"); + } + self.save_index(); + removed_count + } + pub fn reset_index(&self) { self.index.reset().expect("Failed to reset index"); self.save_index(); diff --git a/mobile/rust/src/frb_generated.rs b/mobile/rust/src/frb_generated.rs index 0356753532..4ec616d979 100644 --- a/mobile/rust/src/frb_generated.rs +++ b/mobile/rust/src/frb_generated.rs @@ -38,7 +38,7 @@ flutter_rust_bridge::frb_generated_boilerplate!( default_rust_auto_opaque = RustAutoOpaqueMoi, ); pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.9.0"; -pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -674813457; +pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -1131116360; // Section: executor @@ -160,6 +160,61 @@ fn wire__crate__api__usearch_api__VectorDb_bulk_add_vectors_impl( }, ) } +fn wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, + rust_vec_len_: i32, + data_len_: i32, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "VectorDb_bulk_remove_vectors", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let message = unsafe { + flutter_rust_bridge::for_generated::Dart2RustMessageSse::from_wire( + ptr_, + rust_vec_len_, + data_len_, + ) + }; + let mut deserializer = + flutter_rust_bridge::for_generated::SseDeserializer::new(message); + let api_that = , + >>::sse_decode(&mut deserializer); + let api_keys = >::sse_decode(&mut deserializer); + deserializer.end(); + move |context| { + transform_result_sse::<_, ()>((move || { + let mut api_that_guard = None; + let decode_indices_ = + flutter_rust_bridge::for_generated::lockable_compute_decode_order(vec![ + flutter_rust_bridge::for_generated::LockableOrderInfo::new( + &api_that, 0, false, + ), + ]); + for i in decode_indices_ { + match i { + 0 => api_that_guard = Some(api_that.lockable_decode_sync_ref()), + _ => unreachable!(), + } + } + let api_that_guard = api_that_guard.unwrap(); + let output_ok = Result::<_, ()>::Ok( + crate::api::usearch_api::VectorDB::bulk_remove_vectors( + &*api_that_guard, + api_keys, + ), + )?; + Ok(output_ok) + })()) + } + }, + ) +} fn wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( port_: flutter_rust_bridge::for_generated::MessagePort, ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, @@ -807,49 +862,55 @@ fn pde_ffi_dispatcher_primary_impl( rust_vec_len, data_len, ), - 3 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( + 3 => wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl( port, ptr, rust_vec_len, data_len, ), - 4 => wire__crate__api__usearch_api__VectorDb_delete_index_impl( + 4 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl( port, ptr, rust_vec_len, data_len, ), - 5 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl( + 5 => wire__crate__api__usearch_api__VectorDb_delete_index_impl( port, ptr, rust_vec_len, data_len, ), - 6 => wire__crate__api__usearch_api__VectorDb_get_vector_impl( + 6 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl( port, ptr, rust_vec_len, data_len, ), - 8 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl( + 7 => wire__crate__api__usearch_api__VectorDb_get_vector_impl( port, ptr, rust_vec_len, data_len, ), - 9 => wire__crate__api__usearch_api__VectorDb_reset_index_impl( + 9 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl( port, ptr, rust_vec_len, data_len, ), - 10 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl( + 10 => wire__crate__api__usearch_api__VectorDb_reset_index_impl( port, ptr, rust_vec_len, data_len, ), - 12 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len), + 11 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl( + port, + ptr, + rust_vec_len, + data_len, + ), + 13 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len), _ => unreachable!(), } } @@ -862,8 +923,8 @@ fn pde_ffi_dispatcher_sync_impl( ) -> flutter_rust_bridge::for_generated::WireSyncRust2DartSse { // Codec=Pde (Serialization + dispatch), see doc to use other codecs match func_id { - 7 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len), - 11 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len), + 8 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len), + 12 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len), _ => unreachable!(), } }