diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index f33e1ef3ef..cc598da25c 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -20,6 +20,7 @@ import "package:photos/models/ml/face/face.dart"; import "package:photos/models/ml/face/face_with_embedding.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/models/ml/vector.dart"; +import "package:photos/service_locator.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import "package:photos/services/machine_learning/ml_result.dart"; @@ -84,7 +85,7 @@ class MLDataDB with SqlDbBase implements IMLDataDB { ); stopwatch.stop(); _logger.info("Starting CLIP vector DB migration check unawaited"); - unawaited(checkMigrateFillClipVectorDB()); + if (flagService.enableVectorDb) unawaited(checkMigrateFillClipVectorDB()); return asyncDBConnection; } @@ -1335,21 +1336,25 @@ class MLDataDB with SqlDbBase implements IMLDataDB { 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) VALUES (?, ?, ?)', _getRowFromEmbedding(embeddings.first), ); - await ClipVectorDB.instance.insertEmbedding( - fileID: embeddings.first.fileID, - embedding: embeddings.first.embedding, - ); + if (flagService.enableVectorDb) { + await ClipVectorDB.instance.insertEmbedding( + fileID: embeddings.first.fileID, + embedding: embeddings.first.embedding, + ); + } } else { final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList(); await db.executeBatch( 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) values(?, ?, ?)', inputs, ); - await ClipVectorDB.instance.bulkInsertEmbeddings( - fileIDs: embeddings.map((e) => e.fileID).toList(), - embeddings: - embeddings.map((e) => Float32List.fromList(e.embedding)).toList(), - ); + if (flagService.enableVectorDb) { + await ClipVectorDB.instance.bulkInsertEmbeddings( + fileIDs: embeddings.map((e) => e.fileID).toList(), + embeddings: + embeddings.map((e) => Float32List.fromList(e.embedding)).toList(), + ); + } } Bus.instance.fire(EmbeddingUpdatedEvent()); } @@ -1360,7 +1365,9 @@ class MLDataDB with SqlDbBase implements IMLDataDB { await db.execute( 'DELETE FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})', ); - await ClipVectorDB.instance.deleteEmbeddings(fileIDs); + if (flagService.enableVectorDb) { + await ClipVectorDB.instance.deleteEmbeddings(fileIDs); + } Bus.instance.fire(EmbeddingUpdatedEvent()); } @@ -1368,7 +1375,9 @@ class MLDataDB with SqlDbBase implements IMLDataDB { Future deleteClipIndexes() async { final db = await instance.asyncDB; await db.execute('DELETE FROM $clipTable'); - await ClipVectorDB.instance.deleteAllEmbeddings(); + if (flagService.enableVectorDb) { + await ClipVectorDB.instance.deleteAllEmbeddings(); + } Bus.instance.fire(EmbeddingUpdatedEvent()); } diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 5923ccc7b7..f2b18e515b 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -266,13 +266,25 @@ class SemanticSearchService { required Map minimumSimilarityMap, }) async { final startTime = DateTime.now(); - // TODO: lau: remove this when we feel confident about vector DB - // await _cacheClipVectors(); - // final Map> queryResults = await MLComputer - // .instance - // .computeBulkSimilarities(textQueryToEmbeddingMap, minimumSimilarityMap); - final queryResults = await ClipVectorDB.instance - .computeBulkSimilarities(textQueryToEmbeddingMap, minimumSimilarityMap); + if (kDebugMode) { + for (final queryText in textQueryToEmbeddingMap.keys) { + final embedding = textQueryToEmbeddingMap[queryText]!; + dev.log("CLIPTEXT Query: $queryText, embedding: $embedding"); + } + } + late final Map> queryResults; + if (flagService.enableVectorDb) { + queryResults = await ClipVectorDB.instance.computeBulkSimilarities( + textQueryToEmbeddingMap, + minimumSimilarityMap, + ); + } else { + await _cacheClipVectors(); + queryResults = await MLComputer.instance.computeBulkSimilarities( + textQueryToEmbeddingMap, + minimumSimilarityMap, + ); + } final endTime = DateTime.now(); _logger.info( "computingSimilarities took for ${textQueryToEmbeddingMap.length} queries " + diff --git a/mobile/plugins/ente_feature_flag/lib/src/service.dart b/mobile/plugins/ente_feature_flag/lib/src/service.dart index 376f6b6e48..dddd32aab4 100644 --- a/mobile/plugins/ente_feature_flag/lib/src/service.dart +++ b/mobile/plugins/ente_feature_flag/lib/src/service.dart @@ -58,6 +58,8 @@ class FlagService { bool get enableMobMultiPart => flags.enableMobMultiPart || internalUser; + bool get enableVectorDb => flags.internalUser; + String get castUrl => flags.castUrl; Future setMapEnabled(bool isEnabled) async {