diff --git a/mobile/lib/db/ml/clip_db.dart b/mobile/lib/db/ml/clip_db.dart index 477be3de83..f160924d52 100644 --- a/mobile/lib/db/ml/clip_db.dart +++ b/mobile/lib/db/ml/clip_db.dart @@ -7,16 +7,23 @@ import "package:photos/db/ml/db_fields.dart"; import "package:photos/events/embedding_updated_event.dart"; import "package:photos/models/ml/clip.dart"; import "package:photos/models/ml/ml_versions.dart"; +import "package:photos/models/ml/vector.dart"; extension ClipDB on MLDataDB { static const databaseName = "ente.embeddings.db"; - Future> getAll() async { + Future> getAllClipEmbeddings() async { final db = await MLDataDB.instance.asyncDB; final results = await db.getAll('SELECT * FROM $clipTable'); return _convertToEmbeddings(results); } + Future> getAllClipVectors() async { + final db = await MLDataDB.instance.asyncDB; + final results = await db.getAll('SELECT * FROM $clipTable'); + return _convertToVectors(results); + } + // Get indexed FileIDs Future> clipIndexedFileWithVersion() async { final db = await MLDataDB.instance.asyncDB; @@ -83,6 +90,16 @@ extension ClipDB on MLDataDB { return embeddings; } + List _convertToVectors(List> results) { + final List embeddings = []; + for (final result in results) { + final embedding = _getVectorFromRow(result); + if (embedding.isEmpty) continue; + embeddings.add(embedding); + } + return embeddings; + } + ClipEmbedding _getEmbeddingFromRow(Map row) { final fileID = row[fileIDColumn] as int; final bytes = row[embeddingColumn] as Uint8List; @@ -91,6 +108,13 @@ extension ClipDB on MLDataDB { return ClipEmbedding(fileID: fileID, embedding: list, version: version); } + EmbeddingVector _getVectorFromRow(Map row) { + final fileID = row[fileIDColumn] as int; + final bytes = row[embeddingColumn] as Uint8List; + final list = Float32List.view(bytes.buffer); + return EmbeddingVector(fileID: fileID, embedding: list); + } + List _getRowFromEmbedding(ClipEmbedding embedding) { return [ embedding.fileID, diff --git a/mobile/lib/models/ml/vector.dart b/mobile/lib/models/ml/vector.dart new file mode 100644 index 0000000000..1b0fd21491 --- /dev/null +++ b/mobile/lib/models/ml/vector.dart @@ -0,0 +1,23 @@ +import "dart:convert"; + +import "package:ml_linalg/vector.dart"; + +class EmbeddingVector { + final int fileID; + final Vector vector; + + bool get isEmpty => vector.isEmpty; + + EmbeddingVector({ + required this.fileID, + required List embedding, + }) : vector = Vector.fromList(embedding); + + static Vector decodeEmbedding(String embedding) { + return Vector.fromList(List.from(jsonDecode(embedding) as List)); + } + + static String encodeEmbedding(Vector embedding) { + return jsonEncode(embedding.toList()); + } +} diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 7beb284df7..e145dcb83f 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -6,6 +6,7 @@ import "dart:ui" show Image; import "package:computer/computer.dart"; import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; +import "package:ml_linalg/vector.dart"; import "package:photos/core/cache/lru_map.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; @@ -15,9 +16,9 @@ import 'package:photos/events/embedding_updated_event.dart'; import "package:photos/models/file/file.dart"; import "package:photos/models/ml/clip.dart"; import "package:photos/models/ml/ml_versions.dart"; +import "package:photos/models/ml/vector.dart"; import "package:photos/service_locator.dart"; import "package:photos/services/collections_service.dart"; -import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart"; import "package:photos/services/machine_learning/ml_computer.dart"; import "package:photos/services/machine_learning/ml_result.dart"; import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; @@ -37,7 +38,7 @@ class SemanticSearchService { bool _hasInitialized = false; bool _textModelIsLoaded = false; - Future>? _cachedImageEmbeddings; + Future>? _cachedImageEmbeddingVectors; Future<(String, List)>? _searchScreenRequest; String? _latestPendingQuery; @@ -53,10 +54,10 @@ class SemanticSearchService { // call getClipEmbeddings after 5 seconds Future.delayed(const Duration(seconds: 5), () async { - await getClipEmbeddings(); + await getClipVectors(); }); Bus.instance.on().listen((event) { - _cachedImageEmbeddings = null; + _cachedImageEmbeddingVectors = null; }); unawaited(_loadTextModel(delay: true)); @@ -106,10 +107,10 @@ class SemanticSearchService { _logger.info("Indexes cleared"); } - Future> getClipEmbeddings() async { - _logger.info("Pulling cached embeddings"); - _cachedImageEmbeddings ??= MLDataDB.instance.getAll(); - return _cachedImageEmbeddings!; + Future> getClipVectors() async { + _logger.info("Pulling cached clip embeddings"); + _cachedImageEmbeddingVectors ??= MLDataDB.instance.getAllClipVectors(); + return _cachedImageEmbeddingVectors!; } Future> getMatchingFiles( @@ -267,7 +268,7 @@ class SemanticSearchService { double? minimumSimilarity, }) async { final startTime = DateTime.now(); - final embeddings = await getClipEmbeddings(); + final embeddings = await getClipVectors(); final List queryResults = await _computer.compute( computeBulkSimilarities, param: { @@ -311,25 +312,33 @@ class SemanticSearchService { List computeBulkSimilarities(Map args) { final queryResults = []; - final imageEmbeddings = args["imageEmbeddings"] as List; + final imageEmbeddings = args["imageEmbeddings"] as List; final textEmbedding = args["textEmbedding"] as List; final minimumSimilarity = args["minimumSimilarity"] ?? SemanticSearchService.kMinimumSimilarityThreshold; - double bestScore = 0.0; - for (final imageEmbedding in imageEmbeddings) { - final score = computeCosineSimilarity( - imageEmbedding.embedding, - textEmbedding, - ); - if (score >= minimumSimilarity) { - queryResults.add(QueryResult(imageEmbedding.fileID, score)); + + final Vector textVector = Vector.fromList(textEmbedding); + if (!kDebugMode) { + for (final imageEmbedding in imageEmbeddings) { + final similarity = imageEmbedding.vector.dot(textVector); + if (similarity >= minimumSimilarity) { + queryResults.add(QueryResult(imageEmbedding.fileID, similarity)); + } } - if (score > bestScore) { - bestScore = score; + } else { + double bestScore = 0.0; + for (final imageEmbedding in imageEmbeddings) { + final similarity = imageEmbedding.vector.dot(textVector); + if (similarity >= minimumSimilarity) { + queryResults.add(QueryResult(imageEmbedding.fileID, similarity)); + } + if (similarity > bestScore) { + bestScore = similarity; + } + } + if (kDebugMode && queryResults.isEmpty) { + dev.log("No results found for query with best score: $bestScore"); } - } - if (kDebugMode && queryResults.isEmpty) { - dev.log("No results found for query with best score: $bestScore"); } queryResults.sort((first, second) => second.score.compareTo(first.score));