From 703277fc1db8aec5b72673acbac8d5ce197161fd Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 17 Jan 2025 15:18:38 +0530 Subject: [PATCH] [mob][photos] Run once --- .../ui/viewer/people/save_or_edit_person.dart | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/mobile/lib/ui/viewer/people/save_or_edit_person.dart b/mobile/lib/ui/viewer/people/save_or_edit_person.dart index 8af7a280b0..c1289b9ec5 100644 --- a/mobile/lib/ui/viewer/people/save_or_edit_person.dart +++ b/mobile/lib/ui/viewer/people/save_or_edit_person.dart @@ -61,6 +61,7 @@ class _SaveOrEditPersonState extends State { late final Logger _logger = Logger("_SavePersonState"); Timer? _debounce; List<(PersonEntity, EnteFile)> _cachedPersons = []; + final Map _personToMaxSimilarity = {}; PersonEntity? person; @override @@ -342,7 +343,7 @@ class _SaveOrEditPersonState extends State { if (searchResults.isEmpty) { return const SizedBox.shrink(); } - final finalResults = await _sortByCosine(searchResults); + final finalResults = _sortByCosine(searchResults); return Column( crossAxisAlignment: CrossAxisAlignment.start, @@ -408,28 +409,27 @@ class _SaveOrEditPersonState extends State { if (_cachedPersons.isEmpty) { _cachedPersons = await _getPersonsWithRecentFile(); } + if (widget.clusterID != null) { + if (_personToMaxSimilarity.isEmpty) { + await _temp(); + } + } yield _cachedPersons; } - Future> _sortByCosine( - List<(PersonEntity, EnteFile)> searchResults, - ) async { - if (widget.clusterID == null) return searchResults; - + Future _temp() async { // Get current cluster embedding - final currentClusterSummary = - await MLDataDB.instance.getClusterToClusterSummary([widget.clusterID!]); + final allClusterSummary = await MLDataDB.instance.getAllClusterSummary(); final currentClusterEmbeddingData = - currentClusterSummary[widget.clusterID!]?.$1; - if (currentClusterEmbeddingData == null) return searchResults; + allClusterSummary[widget.clusterID!]?.$1; + if (currentClusterEmbeddingData == null) return; final ml.Vector currentClusterEmbedding = ml.Vector.fromList( EVector.fromBuffer(currentClusterEmbeddingData).values, dtype: ml.DType.float32, ); // Get all cluster embeddings - final allClusterSummary = await MLDataDB.instance.getAllClusterSummary(); - final persons = searchResults.map((e) => e.$1).toList(); + final persons = _cachedPersons.map((e) => e.$1).toList(); final clusterToPerson = {}; for (final person in persons) { if (person.data.assigned != null) { @@ -449,20 +449,29 @@ class _SaveOrEditPersonState extends State { ), ), ); + // Calculate cosine similarity between current cluster and all clusters - final Map personToMaxSimilarity = {}; for (final entry in allClusterEmbeddings.entries) { final personId = clusterToPerson[entry.key]!; final similarity = currentClusterEmbedding.dot(entry.value); - personToMaxSimilarity[personId] = max( - personToMaxSimilarity[personId] ?? double.negativeInfinity, + _personToMaxSimilarity[personId] = max( + _personToMaxSimilarity[personId] ?? double.negativeInfinity, similarity, ); } + + } + + List<(PersonEntity, EnteFile)> _sortByCosine( + List<(PersonEntity, EnteFile)> searchResults, + ) { + if (widget.clusterID == null) return searchResults; + if (_personToMaxSimilarity.isEmpty) return searchResults; + // Sort search results based on cosine similarity searchResults.sort((a, b) { - final similarityA = personToMaxSimilarity[a.$1.remoteID] ?? 0; - final similarityB = personToMaxSimilarity[b.$1.remoteID] ?? 0; + final similarityA = _personToMaxSimilarity[a.$1.remoteID] ?? 0; + final similarityB = _personToMaxSimilarity[b.$1.remoteID] ?? 0; return similarityB.compareTo(similarityA); });