Add bulk get method to vector db api

This commit is contained in:
laurenspriem
2025-05-08 11:47:50 +05:30
parent 715c7c23a7
commit 832f2c451e
4 changed files with 130 additions and 24 deletions

View File

@@ -17,6 +17,8 @@ abstract class VectorDb implements RustOpaqueInterface {
Future<void> bulkAddVectors(
{required Uint64List keys, required List<Float32List> vectors});
Future<List<Float32List>> bulkGetVectors({required Uint64List keys});
Future<BigInt> bulkRemoveVectors({required Uint64List keys});
Future<(List<Uint64List>, List<Float32List>)> bulkSearchVectors(

View File

@@ -72,7 +72,7 @@ class RustLib extends BaseEntrypoint<RustLibApi, RustLibApiImpl, RustLibWire> {
String get codegenVersion => '2.9.0';
@override
int get rustContentHash => -1131116360;
int get rustContentHash => 382419186;
static const kDefaultExternalLibraryLoaderConfig =
ExternalLibraryLoaderConfig(
@@ -93,6 +93,9 @@ abstract class RustLibApi extends BaseApi {
required Uint64List keys,
required List<Float32List> vectors});
Future<List<Float32List>> crateApiUsearchApiVectorDbBulkGetVectors(
{required VectorDb that, required Uint64List keys});
Future<BigInt> crateApiUsearchApiVectorDbBulkRemoveVectors(
{required VectorDb that, required Uint64List keys});
@@ -207,7 +210,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
);
@override
Future<BigInt> crateApiUsearchApiVectorDbBulkRemoveVectors(
Future<List<Float32List>> crateApiUsearchApiVectorDbBulkGetVectors(
{required VectorDb that, required Uint64List keys}) {
return handler.executeNormal(NormalTask(
callFfi: (port_) {
@@ -218,6 +221,34 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 3, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_list_list_prim_f_32_strict,
decodeErrorData: null,
),
constMeta: kCrateApiUsearchApiVectorDbBulkGetVectorsConstMeta,
argValues: [that, keys],
apiImpl: this,
));
}
TaskConstMeta get kCrateApiUsearchApiVectorDbBulkGetVectorsConstMeta =>
const TaskConstMeta(
debugName: "VectorDb_bulk_get_vectors",
argNames: ["that", "keys"],
);
@override
Future<BigInt> crateApiUsearchApiVectorDbBulkRemoveVectors(
{required VectorDb that, required Uint64List keys}) {
return handler.executeNormal(NormalTask(
callFfi: (port_) {
final serializer = SseSerializer(generalizedFrbRustBinding);
sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
sse_encode_list_prim_u_64_strict(keys, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 4, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_usize,
decodeErrorData: null,
@@ -248,7 +279,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_list_list_prim_f_32_strict(queries, serializer);
sse_encode_usize(count, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 4, port: port_);
funcId: 5, port: port_);
},
codec: SseCodec(
decodeSuccessData:
@@ -275,7 +306,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_Auto_Owned_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 5, port: port_);
funcId: 6, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_unit,
@@ -302,7 +333,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 6, port: port_);
funcId: 7, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_record_usize_usize_usize_usize_usize,
@@ -330,7 +361,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
that, serializer);
sse_encode_u_64(key, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 7, port: port_);
funcId: 8, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_list_prim_f_32_strict,
@@ -356,7 +387,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
final serializer = SseSerializer(generalizedFrbRustBinding);
sse_encode_String(filePath, serializer);
sse_encode_usize(dimensions, serializer);
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 8)!;
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 9)!;
},
codec: SseCodec(
decodeSuccessData:
@@ -385,7 +416,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
that, serializer);
sse_encode_u_64(key, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 9, port: port_);
funcId: 10, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_usize,
@@ -411,7 +442,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 10, port: port_);
funcId: 11, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_unit,
@@ -442,7 +473,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_list_prim_f_32_loose(query, serializer);
sse_encode_usize(count, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 11, port: port_);
funcId: 12, port: port_);
},
codec: SseCodec(
decodeSuccessData:
@@ -467,7 +498,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
callFfi: () {
final serializer = SseSerializer(generalizedFrbRustBinding);
sse_encode_String(name, serializer);
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 12)!;
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 13)!;
},
codec: SseCodec(
decodeSuccessData: sse_decode_String,
@@ -490,7 +521,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
callFfi: (port_) {
final serializer = SseSerializer(generalizedFrbRustBinding);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 13, port: port_);
funcId: 14, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_unit,
@@ -1008,6 +1039,10 @@ class VectorDbImpl extends RustOpaque implements VectorDb {
RustLib.instance.api.crateApiUsearchApiVectorDbBulkAddVectors(
that: this, keys: keys, vectors: vectors);
Future<List<Float32List>> bulkGetVectors({required Uint64List keys}) =>
RustLib.instance.api
.crateApiUsearchApiVectorDbBulkGetVectors(that: this, keys: keys);
Future<BigInt> bulkRemoveVectors({required Uint64List keys}) =>
RustLib.instance.api
.crateApiUsearchApiVectorDbBulkRemoveVectors(that: this, keys: keys);

View File

@@ -108,6 +108,15 @@ impl VectorDB {
vector
}
pub fn bulk_get_vectors(&self, keys: Vec<u64>) -> Vec<Vec<f32>> {
let mut vectors = Vec::new();
for key in keys {
let vector = self.get_vector(key);
vectors.push(vector);
}
vectors
}
pub fn remove_vector(&self, key: u64) -> usize {
let removed_count = self.index.remove(key).expect("Failed to remove vector");
self.save_index();

View File

@@ -38,7 +38,7 @@ flutter_rust_bridge::frb_generated_boilerplate!(
default_rust_auto_opaque = RustAutoOpaqueMoi,
);
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.9.0";
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -1131116360;
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 382419186;
// Section: executor
@@ -160,6 +160,60 @@ fn wire__crate__api__usearch_api__VectorDb_bulk_add_vectors_impl(
},
)
}
fn wire__crate__api__usearch_api__VectorDb_bulk_get_vectors_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr,
rust_vec_len_: i32,
data_len_: i32,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::SseCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
debug_name: "VectorDb_bulk_get_vectors",
port: Some(port_),
mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal,
},
move || {
let message = unsafe {
flutter_rust_bridge::for_generated::Dart2RustMessageSse::from_wire(
ptr_,
rust_vec_len_,
data_len_,
)
};
let mut deserializer =
flutter_rust_bridge::for_generated::SseDeserializer::new(message);
let api_that = <RustOpaqueMoi<
flutter_rust_bridge::for_generated::RustAutoOpaqueInner<VectorDB>,
>>::sse_decode(&mut deserializer);
let api_keys = <Vec<u64>>::sse_decode(&mut deserializer);
deserializer.end();
move |context| {
transform_result_sse::<_, ()>((move || {
let mut api_that_guard = None;
let decode_indices_ =
flutter_rust_bridge::for_generated::lockable_compute_decode_order(vec![
flutter_rust_bridge::for_generated::LockableOrderInfo::new(
&api_that, 0, false,
),
]);
for i in decode_indices_ {
match i {
0 => api_that_guard = Some(api_that.lockable_decode_sync_ref()),
_ => unreachable!(),
}
}
let api_that_guard = api_that_guard.unwrap();
let output_ok =
Result::<_, ()>::Ok(crate::api::usearch_api::VectorDB::bulk_get_vectors(
&*api_that_guard,
api_keys,
))?;
Ok(output_ok)
})())
}
},
)
}
fn wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr,
@@ -862,55 +916,61 @@ fn pde_ffi_dispatcher_primary_impl(
rust_vec_len,
data_len,
),
3 => wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl(
3 => wire__crate__api__usearch_api__VectorDb_bulk_get_vectors_impl(
port,
ptr,
rust_vec_len,
data_len,
),
4 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl(
4 => wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl(
port,
ptr,
rust_vec_len,
data_len,
),
5 => wire__crate__api__usearch_api__VectorDb_delete_index_impl(
5 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl(
port,
ptr,
rust_vec_len,
data_len,
),
6 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl(
6 => wire__crate__api__usearch_api__VectorDb_delete_index_impl(
port,
ptr,
rust_vec_len,
data_len,
),
7 => wire__crate__api__usearch_api__VectorDb_get_vector_impl(
7 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl(
port,
ptr,
rust_vec_len,
data_len,
),
9 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl(
8 => wire__crate__api__usearch_api__VectorDb_get_vector_impl(
port,
ptr,
rust_vec_len,
data_len,
),
10 => wire__crate__api__usearch_api__VectorDb_reset_index_impl(
10 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl(
port,
ptr,
rust_vec_len,
data_len,
),
11 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl(
11 => wire__crate__api__usearch_api__VectorDb_reset_index_impl(
port,
ptr,
rust_vec_len,
data_len,
),
13 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len),
12 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl(
port,
ptr,
rust_vec_len,
data_len,
),
14 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len),
_ => unreachable!(),
}
}
@@ -923,8 +983,8 @@ fn pde_ffi_dispatcher_sync_impl(
) -> flutter_rust_bridge::for_generated::WireSyncRust2DartSse {
// Codec=Pde (Serialization + dispatch), see doc to use other codecs
match func_id {
8 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len),
12 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len),
9 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len),
13 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len),
_ => unreachable!(),
}
}