Give distances in bulk search

This commit is contained in:
laurenspriem
2025-04-09 15:31:03 +05:30
parent 77e2bb1d46
commit 8b489e9ced
7 changed files with 102 additions and 22 deletions

View File

@@ -17,7 +17,7 @@ abstract class VectorDb implements RustOpaqueInterface {
Future<void> bulkAddVectors(
{required Uint64List keys, required List<Float32List> vectors});
Future<List<Uint64List>> bulkSearchVectors(
Future<(List<Uint64List>, List<Float32List>)> bulkSearchVectors(
{required List<Float32List> queries, required BigInt count});
Future<void> deleteIndex();

View File

@@ -93,10 +93,11 @@ abstract class RustLibApi extends BaseApi {
required Uint64List keys,
required List<Float32List> vectors});
Future<List<Uint64List>> crateApiUsearchApiVectorDbBulkSearchVectors(
{required VectorDb that,
required List<Float32List> queries,
required BigInt count});
Future<(List<Uint64List>, List<Float32List>)>
crateApiUsearchApiVectorDbBulkSearchVectors(
{required VectorDb that,
required List<Float32List> queries,
required BigInt count});
Future<void> crateApiUsearchApiVectorDbDeleteIndex({required VectorDb that});
@@ -203,10 +204,11 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
);
@override
Future<List<Uint64List>> crateApiUsearchApiVectorDbBulkSearchVectors(
{required VectorDb that,
required List<Float32List> queries,
required BigInt count}) {
Future<(List<Uint64List>, List<Float32List>)>
crateApiUsearchApiVectorDbBulkSearchVectors(
{required VectorDb that,
required List<Float32List> queries,
required BigInt count}) {
return handler.executeNormal(NormalTask(
callFfi: (port_) {
final serializer = SseSerializer(generalizedFrbRustBinding);
@@ -218,7 +220,8 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
funcId: 3, port: port_);
},
codec: SseCodec(
decodeSuccessData: sse_decode_list_list_prim_u_64_strict,
decodeSuccessData:
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict,
decodeErrorData: null,
),
constMeta: kCrateApiUsearchApiVectorDbBulkSearchVectorsConstMeta,
@@ -557,6 +560,21 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
return raw as Uint8List;
}
@protected
(List<Uint64List>, List<Float32List>)
dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
dynamic raw) {
// Codec=Dco (DartCObject based), see doc to use other codecs
final arr = raw as List<dynamic>;
if (arr.length != 2) {
throw Exception('Expected 2 elements, got ${arr.length}');
}
return (
dco_decode_list_list_prim_u_64_strict(arr[0]),
dco_decode_list_list_prim_f_32_strict(arr[1]),
);
}
@protected
(
Uint64List,
@@ -708,6 +726,16 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
return deserializer.buffer.getUint8List(len_);
}
@protected
(List<Uint64List>, List<Float32List>)
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
SseDeserializer deserializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
final var_field0 = sse_decode_list_list_prim_u_64_strict(deserializer);
final var_field1 = sse_decode_list_list_prim_f_32_strict(deserializer);
return (var_field0, var_field1);
}
@protected
(Uint64List, Float32List)
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
@@ -858,6 +886,14 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
serializer.buffer.putUint8List(self);
}
@protected
void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
(List<Uint64List>, List<Float32List>) self, SseSerializer serializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
sse_encode_list_list_prim_u_64_strict(self.$1, serializer);
sse_encode_list_list_prim_f_32_strict(self.$2, serializer);
}
@protected
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
(Uint64List, Float32List) self, SseSerializer serializer) {
@@ -941,7 +977,7 @@ class VectorDbImpl extends RustOpaque implements VectorDb {
RustLib.instance.api.crateApiUsearchApiVectorDbBulkAddVectors(
that: this, keys: keys, vectors: vectors);
Future<List<Uint64List>> bulkSearchVectors(
Future<(List<Uint64List>, List<Float32List>)> bulkSearchVectors(
{required List<Float32List> queries, required BigInt count}) =>
RustLib.instance.api.crateApiUsearchApiVectorDbBulkSearchVectors(
that: this, queries: queries, count: count);

View File

@@ -64,6 +64,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
Uint8List dco_decode_list_prim_u_8_strict(dynamic raw);
@protected
(List<Uint64List>, List<Float32List>)
dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
dynamic raw);
@protected
(
Uint64List,
@@ -127,6 +132,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
Uint8List sse_decode_list_prim_u_8_strict(SseDeserializer deserializer);
@protected
(List<Uint64List>, List<Float32List>)
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
SseDeserializer deserializer);
@protected
(Uint64List, Float32List)
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
@@ -200,6 +210,10 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
void sse_encode_list_prim_u_8_strict(
Uint8List self, SseSerializer serializer);
@protected
void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
(List<Uint64List>, List<Float32List>) self, SseSerializer serializer);
@protected
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
(Uint64List, Float32List) self, SseSerializer serializer);

View File

@@ -66,6 +66,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
Uint8List dco_decode_list_prim_u_8_strict(dynamic raw);
@protected
(List<Uint64List>, List<Float32List>)
dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
dynamic raw);
@protected
(
Uint64List,
@@ -129,6 +134,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
Uint8List sse_decode_list_prim_u_8_strict(SseDeserializer deserializer);
@protected
(List<Uint64List>, List<Float32List>)
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
SseDeserializer deserializer);
@protected
(Uint64List, Float32List)
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
@@ -202,6 +212,10 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
void sse_encode_list_prim_u_8_strict(
Uint8List self, SseSerializer serializer);
@protected
void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
(List<Uint64List>, List<Float32List>) self, SseSerializer serializer);
@protected
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
(Uint64List, Float32List) self, SseSerializer serializer);

View File

@@ -183,7 +183,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
final queries = laurensFaceIdToFloat32.values.toList();
final count = BigInt.from(10);
w?.reset();
final results = await vectorDB.bulkSearchVectors(
final (vectorKeys, distances) = await vectorDB.bulkSearchVectors(
queries: queries,
count: count,
);
@@ -192,7 +192,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
'Done with ${queries.length * queries.length} (${queries.length} x ${queries.length}}) embeddings comparisons in vector DB',
);
logger.info(
'vector db results: ${results.length} results, first: ${results.first}, hundredth: ${results[99]}',
'vector db results: ${vectorKeys.length} results, first: ${vectorKeys.first}, hundredth: ${vectorKeys[99]}',
);
// Benchmarking our own vector comparisons
@@ -267,7 +267,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
// Benchmarking the vector DB
final count = BigInt.from(10);
w?.reset();
final results = await vectorDB.bulkSearchVectors(
final (vectorKeys, distances) = await vectorDB.bulkSearchVectors(
queries: clipFloat32,
count: count,
);
@@ -276,7 +276,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
'Done with ${clipFloat32.length * clipFloat32.length} (${clipFloat32.length} x ${clipFloat32.length}}) embeddings comparisons in vector DB',
);
logger.info(
'vector db results: ${results.length} results, first: ${results.first}, hundredth: ${results[99]}',
'vector db results: ${vectorKeys.length} results, first: ${vectorKeys.first} with distances ${distances.first}, hundredth: ${vectorKeys[99]} with distances ${distances[99]}',
);
// // Benchmarking our own vector comparisons

View File

@@ -84,17 +84,16 @@ impl VectorDB {
(matches.keys, matches.distances)
}
pub fn bulk_search_vectors(&self, queries: &Vec<Vec<f32>>, count: usize) -> Vec<Vec<u64>> {
pub fn bulk_search_vectors(&self, queries: &Vec<Vec<f32>>, count: usize) -> (Vec<Vec<u64>>, Vec<Vec<f32>>) {
let mut keys = Vec::new();
let mut distances = Vec::new();
for query in queries {
let matches = self
.index
.search(query, count)
.expect("Failed to search vectors");
keys.push(matches.keys);
let (keys_result, distances_result) = self.search_vectors(query, count);
keys.push(keys_result);
distances.push(distances_result);
}
keys
(keys, distances)
}
pub fn get_vector(&self, key: u64) -> Vec<f32> {

View File

@@ -716,6 +716,15 @@ impl SseDecode for Vec<u8> {
}
}
impl SseDecode for (Vec<Vec<u64>>, Vec<Vec<f32>>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
let mut var_field0 = <Vec<Vec<u64>>>::sse_decode(deserializer);
let mut var_field1 = <Vec<Vec<f32>>>::sse_decode(deserializer);
return (var_field0, var_field1);
}
}
impl SseDecode for (Vec<u64>, Vec<f32>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
@@ -958,6 +967,14 @@ impl SseEncode for Vec<u8> {
}
}
impl SseEncode for (Vec<Vec<u64>>, Vec<Vec<f32>>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {
<Vec<Vec<u64>>>::sse_encode(self.0, serializer);
<Vec<Vec<f32>>>::sse_encode(self.1, serializer);
}
}
impl SseEncode for (Vec<u64>, Vec<f32>) {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {