diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 625e21ab8e..5f7f509355 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -236,10 +236,14 @@ class SemanticSearchService { _queue.clear(); } - Future> _getMatchingFiles(String query) async { + Future> _getMatchingFiles( + String query, { + double? scoreThreshold, + }) async { final textEmbedding = await _getTextEmbedding(query); - final queryResults = await _getScores(textEmbedding); + final queryResults = + await _getScores(textEmbedding, scoreThreshold: scoreThreshold); final filesMap = await FilesDB.instance .getFilesFromIDs(queryResults.map((e) => e.id).toList()); @@ -355,13 +359,17 @@ class SemanticSearchService { } } - Future> _getScores(List textEmbedding) async { + Future> _getScores( + List textEmbedding, { + double? scoreThreshold, + }) async { final startTime = DateTime.now(); final List queryResults = await _computer.compute( computeBulkScore, param: { "imageEmbeddings": _cachedEmbeddings, "textEmbedding": textEmbedding, + "scoreThreshold": scoreThreshold, }, taskName: "computeBulkScore", ); @@ -402,12 +410,14 @@ List computeBulkScore(Map args) { final queryResults = []; final imageEmbeddings = args["imageEmbeddings"] as List; final textEmbedding = args["textEmbedding"] as List; + final scoreThreshold = + args["scoreThreshold"] ?? SemanticSearchService.kScoreThreshold; for (final imageEmbedding in imageEmbeddings) { final score = computeScore( imageEmbedding.embedding, textEmbedding, ); - if (score >= SemanticSearchService.kScoreThreshold) { + if (score >= scoreThreshold) { queryResults.add(QueryResult(imageEmbedding.fileID, score)); } }