[mob] Add new table for clip embedding
This commit is contained in:
@@ -6,7 +6,6 @@ import 'package:path_provider/path_provider.dart';
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/events/embedding_updated_event.dart";
|
||||
import "package:photos/models/embedding.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:sqlite_async/sqlite_async.dart";
|
||||
|
||||
class EmbeddingsDB {
|
||||
@@ -15,9 +14,13 @@ class EmbeddingsDB {
|
||||
static final EmbeddingsDB instance = EmbeddingsDB._privateConstructor();
|
||||
|
||||
static const databaseName = "ente.embeddings.db";
|
||||
static const tableName = "embeddings";
|
||||
static const tableName = "clip_embedding";
|
||||
static const oldTableName = "embeddings";
|
||||
static const columnFileID = "file_id";
|
||||
static const columnEmbedding = "embedding";
|
||||
static const columnVersion = "version";
|
||||
|
||||
@Deprecated("")
|
||||
static const columnUpdationTime = "updation_time";
|
||||
|
||||
static Future<SqliteDatabase>? _dbFuture;
|
||||
@@ -41,8 +44,21 @@ class EmbeddingsDB {
|
||||
SqliteMigration(
|
||||
1,
|
||||
(tx) async {
|
||||
// Avoid creating the old table
|
||||
// await tx.execute(
|
||||
// 'CREATE TABLE $oldTableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))',
|
||||
// );
|
||||
},
|
||||
),
|
||||
)
|
||||
..add(
|
||||
SqliteMigration(
|
||||
2,
|
||||
(tx) async {
|
||||
// delete old table
|
||||
await tx.execute('DROP TABLE IF EXISTS $oldTableName');
|
||||
await tx.execute(
|
||||
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))',
|
||||
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnVersion INTEGER, UNIQUE ($columnFileID))',
|
||||
);
|
||||
},
|
||||
),
|
||||
@@ -57,7 +73,7 @@ class EmbeddingsDB {
|
||||
await db.execute('DELETE FROM $tableName');
|
||||
}
|
||||
|
||||
Future<List<Embedding>> getAll() async {
|
||||
Future<List<ClipEmbedding>> getAll() async {
|
||||
final db = await _database;
|
||||
final results = await db.getAll('SELECT * FROM $tableName');
|
||||
return _convertToEmbeddings(results);
|
||||
@@ -66,11 +82,11 @@ class EmbeddingsDB {
|
||||
// Get indexed FileIDs
|
||||
Future<Map<int, int>> getIndexedFileIds() async {
|
||||
final db = await _database;
|
||||
final maps = await db.getAll('SELECT $columnFileID FROM $tableName');
|
||||
final maps = await db
|
||||
.getAll('SELECT $columnFileID , $columnVersion FROM $tableName');
|
||||
final Map<int, int> result = {};
|
||||
for (final map in maps) {
|
||||
result[map[columnFileID] as int] =
|
||||
clipMlVersion; // TODO: Add an actual column for version
|
||||
result[map[columnFileID] as int] = map[columnVersion] as int;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -84,33 +100,25 @@ class EmbeddingsDB {
|
||||
return maps.first['count'] as int;
|
||||
}
|
||||
|
||||
Future<void> put(Embedding embedding) async {
|
||||
Future<void> put(ClipEmbedding embedding) async {
|
||||
final db = await _database;
|
||||
await db.execute(
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)',
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) VALUES (?, ?, ?)',
|
||||
_getRowFromEmbedding(embedding),
|
||||
);
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
}
|
||||
|
||||
Future<void> putMany(List<Embedding> embeddings) async {
|
||||
Future<void> putMany(List<ClipEmbedding> embeddings) async {
|
||||
final db = await _database;
|
||||
final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList();
|
||||
await db.executeBatch(
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)',
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) values(?, ?, ?)',
|
||||
inputs,
|
||||
);
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
}
|
||||
|
||||
Future<List<Embedding>> getUnsyncedEmbeddings() async {
|
||||
final db = await _database;
|
||||
final results = await db.getAll(
|
||||
'SELECT * FROM $tableName WHERE $columnUpdationTime IS NULL',
|
||||
);
|
||||
return _convertToEmbeddings(results);
|
||||
}
|
||||
|
||||
Future<void> deleteEmbeddings(List<int> fileIDs) async {
|
||||
final db = await _database;
|
||||
await db.execute(
|
||||
@@ -125,8 +133,8 @@ class EmbeddingsDB {
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
}
|
||||
|
||||
List<Embedding> _convertToEmbeddings(List<Map<String, dynamic>> results) {
|
||||
final List<Embedding> embeddings = [];
|
||||
List<ClipEmbedding> _convertToEmbeddings(List<Map<String, dynamic>> results) {
|
||||
final List<ClipEmbedding> embeddings = [];
|
||||
for (final result in results) {
|
||||
final embedding = _getEmbeddingFromRow(result);
|
||||
if (embedding.isEmpty) continue;
|
||||
@@ -135,18 +143,19 @@ class EmbeddingsDB {
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
Embedding _getEmbeddingFromRow(Map<String, dynamic> row) {
|
||||
ClipEmbedding _getEmbeddingFromRow(Map<String, dynamic> row) {
|
||||
final fileID = row[columnFileID];
|
||||
final bytes = row[columnEmbedding] as Uint8List;
|
||||
final version = row[columnVersion] as int;
|
||||
final list = Float32List.view(bytes.buffer);
|
||||
return Embedding(fileID: fileID, embedding: list);
|
||||
return ClipEmbedding(fileID: fileID, embedding: list, version: version);
|
||||
}
|
||||
|
||||
List<Object?> _getRowFromEmbedding(Embedding embedding) {
|
||||
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
|
||||
return [
|
||||
embedding.fileID,
|
||||
Float32List.fromList(embedding.embedding).buffer.asUint8List(),
|
||||
embedding.updationTime,
|
||||
embedding.version,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import "dart:convert";
|
||||
|
||||
class Embedding {
|
||||
class ClipEmbedding {
|
||||
final int fileID;
|
||||
final List<double> embedding;
|
||||
int? updationTime;
|
||||
int version;
|
||||
|
||||
bool get isEmpty => embedding.isEmpty;
|
||||
|
||||
Embedding({
|
||||
ClipEmbedding({
|
||||
required this.fileID,
|
||||
required this.embedding,
|
||||
this.updationTime,
|
||||
required this.version,
|
||||
});
|
||||
|
||||
factory Embedding.empty(int fileID) {
|
||||
return Embedding(
|
||||
factory ClipEmbedding.empty(int fileID) {
|
||||
return ClipEmbedding(
|
||||
fileID: fileID,
|
||||
embedding: <double>[],
|
||||
version: 0,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class SemanticSearchService {
|
||||
bool _hasInitialized = false;
|
||||
bool _textModelIsLoaded = false;
|
||||
bool _isSyncing = false;
|
||||
List<Embedding> _cachedImageEmbeddings = <Embedding>[];
|
||||
List<ClipEmbedding> _cachedImageEmbeddings = <ClipEmbedding>[];
|
||||
Future<(String, List<EnteFile>)>? _searchScreenRequest;
|
||||
String? _latestPendingQuery;
|
||||
|
||||
@@ -274,15 +274,16 @@ class SemanticSearchService {
|
||||
ClipResult clipResult,
|
||||
EnteFile entefile,
|
||||
) async {
|
||||
final embedding = Embedding(
|
||||
final embedding = ClipEmbedding(
|
||||
fileID: clipResult.fileID,
|
||||
embedding: clipResult.embedding,
|
||||
version: clipMlVersion,
|
||||
);
|
||||
await EmbeddingsDB.instance.put(embedding);
|
||||
}
|
||||
|
||||
static Future<void> storeEmptyClipImageResult(EnteFile entefile) async {
|
||||
final embedding = Embedding.empty(entefile.uploadedFileID!);
|
||||
final embedding = ClipEmbedding.empty(entefile.uploadedFileID!);
|
||||
await EmbeddingsDB.instance.put(embedding);
|
||||
}
|
||||
|
||||
@@ -363,7 +364,7 @@ class SemanticSearchService {
|
||||
|
||||
List<QueryResult> computeBulkSimilarities(Map args) {
|
||||
final queryResults = <QueryResult>[];
|
||||
final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
|
||||
final imageEmbeddings = args["imageEmbeddings"] as List<ClipEmbedding>;
|
||||
final textEmbedding = args["textEmbedding"] as List<double>;
|
||||
final minimumSimilarity = args["minimumSimilarity"] ??
|
||||
SemanticSearchService.kMinimumSimilarityThreshold;
|
||||
|
||||
Reference in New Issue
Block a user