From 68d10be79eb6d4ed3dfd138087ad435d849b5077 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:36:31 +0530 Subject: [PATCH] [mob] Add new table for clip embedding --- mobile/lib/db/embeddings_db.dart | 59 +++++++++++-------- mobile/lib/models/embedding.dart | 13 ++-- .../semantic_search_service.dart | 9 +-- 3 files changed, 46 insertions(+), 35 deletions(-) diff --git a/mobile/lib/db/embeddings_db.dart b/mobile/lib/db/embeddings_db.dart index 2b10413933..1cdd331c30 100644 --- a/mobile/lib/db/embeddings_db.dart +++ b/mobile/lib/db/embeddings_db.dart @@ -6,7 +6,6 @@ import 'package:path_provider/path_provider.dart'; import "package:photos/core/event_bus.dart"; import "package:photos/events/embedding_updated_event.dart"; import "package:photos/models/embedding.dart"; -import "package:photos/models/ml/ml_versions.dart"; import "package:sqlite_async/sqlite_async.dart"; class EmbeddingsDB { @@ -15,9 +14,13 @@ class EmbeddingsDB { static final EmbeddingsDB instance = EmbeddingsDB._privateConstructor(); static const databaseName = "ente.embeddings.db"; - static const tableName = "embeddings"; + static const tableName = "clip_embedding"; + static const oldTableName = "embeddings"; static const columnFileID = "file_id"; static const columnEmbedding = "embedding"; + static const columnVersion = "version"; + + @Deprecated("") static const columnUpdationTime = "updation_time"; static Future? _dbFuture; @@ -41,8 +44,21 @@ class EmbeddingsDB { SqliteMigration( 1, (tx) async { + // Avoid creating the old table + // await tx.execute( + // 'CREATE TABLE $oldTableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))', + // ); + }, + ), + ) + ..add( + SqliteMigration( + 2, + (tx) async { + // delete old table + await tx.execute('DROP TABLE IF EXISTS $oldTableName'); await tx.execute( - 'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))', + 'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnVersion INTEGER, UNIQUE ($columnFileID))', ); }, ), @@ -57,7 +73,7 @@ class EmbeddingsDB { await db.execute('DELETE FROM $tableName'); } - Future> getAll() async { + Future> getAll() async { final db = await _database; final results = await db.getAll('SELECT * FROM $tableName'); return _convertToEmbeddings(results); @@ -66,11 +82,11 @@ class EmbeddingsDB { // Get indexed FileIDs Future> getIndexedFileIds() async { final db = await _database; - final maps = await db.getAll('SELECT $columnFileID FROM $tableName'); + final maps = await db + .getAll('SELECT $columnFileID , $columnVersion FROM $tableName'); final Map result = {}; for (final map in maps) { - result[map[columnFileID] as int] = - clipMlVersion; // TODO: Add an actual column for version + result[map[columnFileID] as int] = map[columnVersion] as int; } return result; } @@ -84,33 +100,25 @@ class EmbeddingsDB { return maps.first['count'] as int; } - Future put(Embedding embedding) async { + Future put(ClipEmbedding embedding) async { final db = await _database; await db.execute( - 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)', + 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) VALUES (?, ?, ?)', _getRowFromEmbedding(embedding), ); Bus.instance.fire(EmbeddingUpdatedEvent()); } - Future putMany(List embeddings) async { + Future putMany(List embeddings) async { final db = await _database; final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList(); await db.executeBatch( - 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)', + 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) values(?, ?, ?)', inputs, ); Bus.instance.fire(EmbeddingUpdatedEvent()); } - Future> getUnsyncedEmbeddings() async { - final db = await _database; - final results = await db.getAll( - 'SELECT * FROM $tableName WHERE $columnUpdationTime IS NULL', - ); - return _convertToEmbeddings(results); - } - Future deleteEmbeddings(List fileIDs) async { final db = await _database; await db.execute( @@ -125,8 +133,8 @@ class EmbeddingsDB { Bus.instance.fire(EmbeddingUpdatedEvent()); } - List _convertToEmbeddings(List> results) { - final List embeddings = []; + List _convertToEmbeddings(List> results) { + final List embeddings = []; for (final result in results) { final embedding = _getEmbeddingFromRow(result); if (embedding.isEmpty) continue; @@ -135,18 +143,19 @@ class EmbeddingsDB { return embeddings; } - Embedding _getEmbeddingFromRow(Map row) { + ClipEmbedding _getEmbeddingFromRow(Map row) { final fileID = row[columnFileID]; final bytes = row[columnEmbedding] as Uint8List; + final version = row[columnVersion] as int; final list = Float32List.view(bytes.buffer); - return Embedding(fileID: fileID, embedding: list); + return ClipEmbedding(fileID: fileID, embedding: list, version: version); } - List _getRowFromEmbedding(Embedding embedding) { + List _getRowFromEmbedding(ClipEmbedding embedding) { return [ embedding.fileID, Float32List.fromList(embedding.embedding).buffer.asUint8List(), - embedding.updationTime, + embedding.version, ]; } diff --git a/mobile/lib/models/embedding.dart b/mobile/lib/models/embedding.dart index 60554538b3..0fb054b3bb 100644 --- a/mobile/lib/models/embedding.dart +++ b/mobile/lib/models/embedding.dart @@ -1,22 +1,23 @@ import "dart:convert"; -class Embedding { +class ClipEmbedding { final int fileID; final List embedding; - int? updationTime; + int version; bool get isEmpty => embedding.isEmpty; - Embedding({ + ClipEmbedding({ required this.fileID, required this.embedding, - this.updationTime, + required this.version, }); - factory Embedding.empty(int fileID) { - return Embedding( + factory ClipEmbedding.empty(int fileID) { + return ClipEmbedding( fileID: fileID, embedding: [], + version: 0, ); } diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 7a5ae8cb65..c01c42ae95 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -43,7 +43,7 @@ class SemanticSearchService { bool _hasInitialized = false; bool _textModelIsLoaded = false; bool _isSyncing = false; - List _cachedImageEmbeddings = []; + List _cachedImageEmbeddings = []; Future<(String, List)>? _searchScreenRequest; String? _latestPendingQuery; @@ -274,15 +274,16 @@ class SemanticSearchService { ClipResult clipResult, EnteFile entefile, ) async { - final embedding = Embedding( + final embedding = ClipEmbedding( fileID: clipResult.fileID, embedding: clipResult.embedding, + version: clipMlVersion, ); await EmbeddingsDB.instance.put(embedding); } static Future storeEmptyClipImageResult(EnteFile entefile) async { - final embedding = Embedding.empty(entefile.uploadedFileID!); + final embedding = ClipEmbedding.empty(entefile.uploadedFileID!); await EmbeddingsDB.instance.put(embedding); } @@ -363,7 +364,7 @@ class SemanticSearchService { List computeBulkSimilarities(Map args) { final queryResults = []; - final imageEmbeddings = args["imageEmbeddings"] as List; + final imageEmbeddings = args["imageEmbeddings"] as List; final textEmbedding = args["textEmbedding"] as List; final minimumSimilarity = args["minimumSimilarity"] ?? SemanticSearchService.kMinimumSimilarityThreshold;