[mob] Add new table for clip embedding

This commit is contained in:
Neeraj Gupta
2024-07-22 11:36:31 +05:30
parent 32912be9c2
commit 68d10be79e
3 changed files with 46 additions and 35 deletions

View File

@@ -6,7 +6,6 @@ import 'package:path_provider/path_provider.dart';
import "package:photos/core/event_bus.dart";
import "package:photos/events/embedding_updated_event.dart";
import "package:photos/models/embedding.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:sqlite_async/sqlite_async.dart";
class EmbeddingsDB {
@@ -15,9 +14,13 @@ class EmbeddingsDB {
static final EmbeddingsDB instance = EmbeddingsDB._privateConstructor();
static const databaseName = "ente.embeddings.db";
static const tableName = "embeddings";
static const tableName = "clip_embedding";
static const oldTableName = "embeddings";
static const columnFileID = "file_id";
static const columnEmbedding = "embedding";
static const columnVersion = "version";
@Deprecated("")
static const columnUpdationTime = "updation_time";
static Future<SqliteDatabase>? _dbFuture;
@@ -41,8 +44,21 @@ class EmbeddingsDB {
SqliteMigration(
1,
(tx) async {
// Avoid creating the old table
// await tx.execute(
// 'CREATE TABLE $oldTableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))',
// );
},
),
)
..add(
SqliteMigration(
2,
(tx) async {
// delete old table
await tx.execute('DROP TABLE IF EXISTS $oldTableName');
await tx.execute(
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))',
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnVersion INTEGER, UNIQUE ($columnFileID))',
);
},
),
@@ -57,7 +73,7 @@ class EmbeddingsDB {
await db.execute('DELETE FROM $tableName');
}
Future<List<Embedding>> getAll() async {
Future<List<ClipEmbedding>> getAll() async {
final db = await _database;
final results = await db.getAll('SELECT * FROM $tableName');
return _convertToEmbeddings(results);
@@ -66,11 +82,11 @@ class EmbeddingsDB {
// Get indexed FileIDs
Future<Map<int, int>> getIndexedFileIds() async {
final db = await _database;
final maps = await db.getAll('SELECT $columnFileID FROM $tableName');
final maps = await db
.getAll('SELECT $columnFileID , $columnVersion FROM $tableName');
final Map<int, int> result = {};
for (final map in maps) {
result[map[columnFileID] as int] =
clipMlVersion; // TODO: Add an actual column for version
result[map[columnFileID] as int] = map[columnVersion] as int;
}
return result;
}
@@ -84,33 +100,25 @@ class EmbeddingsDB {
return maps.first['count'] as int;
}
Future<void> put(Embedding embedding) async {
Future<void> put(ClipEmbedding embedding) async {
final db = await _database;
await db.execute(
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)',
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) VALUES (?, ?, ?)',
_getRowFromEmbedding(embedding),
);
Bus.instance.fire(EmbeddingUpdatedEvent());
}
Future<void> putMany(List<Embedding> embeddings) async {
Future<void> putMany(List<ClipEmbedding> embeddings) async {
final db = await _database;
final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList();
await db.executeBatch(
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)',
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) values(?, ?, ?)',
inputs,
);
Bus.instance.fire(EmbeddingUpdatedEvent());
}
Future<List<Embedding>> getUnsyncedEmbeddings() async {
final db = await _database;
final results = await db.getAll(
'SELECT * FROM $tableName WHERE $columnUpdationTime IS NULL',
);
return _convertToEmbeddings(results);
}
Future<void> deleteEmbeddings(List<int> fileIDs) async {
final db = await _database;
await db.execute(
@@ -125,8 +133,8 @@ class EmbeddingsDB {
Bus.instance.fire(EmbeddingUpdatedEvent());
}
List<Embedding> _convertToEmbeddings(List<Map<String, dynamic>> results) {
final List<Embedding> embeddings = [];
List<ClipEmbedding> _convertToEmbeddings(List<Map<String, dynamic>> results) {
final List<ClipEmbedding> embeddings = [];
for (final result in results) {
final embedding = _getEmbeddingFromRow(result);
if (embedding.isEmpty) continue;
@@ -135,18 +143,19 @@ class EmbeddingsDB {
return embeddings;
}
Embedding _getEmbeddingFromRow(Map<String, dynamic> row) {
ClipEmbedding _getEmbeddingFromRow(Map<String, dynamic> row) {
final fileID = row[columnFileID];
final bytes = row[columnEmbedding] as Uint8List;
final version = row[columnVersion] as int;
final list = Float32List.view(bytes.buffer);
return Embedding(fileID: fileID, embedding: list);
return ClipEmbedding(fileID: fileID, embedding: list, version: version);
}
List<Object?> _getRowFromEmbedding(Embedding embedding) {
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
return [
embedding.fileID,
Float32List.fromList(embedding.embedding).buffer.asUint8List(),
embedding.updationTime,
embedding.version,
];
}

View File

@@ -1,22 +1,23 @@
import "dart:convert";
class Embedding {
class ClipEmbedding {
final int fileID;
final List<double> embedding;
int? updationTime;
int version;
bool get isEmpty => embedding.isEmpty;
Embedding({
ClipEmbedding({
required this.fileID,
required this.embedding,
this.updationTime,
required this.version,
});
factory Embedding.empty(int fileID) {
return Embedding(
factory ClipEmbedding.empty(int fileID) {
return ClipEmbedding(
fileID: fileID,
embedding: <double>[],
version: 0,
);
}

View File

@@ -43,7 +43,7 @@ class SemanticSearchService {
bool _hasInitialized = false;
bool _textModelIsLoaded = false;
bool _isSyncing = false;
List<Embedding> _cachedImageEmbeddings = <Embedding>[];
List<ClipEmbedding> _cachedImageEmbeddings = <ClipEmbedding>[];
Future<(String, List<EnteFile>)>? _searchScreenRequest;
String? _latestPendingQuery;
@@ -274,15 +274,16 @@ class SemanticSearchService {
ClipResult clipResult,
EnteFile entefile,
) async {
final embedding = Embedding(
final embedding = ClipEmbedding(
fileID: clipResult.fileID,
embedding: clipResult.embedding,
version: clipMlVersion,
);
await EmbeddingsDB.instance.put(embedding);
}
static Future<void> storeEmptyClipImageResult(EnteFile entefile) async {
final embedding = Embedding.empty(entefile.uploadedFileID!);
final embedding = ClipEmbedding.empty(entefile.uploadedFileID!);
await EmbeddingsDB.instance.put(embedding);
}
@@ -363,7 +364,7 @@ class SemanticSearchService {
List<QueryResult> computeBulkSimilarities(Map args) {
final queryResults = <QueryResult>[];
final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
final imageEmbeddings = args["imageEmbeddings"] as List<ClipEmbedding>;
final textEmbedding = args["textEmbedding"] as List<double>;
final minimumSimilarity = args["minimumSimilarity"] ??
SemanticSearchService.kMinimumSimilarityThreshold;