From ea1a2960bfd72e0cf7c440bb9073c5d05ef038ee Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 8 May 2025 12:08:55 +0530 Subject: [PATCH] First implementation of clip vector db --- mobile/lib/db/ml/clip_vector_db.dart | 139 +++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 mobile/lib/db/ml/clip_vector_db.dart diff --git a/mobile/lib/db/ml/clip_vector_db.dart b/mobile/lib/db/ml/clip_vector_db.dart new file mode 100644 index 0000000000..ae9b7be348 --- /dev/null +++ b/mobile/lib/db/ml/clip_vector_db.dart @@ -0,0 +1,139 @@ +import "dart:typed_data" show Float32List; + +import "package:flutter_rust_bridge/flutter_rust_bridge.dart" show Uint64List; +import "package:logging/logging.dart"; +import "package:path/path.dart"; +import "package:path_provider/path_provider.dart"; +import "package:photos/models/ml/vector.dart"; +import "package:photos/src/rust/api/usearch_api.dart"; + +class ClipVectorDB { + static final Logger _logger = Logger("ClipVectorDB"); + + static const _databaseName = "ente.ml.vectordb.clip"; + + static final BigInt _embeddingDimension = BigInt.from(512); + + static Logger get logger => _logger; + + // Singleton pattern + ClipVectorDB._privateConstructor(); + static final instance = ClipVectorDB._privateConstructor(); + factory ClipVectorDB() => instance; + + // only have a single app-wide reference to the database + static Future? _vectorDbFuture; + + Future get _vectorDB async { + _vectorDbFuture ??= _initVectorDB(); + return _vectorDbFuture!; + } + + Future _initVectorDB() async { + final documentsDirectory = await getApplicationDocumentsDirectory(); + final String databaseDirectory = + join(documentsDirectory.path, _databaseName); + _logger.info("Opening vectorDB access: DB path " + databaseDirectory); + final vectorDB = VectorDb( + filePath: databaseDirectory, + dimensions: _embeddingDimension, + ); + + return vectorDB; + } + + Future bulkInsertEmbeddings({ + required Uint64List keys, + required List embeddings, + }) async { + final db = await _vectorDB; + try { + await db.bulkAddVectors(keys: keys, vectors: embeddings); + } catch (e, s) { + _logger.severe("Error bulk inserting embeddings", e, s); + rethrow; + } + } + + Future insertEmbeddings({ + required BigInt key, + required List embedding, + }) async { + final db = await _vectorDB; + try { + await db.addVector(key: key, vector: embedding); + } catch (e, s) { + _logger.severe("Error inserting embedding", e, s); + rethrow; + } + } + + Future> getVectors(List fileIds) async { + final db = await _vectorDB; + try { + final keys = Uint64List.fromList(fileIds); + final vectors = await db.bulkGetVectors(keys: keys); + return List.generate( + vectors.length, + (index) => EmbeddingVector( + fileID: fileIds[index], + embedding: vectors[index], + ), + ); + } catch (e, s) { + _logger.severe("Error getting embeddings", e, s); + rethrow; + } + } + + Future deleteEmbeddings(List keys) async { + final db = await _vectorDB; + try { + final deletedCount = + await db.bulkRemoveVectors(keys: Uint64List.fromList(keys)); + _logger + .info("Deleted $deletedCount embeddings, from ${keys.length} keys"); + } catch (e, s) { + _logger.severe("Error bulk deleting specific embeddings", e, s); + rethrow; + } + } + + Future deleteAllEmbeddings() async { + final db = await _vectorDB; + try { + await db.resetIndex(); + } catch (e, s) { + _logger.severe("Error deleting all embeddings", e, s); + rethrow; + } + } + + Future<(Uint64List, Float32List)> searchClosestVectors( + List query, + int count, + ) async { + final db = await _vectorDB; + try { + final result = + await db.searchVectors(query: query, count: BigInt.from(count)); + return result; + } catch (e, s) { + _logger.severe("Error searching closest vectors", e, s); + rethrow; + } + } + + Future<(BigInt, double)> searchClosestVector( + List query, + ) async { + final db = await _vectorDB; + try { + final result = await db.searchVectors(query: query, count: BigInt.one); + return (result.$1[0], result.$2[0]); + } catch (e, s) { + _logger.severe("Error searching closest vector", e, s); + rethrow; + } + } +}