Give distances in bulk search
This commit is contained in:
@@ -17,7 +17,7 @@ abstract class VectorDb implements RustOpaqueInterface {
|
||||
Future<void> bulkAddVectors(
|
||||
{required Uint64List keys, required List<Float32List> vectors});
|
||||
|
||||
Future<List<Uint64List>> bulkSearchVectors(
|
||||
Future<(List<Uint64List>, List<Float32List>)> bulkSearchVectors(
|
||||
{required List<Float32List> queries, required BigInt count});
|
||||
|
||||
Future<void> deleteIndex();
|
||||
|
||||
@@ -93,10 +93,11 @@ abstract class RustLibApi extends BaseApi {
|
||||
required Uint64List keys,
|
||||
required List<Float32List> vectors});
|
||||
|
||||
Future<List<Uint64List>> crateApiUsearchApiVectorDbBulkSearchVectors(
|
||||
{required VectorDb that,
|
||||
required List<Float32List> queries,
|
||||
required BigInt count});
|
||||
Future<(List<Uint64List>, List<Float32List>)>
|
||||
crateApiUsearchApiVectorDbBulkSearchVectors(
|
||||
{required VectorDb that,
|
||||
required List<Float32List> queries,
|
||||
required BigInt count});
|
||||
|
||||
Future<void> crateApiUsearchApiVectorDbDeleteIndex({required VectorDb that});
|
||||
|
||||
@@ -203,10 +204,11 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
);
|
||||
|
||||
@override
|
||||
Future<List<Uint64List>> crateApiUsearchApiVectorDbBulkSearchVectors(
|
||||
{required VectorDb that,
|
||||
required List<Float32List> queries,
|
||||
required BigInt count}) {
|
||||
Future<(List<Uint64List>, List<Float32List>)>
|
||||
crateApiUsearchApiVectorDbBulkSearchVectors(
|
||||
{required VectorDb that,
|
||||
required List<Float32List> queries,
|
||||
required BigInt count}) {
|
||||
return handler.executeNormal(NormalTask(
|
||||
callFfi: (port_) {
|
||||
final serializer = SseSerializer(generalizedFrbRustBinding);
|
||||
@@ -218,7 +220,8 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
funcId: 3, port: port_);
|
||||
},
|
||||
codec: SseCodec(
|
||||
decodeSuccessData: sse_decode_list_list_prim_u_64_strict,
|
||||
decodeSuccessData:
|
||||
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict,
|
||||
decodeErrorData: null,
|
||||
),
|
||||
constMeta: kCrateApiUsearchApiVectorDbBulkSearchVectorsConstMeta,
|
||||
@@ -557,6 +560,21 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
return raw as Uint8List;
|
||||
}
|
||||
|
||||
@protected
|
||||
(List<Uint64List>, List<Float32List>)
|
||||
dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
dynamic raw) {
|
||||
// Codec=Dco (DartCObject based), see doc to use other codecs
|
||||
final arr = raw as List<dynamic>;
|
||||
if (arr.length != 2) {
|
||||
throw Exception('Expected 2 elements, got ${arr.length}');
|
||||
}
|
||||
return (
|
||||
dco_decode_list_list_prim_u_64_strict(arr[0]),
|
||||
dco_decode_list_list_prim_f_32_strict(arr[1]),
|
||||
);
|
||||
}
|
||||
|
||||
@protected
|
||||
(
|
||||
Uint64List,
|
||||
@@ -708,6 +726,16 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
return deserializer.buffer.getUint8List(len_);
|
||||
}
|
||||
|
||||
@protected
|
||||
(List<Uint64List>, List<Float32List>)
|
||||
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
SseDeserializer deserializer) {
|
||||
// Codec=Sse (Serialization based), see doc to use other codecs
|
||||
final var_field0 = sse_decode_list_list_prim_u_64_strict(deserializer);
|
||||
final var_field1 = sse_decode_list_list_prim_f_32_strict(deserializer);
|
||||
return (var_field0, var_field1);
|
||||
}
|
||||
|
||||
@protected
|
||||
(Uint64List, Float32List)
|
||||
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
|
||||
@@ -858,6 +886,14 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
serializer.buffer.putUint8List(self);
|
||||
}
|
||||
|
||||
@protected
|
||||
void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
(List<Uint64List>, List<Float32List>) self, SseSerializer serializer) {
|
||||
// Codec=Sse (Serialization based), see doc to use other codecs
|
||||
sse_encode_list_list_prim_u_64_strict(self.$1, serializer);
|
||||
sse_encode_list_list_prim_f_32_strict(self.$2, serializer);
|
||||
}
|
||||
|
||||
@protected
|
||||
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
|
||||
(Uint64List, Float32List) self, SseSerializer serializer) {
|
||||
@@ -941,7 +977,7 @@ class VectorDbImpl extends RustOpaque implements VectorDb {
|
||||
RustLib.instance.api.crateApiUsearchApiVectorDbBulkAddVectors(
|
||||
that: this, keys: keys, vectors: vectors);
|
||||
|
||||
Future<List<Uint64List>> bulkSearchVectors(
|
||||
Future<(List<Uint64List>, List<Float32List>)> bulkSearchVectors(
|
||||
{required List<Float32List> queries, required BigInt count}) =>
|
||||
RustLib.instance.api.crateApiUsearchApiVectorDbBulkSearchVectors(
|
||||
that: this, queries: queries, count: count);
|
||||
|
||||
@@ -64,6 +64,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
|
||||
@protected
|
||||
Uint8List dco_decode_list_prim_u_8_strict(dynamic raw);
|
||||
|
||||
@protected
|
||||
(List<Uint64List>, List<Float32List>)
|
||||
dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
dynamic raw);
|
||||
|
||||
@protected
|
||||
(
|
||||
Uint64List,
|
||||
@@ -127,6 +132,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
|
||||
@protected
|
||||
Uint8List sse_decode_list_prim_u_8_strict(SseDeserializer deserializer);
|
||||
|
||||
@protected
|
||||
(List<Uint64List>, List<Float32List>)
|
||||
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
SseDeserializer deserializer);
|
||||
|
||||
@protected
|
||||
(Uint64List, Float32List)
|
||||
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
|
||||
@@ -200,6 +210,10 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
|
||||
void sse_encode_list_prim_u_8_strict(
|
||||
Uint8List self, SseSerializer serializer);
|
||||
|
||||
@protected
|
||||
void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
(List<Uint64List>, List<Float32List>) self, SseSerializer serializer);
|
||||
|
||||
@protected
|
||||
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
|
||||
(Uint64List, Float32List) self, SseSerializer serializer);
|
||||
|
||||
@@ -66,6 +66,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
|
||||
@protected
|
||||
Uint8List dco_decode_list_prim_u_8_strict(dynamic raw);
|
||||
|
||||
@protected
|
||||
(List<Uint64List>, List<Float32List>)
|
||||
dco_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
dynamic raw);
|
||||
|
||||
@protected
|
||||
(
|
||||
Uint64List,
|
||||
@@ -129,6 +134,11 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
|
||||
@protected
|
||||
Uint8List sse_decode_list_prim_u_8_strict(SseDeserializer deserializer);
|
||||
|
||||
@protected
|
||||
(List<Uint64List>, List<Float32List>)
|
||||
sse_decode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
SseDeserializer deserializer);
|
||||
|
||||
@protected
|
||||
(Uint64List, Float32List)
|
||||
sse_decode_record_list_prim_u_64_strict_list_prim_f_32_strict(
|
||||
@@ -202,6 +212,10 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
|
||||
void sse_encode_list_prim_u_8_strict(
|
||||
Uint8List self, SseSerializer serializer);
|
||||
|
||||
@protected
|
||||
void sse_encode_record_list_list_prim_u_64_strict_list_list_prim_f_32_strict(
|
||||
(List<Uint64List>, List<Float32List>) self, SseSerializer serializer);
|
||||
|
||||
@protected
|
||||
void sse_encode_record_list_prim_u_64_strict_list_prim_f_32_strict(
|
||||
(Uint64List, Float32List) self, SseSerializer serializer);
|
||||
|
||||
@@ -183,7 +183,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
final queries = laurensFaceIdToFloat32.values.toList();
|
||||
final count = BigInt.from(10);
|
||||
w?.reset();
|
||||
final results = await vectorDB.bulkSearchVectors(
|
||||
final (vectorKeys, distances) = await vectorDB.bulkSearchVectors(
|
||||
queries: queries,
|
||||
count: count,
|
||||
);
|
||||
@@ -192,7 +192,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
'Done with ${queries.length * queries.length} (${queries.length} x ${queries.length}}) embeddings comparisons in vector DB',
|
||||
);
|
||||
logger.info(
|
||||
'vector db results: ${results.length} results, first: ${results.first}, hundredth: ${results[99]}',
|
||||
'vector db results: ${vectorKeys.length} results, first: ${vectorKeys.first}, hundredth: ${vectorKeys[99]}',
|
||||
);
|
||||
|
||||
// Benchmarking our own vector comparisons
|
||||
@@ -267,7 +267,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
// Benchmarking the vector DB
|
||||
final count = BigInt.from(10);
|
||||
w?.reset();
|
||||
final results = await vectorDB.bulkSearchVectors(
|
||||
final (vectorKeys, distances) = await vectorDB.bulkSearchVectors(
|
||||
queries: clipFloat32,
|
||||
count: count,
|
||||
);
|
||||
@@ -276,7 +276,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
'Done with ${clipFloat32.length * clipFloat32.length} (${clipFloat32.length} x ${clipFloat32.length}}) embeddings comparisons in vector DB',
|
||||
);
|
||||
logger.info(
|
||||
'vector db results: ${results.length} results, first: ${results.first}, hundredth: ${results[99]}',
|
||||
'vector db results: ${vectorKeys.length} results, first: ${vectorKeys.first} with distances ${distances.first}, hundredth: ${vectorKeys[99]} with distances ${distances[99]}',
|
||||
);
|
||||
|
||||
// // Benchmarking our own vector comparisons
|
||||
|
||||
@@ -84,17 +84,16 @@ impl VectorDB {
|
||||
(matches.keys, matches.distances)
|
||||
}
|
||||
|
||||
pub fn bulk_search_vectors(&self, queries: &Vec<Vec<f32>>, count: usize) -> Vec<Vec<u64>> {
|
||||
pub fn bulk_search_vectors(&self, queries: &Vec<Vec<f32>>, count: usize) -> (Vec<Vec<u64>>, Vec<Vec<f32>>) {
|
||||
let mut keys = Vec::new();
|
||||
let mut distances = Vec::new();
|
||||
|
||||
for query in queries {
|
||||
let matches = self
|
||||
.index
|
||||
.search(query, count)
|
||||
.expect("Failed to search vectors");
|
||||
keys.push(matches.keys);
|
||||
let (keys_result, distances_result) = self.search_vectors(query, count);
|
||||
keys.push(keys_result);
|
||||
distances.push(distances_result);
|
||||
}
|
||||
keys
|
||||
(keys, distances)
|
||||
}
|
||||
|
||||
pub fn get_vector(&self, key: u64) -> Vec<f32> {
|
||||
|
||||
@@ -716,6 +716,15 @@ impl SseDecode for Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
impl SseDecode for (Vec<Vec<u64>>, Vec<Vec<f32>>) {
|
||||
// Codec=Sse (Serialization based), see doc to use other codecs
|
||||
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
|
||||
let mut var_field0 = <Vec<Vec<u64>>>::sse_decode(deserializer);
|
||||
let mut var_field1 = <Vec<Vec<f32>>>::sse_decode(deserializer);
|
||||
return (var_field0, var_field1);
|
||||
}
|
||||
}
|
||||
|
||||
impl SseDecode for (Vec<u64>, Vec<f32>) {
|
||||
// Codec=Sse (Serialization based), see doc to use other codecs
|
||||
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
|
||||
@@ -958,6 +967,14 @@ impl SseEncode for Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
impl SseEncode for (Vec<Vec<u64>>, Vec<Vec<f32>>) {
|
||||
// Codec=Sse (Serialization based), see doc to use other codecs
|
||||
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {
|
||||
<Vec<Vec<u64>>>::sse_encode(self.0, serializer);
|
||||
<Vec<Vec<f32>>>::sse_encode(self.1, serializer);
|
||||
}
|
||||
}
|
||||
|
||||
impl SseEncode for (Vec<u64>, Vec<f32>) {
|
||||
// Codec=Sse (Serialization based), see doc to use other codecs
|
||||
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {
|
||||
|
||||
Reference in New Issue
Block a user