Rust api to do entire search with potential keys in rust

This commit is contained in:
laurenspriem
2025-08-12 11:36:38 +05:30
parent dbf6f6aa37
commit 856a87f01c
5 changed files with 259 additions and 24 deletions

View File

@@ -21,6 +21,11 @@ abstract class VectorDb implements RustOpaqueInterface {
Future<BigInt> bulkRemoveVectors({required Uint64List keys});
Future<(Uint64List, List<Uint64List>, List<Float32List>)> bulkSearchKeys(
{required Uint64List potentialKeys,
required BigInt count,
required bool exact});
Future<(List<Uint64List>, List<Float32List>)> bulkSearchVectors(
{required List<Float32List> queries,
required BigInt count,

View File

@@ -74,7 +74,7 @@ class RustLib extends BaseEntrypoint<RustLibApi, RustLibApiImpl, RustLibWire> {
String get codegenVersion => '2.11.1';
@override
int get rustContentHash => 1543559173;
int get rustContentHash => 1360671619;
static const kDefaultExternalLibraryLoaderConfig =
ExternalLibraryLoaderConfig(
@@ -101,6 +101,13 @@ abstract class RustLibApi extends BaseApi {
Future<BigInt> crateApiUsearchApiVectorDbBulkRemoveVectors(
{required VectorDb that, required Uint64List keys});
Future<(Uint64List, List<Uint64List>, List<Float32List>)>
crateApiUsearchApiVectorDbBulkSearchKeys(
{required VectorDb that,
required Uint64List potentialKeys,
required BigInt count,
required bool exact});
Future<(List<Uint64List>, List<Float32List>)>
crateApiUsearchApiVectorDbBulkSearchVectors(
{required VectorDb that,
@@ -272,6 +279,41 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
argNames: ["that", "keys"],
);
@override
Future<(Uint64List, List<Uint64List>, List<Float32List>)>
crateApiUsearchApiVectorDbBulkSearchKeys(
{required VectorDb that,
required Uint64List potentialKeys,
required BigInt count,
required bool exact}) {
return handler.executeNormal(NormalTask(
callFfi: (port_) {
final serializer = SseSerializer(generalizedFrbRustBinding);
sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
sse_encode_list_prim_u_64_strict(potentialKeys, serializer);
sse_encode_usize(count, serializer);
sse_encode_bool(exact, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 5, port: port_);
},
codec: SseCodec(
decodeSuccessData:
sse_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict,
decodeErrorData: null,
),
constMeta: kCrateApiUsearchApiVectorDbBulkSearchKeysConstMeta,
argValues: [that, potentialKeys, count, exact],
apiImpl: this,
));
}
TaskConstMeta get kCrateApiUsearchApiVectorDbBulkSearchKeysConstMeta =>
const TaskConstMeta(
debugName: "VectorDb_bulk_search_keys",
argNames: ["that", "potentialKeys", "count", "exact"],
);
@override
Future<(List<Uint64List>, List<Float32List>)>
crateApiUsearchApiVectorDbBulkSearchVectors(
@@ -288,7 +330,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_usize(count, serializer);
sse_encode_bool(exact, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 5, port: port_);
funcId: 6, port: port_);
},
codec: SseCodec(
decodeSuccessData:
@@ -317,7 +359,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
that, serializer);
sse_encode_u_64(key, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 6, port: port_);
funcId: 7, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_bool,
@@ -343,7 +385,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_Auto_Owned_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 7, port: port_);
funcId: 8, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_unit,
@@ -370,7 +412,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 8, port: port_);
funcId: 9, port: port_);
},
codec: SseCodec(
decodeSuccessData:
@@ -399,7 +441,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
that, serializer);
sse_encode_u_64(key, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 9, port: port_);
funcId: 10, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_list_prim_f_32_strict,
@@ -425,7 +467,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
final serializer = SseSerializer(generalizedFrbRustBinding);
sse_encode_String(filePath, serializer);
sse_encode_usize(dimensions, serializer);
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 10)!;
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 11)!;
},
codec: SseCodec(
decodeSuccessData:
@@ -454,7 +496,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
that, serializer);
sse_encode_u_64(key, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 11, port: port_);
funcId: 12, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_usize,
@@ -480,7 +522,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_Auto_Ref_RustOpaque_flutter_rust_bridgefor_generatedRustAutoOpaqueInnerVectorDB(
that, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 12, port: port_);
funcId: 13, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_unit,
@@ -513,7 +555,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_usize(count, serializer);
sse_encode_bool(exact, serializer);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 13, port: port_);
funcId: 14, port: port_);
},
codec: SseCodec(
decodeSuccessData:
@@ -538,7 +580,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
callFfi: () {
final serializer = SseSerializer(generalizedFrbRustBinding);
sse_encode_String(name, serializer);
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 14)!;
return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 15)!;
},
codec: SseCodec(
decodeSuccessData: sse_decode_String,
@@ -561,7 +603,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
callFfi: (port_) {
final serializer = SseSerializer(generalizedFrbRustBinding);
pdeCallFfi(generalizedFrbRustBinding, serializer,
funcId: 15, port: port_);
funcId: 16, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_unit,
@@ -683,6 +725,25 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
);
}
@protected
(
Uint64List,
List<Uint64List>,
List<Float32List>
) dco_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
dynamic raw) {
// Codec=Dco (DartCObject based), see doc to use other codecs
final arr = raw as List<dynamic>;
if (arr.length != 3) {
throw Exception('Expected 3 elements, got ${arr.length}');
}
return (
dco_decode_list_prim_u_64_strict(arr[0]),
dco_decode_list_list_prim_u_64_strict(arr[1]),
dco_decode_list_list_prim_f_32_strict(arr[2]),
);
}
@protected
(
Uint64List,
@@ -852,6 +913,20 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
return (var_field0, var_field1);
}
@protected
(
Uint64List,
List<Uint64List>,
List<Float32List>
) sse_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
SseDeserializer deserializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
final var_field0 = sse_decode_list_prim_u_64_strict(deserializer);
final var_field1 = sse_decode_list_list_prim_u_64_strict(deserializer);
final var_field2 = sse_decode_list_list_prim_f_32_strict(deserializer);
return (var_field0, var_field1, var_field2);
}
@protected
(Uint64List, Float32List)
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
@@ -1020,6 +1095,17 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
sse_encode_list_list_prim_f_32_strict(self.$2, serializer);
}
@protected
void
sse_encode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
(Uint64List, List<Uint64List>, List<Float32List>) self,
SseSerializer serializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
sse_encode_list_prim_u_64_strict(self.$1, serializer);
sse_encode_list_list_prim_u_64_strict(self.$2, serializer);
sse_encode_list_list_prim_f_32_strict(self.$3, serializer);
}
@protected
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
(Uint64List, Float32List) self, SseSerializer serializer) {
@@ -1108,6 +1194,13 @@ class VectorDbImpl extends RustOpaque implements VectorDb {
RustLib.instance.api
.crateApiUsearchApiVectorDbBulkRemoveVectors(that: this, keys: keys);
Future<(Uint64List, List<Uint64List>, List<Float32List>)> bulkSearchKeys(
{required Uint64List potentialKeys,
required BigInt count,
required bool exact}) =>
RustLib.instance.api.crateApiUsearchApiVectorDbBulkSearchKeys(
that: this, potentialKeys: potentialKeys, count: count, exact: exact);
Future<(List<Uint64List>, List<Float32List>)> bulkSearchVectors(
{required List<Float32List> queries,
required BigInt count,

View File

@@ -72,6 +72,14 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
dynamic raw);
@protected
(
Uint64List,
List<Uint64List>,
List<Float32List>
) dco_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
dynamic raw);
@protected
(
Uint64List,
@@ -143,6 +151,14 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
SseDeserializer deserializer);
@protected
(
Uint64List,
List<Uint64List>,
List<Float32List>
) sse_decode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
SseDeserializer deserializer);
@protected
(Uint64List, Float32List)
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
@@ -220,6 +236,12 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
(List<Uint64List>, List<Float32List>) self, SseSerializer serializer);
@protected
void
sse_encode_record_list_prim_u_64_strict_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
(Uint64List, List<Uint64List>, List<Float32List>) self,
SseSerializer serializer);
@protected
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
(Uint64List, Float32List) self, SseSerializer serializer);

View File

@@ -118,6 +118,38 @@ impl VectorDB {
(keys, distances)
}
pub fn bulk_search_keys(
&self,
potential_keys: &Vec<u64>,
count: usize,
exact: bool,
) -> (Vec<u64>, Vec<Vec<u64>>, Vec<Vec<f32>>) {
// let max_contained_keys = potential_keys.len();
let mut contained_keys = Vec::new();
let mut queries = Vec::new();
for key in potential_keys {
let contains: bool = self.index.contains(*key);
if contains {
let embedding = self.get_vector(*key);
contained_keys.push(*key);
queries.push(embedding);
}
}
let mut closeby_keys = Vec::new();
let mut distances = Vec::new();
for query in &queries {
let (keys_result, distances_result) = self.search_vectors(query, count, exact);
closeby_keys.push(keys_result);
distances.push(distances_result);
}
if contained_keys.len() != closeby_keys.len() {
panic!("The number of contained keys does not match the number of keys");
}
(contained_keys, closeby_keys, distances)
}
/// Check if a vector with the given key exists in the index.
/// `true` if the index contains the vector with the given key, `false` otherwise.
pub fn contains_vector(&self, key: u64) -> bool {

View File

@@ -38,7 +38,7 @@ flutter_rust_bridge::frb_generated_boilerplate!(
default_rust_auto_opaque = RustAutoOpaqueMoi,
);
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.11.1";
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 1543559173;
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 1360671619;
// Section: executor
@@ -269,6 +269,64 @@ fn wire__crate__api__usearch_api__VectorDb_bulk_remove_vectors_impl(
},
)
}
fn wire__crate__api__usearch_api__VectorDb_bulk_search_keys_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr,
rust_vec_len_: i32,
data_len_: i32,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::SseCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
debug_name: "VectorDb_bulk_search_keys",
port: Some(port_),
mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal,
},
move || {
let message = unsafe {
flutter_rust_bridge::for_generated::Dart2RustMessageSse::from_wire(
ptr_,
rust_vec_len_,
data_len_,
)
};
let mut deserializer =
flutter_rust_bridge::for_generated::SseDeserializer::new(message);
let api_that = <RustOpaqueMoi<
flutter_rust_bridge::for_generated::RustAutoOpaqueInner<VectorDB>,
>>::sse_decode(&mut deserializer);
let api_potential_keys = <Vec<u64>>::sse_decode(&mut deserializer);
let api_count = <usize>::sse_decode(&mut deserializer);
let api_exact = <bool>::sse_decode(&mut deserializer);
deserializer.end();
move |context| {
transform_result_sse::<_, ()>((move || {
let mut api_that_guard = None;
let decode_indices_ =
flutter_rust_bridge::for_generated::lockable_compute_decode_order(vec![
flutter_rust_bridge::for_generated::LockableOrderInfo::new(
&api_that, 0, false,
),
]);
for i in decode_indices_ {
match i {
0 => api_that_guard = Some(api_that.lockable_decode_sync_ref()),
_ => unreachable!(),
}
}
let api_that_guard = api_that_guard.unwrap();
let output_ok =
Result::<_, ()>::Ok(crate::api::usearch_api::VectorDB::bulk_search_keys(
&*api_that_guard,
&api_potential_keys,
api_count,
api_exact,
))?;
Ok(output_ok)
})())
}
},
)
}
fn wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr,
@@ -899,6 +957,16 @@ impl SseDecode for (Vec<Vec<u64>>, Vec<Vec<f32>>) {
}
}
impl SseDecode for (Vec<u64>, Vec<Vec<u64>>, Vec<Vec<f32>>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
let mut var_field0 = <Vec<u64>>::sse_decode(deserializer);
let mut var_field1 = <Vec<Vec<u64>>>::sse_decode(deserializer);
let mut var_field2 = <Vec<Vec<f32>>>::sse_decode(deserializer);
return (var_field0, var_field1, var_field2);
}
}
impl SseDecode for (Vec<u64>, Vec<f32>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
@@ -990,55 +1058,61 @@ fn pde_ffi_dispatcher_primary_impl(
rust_vec_len,
data_len,
),
5 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl(
5 => wire__crate__api__usearch_api__VectorDb_bulk_search_keys_impl(
port,
ptr,
rust_vec_len,
data_len,
),
6 => wire__crate__api__usearch_api__VectorDb_contains_vector_impl(
6 => wire__crate__api__usearch_api__VectorDb_bulk_search_vectors_impl(
port,
ptr,
rust_vec_len,
data_len,
),
7 => wire__crate__api__usearch_api__VectorDb_delete_index_impl(
7 => wire__crate__api__usearch_api__VectorDb_contains_vector_impl(
port,
ptr,
rust_vec_len,
data_len,
),
8 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl(
8 => wire__crate__api__usearch_api__VectorDb_delete_index_impl(
port,
ptr,
rust_vec_len,
data_len,
),
9 => wire__crate__api__usearch_api__VectorDb_get_vector_impl(
9 => wire__crate__api__usearch_api__VectorDb_get_index_stats_impl(
port,
ptr,
rust_vec_len,
data_len,
),
11 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl(
10 => wire__crate__api__usearch_api__VectorDb_get_vector_impl(
port,
ptr,
rust_vec_len,
data_len,
),
12 => wire__crate__api__usearch_api__VectorDb_reset_index_impl(
12 => wire__crate__api__usearch_api__VectorDb_remove_vector_impl(
port,
ptr,
rust_vec_len,
data_len,
),
13 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl(
13 => wire__crate__api__usearch_api__VectorDb_reset_index_impl(
port,
ptr,
rust_vec_len,
data_len,
),
15 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len),
14 => wire__crate__api__usearch_api__VectorDb_search_vectors_impl(
port,
ptr,
rust_vec_len,
data_len,
),
16 => wire__crate__api__simple__init_app_impl(port, ptr, rust_vec_len, data_len),
_ => unreachable!(),
}
}
@@ -1051,8 +1125,8 @@ fn pde_ffi_dispatcher_sync_impl(
) -> flutter_rust_bridge::for_generated::WireSyncRust2DartSse {
// Codec=Pde (Serialization + dispatch), see doc to use other codecs
match func_id {
10 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len),
14 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len),
11 => wire__crate__api__usearch_api__VectorDb_new_impl(ptr, rust_vec_len, data_len),
15 => wire__crate__api__simple__greet_impl(ptr, rust_vec_len, data_len),
_ => unreachable!(),
}
}
@@ -1171,6 +1245,15 @@ impl SseEncode for (Vec<Vec<u64>>, Vec<Vec<f32>>) {
}
}
impl SseEncode for (Vec<u64>, Vec<Vec<u64>>, Vec<Vec<f32>>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {
<Vec<u64>>::sse_encode(self.0, serializer);
<Vec<Vec<u64>>>::sse_encode(self.1, serializer);
<Vec<Vec<f32>>>::sse_encode(self.2, serializer);
}
}
impl SseEncode for (Vec<u64>, Vec<f32>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {