From b5113dd4205043fbe8637c84e62632014e1615bc Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 5 Aug 2025 11:52:03 +0530 Subject: [PATCH] Add option for exact search of similar files --- .../apps/photos/lib/db/ml/clip_vector_db.dart | 16 ++++++++--- .../similar_images_service.dart | 27 ++++++++++++++++--- .../debug/ml_debug_section_widget.dart | 2 ++ .../lib/ui/tools/similar_images_page.dart | 26 +++++++++++++++++- 4 files changed, 63 insertions(+), 8 deletions(-) diff --git a/mobile/apps/photos/lib/db/ml/clip_vector_db.dart b/mobile/apps/photos/lib/db/ml/clip_vector_db.dart index e78c100067..9fcffecf30 100644 --- a/mobile/apps/photos/lib/db/ml/clip_vector_db.dart +++ b/mobile/apps/photos/lib/db/ml/clip_vector_db.dart @@ -176,11 +176,15 @@ class ClipVectorDB { Future<(Uint64List, Float32List)> searchClosestVectors( List query, int count, + bool exact, ) async { final db = await _vectorDB; try { - final result = - await db.searchVectors(query: query, count: BigInt.from(count)); + final result = await db.searchVectors( + query: query, + count: BigInt.from(count), + exact: exact, + ); return result; } catch (e, s) { _logger.severe("Error searching closest vectors", e, s); @@ -190,10 +194,12 @@ class ClipVectorDB { Future<(BigInt, double)> searchClosestVector( List query, + bool exact, ) async { final db = await _vectorDB; try { - final result = await db.searchVectors(query: query, count: BigInt.one); + final result = + await db.searchVectors(query: query, count: BigInt.one, exact: exact); return (result.$1[0], result.$2[0]); } catch (e, s) { _logger.severe("Error searching closest vector", e, s); @@ -204,12 +210,14 @@ class ClipVectorDB { Future<(List, List)> bulkSearchVectors( List queries, BigInt count, + bool exact, ) async { final db = await _vectorDB; try { final result = await db.bulkSearchVectors( queries: queries, count: count, + exact: exact, ); return result; } catch (e, s) { @@ -230,7 +238,7 @@ class ClipVectorDB { final minimumSimilarity = minimumSimilarityMap[query]!; final textEmbedding = entry.value; final (potentialFileIDs, distances) = - await searchClosestVectors(textEmbedding, 1000); + await searchClosestVectors(textEmbedding, 1000, true); final queryResults = []; for (var i = 0; i < potentialFileIDs.length; i++) { final similarity = 1 - distances[i]; diff --git a/mobile/apps/photos/lib/services/machine_learning/similar_images_service.dart b/mobile/apps/photos/lib/services/machine_learning/similar_images_service.dart index ee56868162..880ab3884d 100644 --- a/mobile/apps/photos/lib/services/machine_learning/similar_images_service.dart +++ b/mobile/apps/photos/lib/services/machine_learning/similar_images_service.dart @@ -1,11 +1,13 @@ import "dart:math" show max; import "dart:typed_data" show Float32List; +import "package:flutter/foundation.dart" show kDebugMode; import "package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart" show Uint64List; import 'package:logging/logging.dart'; import "package:photos/db/ml/clip_vector_db.dart"; import "package:photos/db/ml/db.dart"; +import "package:photos/extensions/stop_watch.dart"; import 'package:photos/models/file/file.dart'; import "package:photos/models/file/file_type.dart"; import "package:photos/models/ml/vector.dart"; @@ -21,10 +23,18 @@ class SimilarImagesService { /// Returns a list of SimilarFiles, where each SimilarFiles object contains /// a list of files that are perceptually similar - Future> getSimilarFiles(double distanceThreshold) async { + Future> getSimilarFiles( + double distanceThreshold, + bool exact, + ) async { try { + final now = DateTime.now(); final List result = - await _getSimilarFiles(distanceThreshold); + await _getSimilarFiles(distanceThreshold, exact); + final duration = DateTime.now().difference(now); + _logger.info( + "Found ${result.length} similar files in ${duration.inSeconds} seconds for threshold $distanceThreshold and exact $exact", + ); return result; } catch (e, s) { _logger.severe("failed to get similar files", e, s); @@ -32,10 +42,15 @@ class SimilarImagesService { } } - Future> _getSimilarFiles(double distanceThreshold) async { + Future> _getSimilarFiles( + double distanceThreshold, + bool exact, + ) async { + final w = (kDebugMode ? EnteWatch('getSimilarFiles') : null)?..start(); final mlDataDB = MLDataDB.instance; _logger.info("Checking migration and filling clip vector DB"); await mlDataDB.checkMigrateFillClipVectorDB(); + w?.log("checkMigrateFillClipVectorDB"); // Get the embeddings ready for vector search final List allImageEmbeddings = @@ -48,13 +63,16 @@ class SimilarImagesService { final keys = Uint64List.fromList( allImageEmbeddings.map((e) => BigInt.from(e.fileID)).toList(), ); + w?.log("getAllClipVectors"); // Run bulk vector search final (vectorKeys, distances) = await ClipVectorDB.instance.bulkSearchVectors( clipFloat32, BigInt.from(100), + exact, ); + w?.log("bulkSearchVectors"); // Get all files, and create a map of fileID to file final allFiles = Set.from( @@ -66,6 +84,7 @@ class SimilarImagesService { allFileIdsToFile[file.uploadedFileID!] = file; } } + w?.log("getAllFilesForSearch"); // Run through the vector search results and create SimilarFiles objects final alreadyUsedFileIDs = {}; @@ -111,9 +130,11 @@ class SimilarImagesService { allSimilarFiles.add(similarFiles); } } + w?.log("going through files"); // Sort the similar files by total size in descending order allSimilarFiles.sort((a, b) => b.totalSize.compareTo(a.totalSize)); + w?.log("sort similar files"); return allSimilarFiles; } diff --git a/mobile/apps/photos/lib/ui/settings/debug/ml_debug_section_widget.dart b/mobile/apps/photos/lib/ui/settings/debug/ml_debug_section_widget.dart index 93fb519185..b905228ee8 100644 --- a/mobile/apps/photos/lib/ui/settings/debug/ml_debug_section_widget.dart +++ b/mobile/apps/photos/lib/ui/settings/debug/ml_debug_section_widget.dart @@ -420,6 +420,7 @@ class _MLDebugSectionWidgetState extends State { final (vectorKeys, distances) = await vectorDB.bulkSearchVectors( queries: queries, count: count, + exact: false, ); w?.log( @@ -504,6 +505,7 @@ class _MLDebugSectionWidgetState extends State { final (vectorKeys, distances) = await vectorDB.bulkSearchVectors( queries: clipFloat32, count: count, + exact: false, ); w?.log( diff --git a/mobile/apps/photos/lib/ui/tools/similar_images_page.dart b/mobile/apps/photos/lib/ui/tools/similar_images_page.dart index cf289e3c17..1c6447a0a3 100644 --- a/mobile/apps/photos/lib/ui/tools/similar_images_page.dart +++ b/mobile/apps/photos/lib/ui/tools/similar_images_page.dart @@ -12,6 +12,7 @@ import 'package:photos/theme/ente_theme.dart'; import "package:photos/ui/common/loading_widget.dart"; import 'package:photos/ui/components/buttons/button_widget.dart'; import "package:photos/ui/components/models/button_type.dart"; +import "package:photos/ui/components/toggle_switch_widget.dart"; import "package:photos/ui/viewer/file/detail_page.dart"; import "package:photos/ui/viewer/file/thumbnail_widget.dart"; import "package:photos/utils/dialog_util.dart"; @@ -49,6 +50,7 @@ class _SimilarImagesPageState extends State { double _distanceThreshold = 0.04; // Default value List _similarFilesList = []; SortKey _sortKey = SortKey.size; + bool _exactSearch = false; @override void initState() { @@ -159,6 +161,25 @@ class _SimilarImagesPageState extends State { textAlign: TextAlign.center, ), const SizedBox(height: 48), + Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text( + "Exact search", // TODO: lau: extract string + style: getEnteTextTheme(context).bodyBold, + ), + ToggleSwitchWidget( + value: () => _exactSearch, + onChanged: () async { + if (_isDisposed) return; + setState(() { + _exactSearch = !_exactSearch; + }); + }, + ), + ], + ), + const SizedBox(height: 32), ButtonWidget( labelText: "Find similar images", // TODO: lau: extract string buttonType: ButtonType.primary, @@ -269,8 +290,11 @@ class _SimilarImagesPageState extends State { }); try { + // You can use _toggleValue here for advanced mode features + _logger.info("exact mode: $_exactSearch"); + final similarFiles = await SimilarImagesService.instance - .getSimilarFiles(_distanceThreshold); + .getSimilarFiles(_distanceThreshold, _exactSearch); _logger.info( "Found ${similarFiles.length} groups of similar images", );