diff --git a/mobile/lib/src/rust/api/usearch_api.dart b/mobile/lib/src/rust/api/usearch_api.dart index 7e3fa3d780..7630098f59 100644 --- a/mobile/lib/src/rust/api/usearch_api.dart +++ b/mobile/lib/src/rust/api/usearch_api.dart @@ -17,7 +17,7 @@ abstract class VectorDb implements RustOpaqueInterface { Future bulkAddVectors( {required Uint64List keys, required List vectors}); - Future> bulkSearchVectors( + Future<(List, List)> bulkSearchVectors( {required List queries, required BigInt count}); Future deleteIndex(); diff --git a/mobile/lib/src/rust/frb_generated.dart b/mobile/lib/src/rust/frb_generated.dart index bf3551a6ee..4e2ee761ca 100644 --- a/mobile/lib/src/rust/frb_generated.dart +++ b/mobile/lib/src/rust/frb_generated.dart @@ -93,10 +93,11 @@ abstract class RustLibApi extends BaseApi { required Uint64List keys, required List vectors}); - Future> crateApiUsearchApiVectorDbBulkSearchVectors( - {required VectorDb that, - required List queries, - required BigInt count}); + Future<(List, List)> + crateApiUsearchApiVectorDbBulkSearchVectors( + {required VectorDb that, + required List queries, + required BigInt count}); Future crateApiUsearchApiVectorDbDeleteIndex({required VectorDb that}); @@ -203,10 +204,11 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { ); @override - Future> crateApiUsearchApiVectorDbBulkSearchVectors( - {required VectorDb that, - required List queries, - required BigInt count}) { + Future<(List, List)> + crateApiUsearchApiVectorDbBulkSearchVectors( + {required VectorDb that, + required List queries, + required BigInt count}) { return handler.executeNormal(NormalTask( callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); @@ -218,7 +220,8 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { funcId: 3, port: port_); }, codec: SseCodec( - decodeSuccessData: sse_decode_list_list_prim_u_64_strict, + decodeSuccessData: + sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict, decodeErrorData: null, ), constMeta: kCrateApiUsearchApiVectorDbBulkSearchVectorsConstMeta, @@ -557,6 +560,21 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { return raw as Uint8List; } + @protected + (List, List) + dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + dynamic raw) { + // Codec=Dco (DartCObject based), see doc to use other codecs + final arr = raw as List; + if (arr.length != 2) { + throw Exception('Expected 2 elements, got ${arr.length}'); + } + return ( + dco_decode_list_list_prim_u_64_strict(arr[0]), + dco_decode_list_list_prim_f_32_strict(arr[1]), + ); + } + @protected ( Uint64List, @@ -708,6 +726,16 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { return deserializer.buffer.getUint8List(len_); } + @protected + (List, List) + sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + SseDeserializer deserializer) { + // Codec=Sse (Serialization based), see doc to use other codecs + final var_field0 = sse_decode_list_list_prim_u_64_strict(deserializer); + final var_field1 = sse_decode_list_list_prim_f_32_strict(deserializer); + return (var_field0, var_field1); + } + @protected (Uint64List, Float32List) sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict( @@ -858,6 +886,14 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { serializer.buffer.putUint8List(self); } + @protected + void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + (List, List) self, SseSerializer serializer) { + // Codec=Sse (Serialization based), see doc to use other codecs + sse_encode_list_list_prim_u_64_strict(self.$1, serializer); + sse_encode_list_list_prim_f_32_strict(self.$2, serializer); + } + @protected void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict( (Uint64List, Float32List) self, SseSerializer serializer) { @@ -941,7 +977,7 @@ class VectorDbImpl extends RustOpaque implements VectorDb { RustLib.instance.api.crateApiUsearchApiVectorDbBulkAddVectors( that: this, keys: keys, vectors: vectors); - Future> bulkSearchVectors( + Future<(List, List)> bulkSearchVectors( {required List queries, required BigInt count}) => RustLib.instance.api.crateApiUsearchApiVectorDbBulkSearchVectors( that: this, queries: queries, count: count); diff --git a/mobile/lib/src/rust/frb_generated.io.dart b/mobile/lib/src/rust/frb_generated.io.dart index 440796cf32..34a84063ae 100644 --- a/mobile/lib/src/rust/frb_generated.io.dart +++ b/mobile/lib/src/rust/frb_generated.io.dart @@ -64,6 +64,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { @protected Uint8List dco_decode_list_prim_u_8_strict(dynamic raw); + @protected + (List, List) + dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + dynamic raw); + @protected ( Uint64List, @@ -127,6 +132,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { @protected Uint8List sse_decode_list_prim_u_8_strict(SseDeserializer deserializer); + @protected + (List, List) + sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + SseDeserializer deserializer); + @protected (Uint64List, Float32List) sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict( @@ -200,6 +210,10 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { void sse_encode_list_prim_u_8_strict( Uint8List self, SseSerializer serializer); + @protected + void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + (List, List) self, SseSerializer serializer); + @protected void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict( (Uint64List, Float32List) self, SseSerializer serializer); diff --git a/mobile/lib/src/rust/frb_generated.web.dart b/mobile/lib/src/rust/frb_generated.web.dart index e3f81919d0..20c7329814 100644 --- a/mobile/lib/src/rust/frb_generated.web.dart +++ b/mobile/lib/src/rust/frb_generated.web.dart @@ -66,6 +66,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { @protected Uint8List dco_decode_list_prim_u_8_strict(dynamic raw); + @protected + (List, List) + dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + dynamic raw); + @protected ( Uint64List, @@ -129,6 +134,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { @protected Uint8List sse_decode_list_prim_u_8_strict(SseDeserializer deserializer); + @protected + (List, List) + sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + SseDeserializer deserializer); + @protected (Uint64List, Float32List) sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict( @@ -202,6 +212,10 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { void sse_encode_list_prim_u_8_strict( Uint8List self, SseSerializer serializer); + @protected + void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict( + (List, List) self, SseSerializer serializer); + @protected void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict( (Uint64List, Float32List) self, SseSerializer serializer); diff --git a/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart b/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart index acf4690af4..421ba6f57e 100644 --- a/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart @@ -183,7 +183,7 @@ class _MLDebugSectionWidgetState extends State { final queries = laurensFaceIdToFloat32.values.toList(); final count = BigInt.from(10); w?.reset(); - final results = await vectorDB.bulkSearchVectors( + final (vectorKeys, distances) = await vectorDB.bulkSearchVectors( queries: queries, count: count, ); @@ -192,7 +192,7 @@ class _MLDebugSectionWidgetState extends State { 'Done with ${queries.length * queries.length} (${queries.length} x ${queries.length}}) embeddings comparisons in vector DB', ); logger.info( - 'vector db results: ${results.length} results, first: ${results.first}, hundredth: ${results[99]}', + 'vector db results: ${vectorKeys.length} results, first: ${vectorKeys.first}, hundredth: ${vectorKeys[99]}', ); // Benchmarking our own vector comparisons @@ -267,7 +267,7 @@ class _MLDebugSectionWidgetState extends State { // Benchmarking the vector DB final count = BigInt.from(10); w?.reset(); - final results = await vectorDB.bulkSearchVectors( + final (vectorKeys, distances) = await vectorDB.bulkSearchVectors( queries: clipFloat32, count: count, ); @@ -276,7 +276,7 @@ class _MLDebugSectionWidgetState extends State { 'Done with ${clipFloat32.length * clipFloat32.length} (${clipFloat32.length} x ${clipFloat32.length}}) embeddings comparisons in vector DB', ); logger.info( - 'vector db results: ${results.length} results, first: ${results.first}, hundredth: ${results[99]}', + 'vector db results: ${vectorKeys.length} results, first: ${vectorKeys.first} with distances ${distances.first}, hundredth: ${vectorKeys[99]} with distances ${distances[99]}', ); // // Benchmarking our own vector comparisons diff --git a/mobile/rust/src/api/usearch_api.rs b/mobile/rust/src/api/usearch_api.rs index 7deb4b421c..d5d6ce318e 100644 --- a/mobile/rust/src/api/usearch_api.rs +++ b/mobile/rust/src/api/usearch_api.rs @@ -84,17 +84,16 @@ impl VectorDB { (matches.keys, matches.distances) } - pub fn bulk_search_vectors(&self, queries: &Vec>, count: usize) -> Vec> { + pub fn bulk_search_vectors(&self, queries: &Vec>, count: usize) -> (Vec>, Vec>) { let mut keys = Vec::new(); + let mut distances = Vec::new(); for query in queries { - let matches = self - .index - .search(query, count) - .expect("Failed to search vectors"); - keys.push(matches.keys); + let (keys_result, distances_result) = self.search_vectors(query, count); + keys.push(keys_result); + distances.push(distances_result); } - keys + (keys, distances) } pub fn get_vector(&self, key: u64) -> Vec { diff --git a/mobile/rust/src/frb_generated.rs b/mobile/rust/src/frb_generated.rs index d27f77f862..8e80d305db 100644 --- a/mobile/rust/src/frb_generated.rs +++ b/mobile/rust/src/frb_generated.rs @@ -716,6 +716,15 @@ impl SseDecode for Vec { } } +impl SseDecode for (Vec>, Vec>) { + // Codec=Sse (Serialization based), see doc to use other codecs + fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self { + let mut var_field0 = >>::sse_decode(deserializer); + let mut var_field1 = >>::sse_decode(deserializer); + return (var_field0, var_field1); + } +} + impl SseDecode for (Vec, Vec) { // Codec=Sse (Serialization based), see doc to use other codecs fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self { @@ -958,6 +967,14 @@ impl SseEncode for Vec { } } +impl SseEncode for (Vec>, Vec>) { + // Codec=Sse (Serialization based), see doc to use other codecs + fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) { + >>::sse_encode(self.0, serializer); + >>::sse_encode(self.1, serializer); + } +} + impl SseEncode for (Vec, Vec) { // Codec=Sse (Serialization based), see doc to use other codecs fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {