diff --git a/mobile/lib/db/embeddings_db.dart b/mobile/lib/db/embeddings_db.dart index 4da5c056a5..0d98314d8c 100644 --- a/mobile/lib/db/embeddings_db.dart +++ b/mobile/lib/db/embeddings_db.dart @@ -16,7 +16,6 @@ class EmbeddingsDB { static const databaseName = "ente.embeddings.db"; static const tableName = "embeddings"; static const columnFileID = "file_id"; - static const columnModel = "model"; static const columnEmbedding = "embedding"; static const columnUpdationTime = "updation_time"; @@ -42,7 +41,7 @@ class EmbeddingsDB { 1, (tx) async { await tx.execute( - 'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnModel INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID, $columnModel))', + 'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))', ); }, ), @@ -57,19 +56,16 @@ class EmbeddingsDB { await db.execute('DELETE FROM $tableName'); } - Future> getAll(Model model) async { + Future> getAll() async { final db = await _database; final results = await db.getAll('SELECT * FROM $tableName'); return _convertToEmbeddings(results); } // Get FileIDs for a specific model - Future> getFileIDs(Model model) async { + Future> getFileIDs() async { final db = await _database; - final results = await db.getAll( - 'SELECT $columnFileID FROM $tableName WHERE $columnModel = ?', - [modelToInt(model)!], - ); + final results = await db.getAll('SELECT $columnFileID FROM $tableName'); if (results.isEmpty) { return {}; } @@ -79,7 +75,7 @@ class EmbeddingsDB { Future put(Embedding embedding) async { final db = await _database; await db.execute( - 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)', + 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)', _getRowFromEmbedding(embedding), ); Bus.instance.fire(EmbeddingUpdatedEvent()); @@ -89,7 +85,7 @@ class EmbeddingsDB { final db = await _database; final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList(); await db.executeBatch( - 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)', + 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)', inputs, ); Bus.instance.fire(EmbeddingUpdatedEvent()); @@ -111,12 +107,9 @@ class EmbeddingsDB { Bus.instance.fire(EmbeddingUpdatedEvent()); } - Future deleteAllForModel(Model model) async { + Future deleteAll() async { final db = await _database; - await db.execute( - 'DELETE FROM $tableName WHERE $columnModel = ?', - [modelToInt(model)!], - ); + await db.execute('DELETE FROM $tableName'); Bus.instance.fire(EmbeddingUpdatedEvent()); } @@ -132,16 +125,14 @@ class EmbeddingsDB { Embedding _getEmbeddingFromRow(Map row) { final fileID = row[columnFileID]; - final model = intToModel(row[columnModel])!; final bytes = row[columnEmbedding] as Uint8List; final list = Float32List.view(bytes.buffer); - return Embedding(fileID: fileID, model: model, embedding: list); + return Embedding(fileID: fileID, embedding: list); } List _getRowFromEmbedding(Embedding embedding) { return [ embedding.fileID, - modelToInt(embedding.model)!, Float32List.fromList(embedding.embedding).buffer.asUint8List(), embedding.updationTime, ]; @@ -157,26 +148,4 @@ class EmbeddingsDB { await deprecatedIsar.delete(); } } - - int? modelToInt(Model model) { - switch (model) { - case Model.onnxClip: - return 1; - case Model.ggmlClip: - return 2; - default: - return null; - } - } - - Model? intToModel(int model) { - switch (model) { - case 1: - return Model.onnxClip; - case 2: - return Model.ggmlClip; - default: - return null; - } - } } diff --git a/mobile/lib/models/embedding.dart b/mobile/lib/models/embedding.dart index 91ac9a0213..60554538b3 100644 --- a/mobile/lib/models/embedding.dart +++ b/mobile/lib/models/embedding.dart @@ -2,7 +2,6 @@ import "dart:convert"; class Embedding { final int fileID; - final Model model; final List embedding; int? updationTime; @@ -10,15 +9,13 @@ class Embedding { Embedding({ required this.fileID, - required this.model, required this.embedding, this.updationTime, }); - factory Embedding.empty(int fileID, Model model) { + factory Embedding.empty(int fileID) { return Embedding( fileID: fileID, - model: model, embedding: [], ); } @@ -31,34 +28,3 @@ class Embedding { return jsonEncode(embedding); } } - -enum Model { - onnxClip, - ggmlClip, -} - -extension ModelExtension on Model { - String get name => serialize(this); -} - -String serialize(Model model) { - switch (model) { - case Model.onnxClip: - return 'onnx-clip'; - case Model.ggmlClip: - return 'ggml-clip'; - default: - throw Exception('$model is not a valid Model'); - } -} - -Model deserialize(String model) { - switch (model) { - case 'onnx-clip': - return Model.onnxClip; - case 'ggml-clip': - return Model.ggmlClip; - default: - throw Exception('$model is not a valid Model'); - } -} diff --git a/mobile/lib/services/machine_learning/semantic_search/embedding_store.dart b/mobile/lib/services/machine_learning/semantic_search/embedding_store.dart index 00859a20ce..e2acbb6871 100644 --- a/mobile/lib/services/machine_learning/semantic_search/embedding_store.dart +++ b/mobile/lib/services/machine_learning/semantic_search/embedding_store.dart @@ -33,17 +33,16 @@ class EmbeddingStore { _preferences = await SharedPreferences.getInstance(); } - Future pullEmbeddings(Model model) async { - return true; // TODO: remove this + Future pullEmbeddings() async { if (_remoteSyncStatus != null) { return _remoteSyncStatus!.future; } _remoteSyncStatus = Completer(); try { - var remoteEmbeddings = await _getRemoteEmbeddings(model); + var remoteEmbeddings = await _getRemoteEmbeddings(); await _storeRemoteEmbeddings(remoteEmbeddings.embeddings); while (remoteEmbeddings.hasMore) { - remoteEmbeddings = await _getRemoteEmbeddings(model); + remoteEmbeddings = await _getRemoteEmbeddings(); await _storeRemoteEmbeddings(remoteEmbeddings.embeddings); } _remoteSyncStatus!.complete(true); @@ -85,8 +84,8 @@ class EmbeddingStore { unawaited(_pushEmbedding(file, embedding)); } - Future clearEmbeddings(Model model) async { - await EmbeddingsDB.instance.deleteAllForModel(model); + Future clearEmbeddings() async { + await EmbeddingsDB.instance.deleteAll(); await _preferences.remove(kEmbeddingsSyncTimeKey); } @@ -106,7 +105,6 @@ class EmbeddingStore { "/embeddings", data: { "fileID": embedding.fileID, - "model": embedding.model.name, "encryptedEmbedding": encryptedData, "decryptionHeader": header, }, @@ -119,10 +117,7 @@ class EmbeddingStore { } } - Future _getRemoteEmbeddings( - Model model, { - int limit = 200, - }) async { + Future _getRemoteEmbeddings({int limit = 200}) async { final remoteEmbeddings = []; try { final sinceTime = _preferences.getInt(kEmbeddingsSyncTimeKey) ?? 0; @@ -130,7 +125,6 @@ class EmbeddingStore { final response = await _dio.get( "/embeddings/diff", queryParameters: { - "model": model.name, "sinceTime": sinceTime, "limit": limit, }, @@ -212,7 +206,6 @@ Future> _decodeEmbeddings(Map args) async { embeddings.add( Embedding( fileID: input.embedding.fileID, - model: deserialize(input.embedding.model), embedding: decodedEmbedding, updationTime: input.embedding.updatedAt, ),