From 7509a367c5477e47de52c1aa4e3192737576555d Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Mon, 30 Jun 2025 11:43:22 +0200 Subject: [PATCH] Fix suggestion not using previous negative feedback --- mobile/lib/db/ml/base.dart | 1 + mobile/lib/db/ml/db.dart | 15 +++++++++++ .../face_ml/feedback/cluster_feedback.dart | 27 +++++++++++++++++-- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/mobile/lib/db/ml/base.dart b/mobile/lib/db/ml/base.dart index 0f8994192e..5bac0233d6 100644 --- a/mobile/lib/db/ml/base.dart +++ b/mobile/lib/db/ml/base.dart @@ -13,6 +13,7 @@ abstract class IMLDataDB { Future getFaceIndexedFileCount({int minimumMlVersion}); Future> clusterIdToFaceCount(); Future> getPersonIgnoredClusters(String personID); + Future>> getPersonToRejectedSuggestions(); Future> getPersonClusterIDs(String personID); Future> getPersonsClusterIDs(List personID); Future clearTable(); diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index bba2a96d47..851e349d00 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -212,6 +212,21 @@ class MLDataDB with SqlDbBase implements IMLDataDB { return ignoredClusterIDs.union(rejectClusterIDs); } + @override + Future>> getPersonToRejectedSuggestions() async { + final db = await instance.asyncDB; + final List> rejectMaps = await db.getAll( + 'SELECT $personIdColumn, $clusterIDColumn FROM $notPersonFeedback', + ); + final Map> result = {}; + for (final map in rejectMaps) { + final personID = map[personIdColumn] as String; + final clusterID = map[clusterIDColumn] as String; + result.putIfAbsent(personID, () => {}).add(clusterID); + } + return result; + } + @override Future> getPersonClusterIDs(String personID) async { final db = await instance.asyncDB; diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 09cf6335a0..2cefa84ac6 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -357,11 +357,23 @@ class ClusterFeedbackService { allOtherPersonClustersToIgnore, minClusterSize: kMinimumClusterSizeSearchResult, ); + + final Map> personClusterToIgnoredClusters = {}; + final personToRejectedSuggestions = + await mlDataDB.getPersonToRejectedSuggestions(); + for (final personID in personToRejectedSuggestions.keys) { + final personCluster = personIdToBiggestCluster[personID]; + if (personCluster == null) continue; + final ignoredClusters = personToRejectedSuggestions[personID] ?? {}; + personClusterToIgnoredClusters[personCluster] = ignoredClusters; + } + final List<(String, double, String)> foundSuggestions = await calcSuggestionsMeanInComputer( clusterAvg, allPersonClusters, allOtherPersonClustersToIgnore, + personClusterToIgnoredClusters: personClusterToIgnoredClusters, 0.55, ); @@ -1014,8 +1026,9 @@ class ClusterFeedbackService { Map clusterAvg, Set personClusters, Set ignoredClusters, - double maxClusterDistance, - ) async { + double maxClusterDistance, { + Map>? personClusterToIgnoredClusters, + }) async { return await _computer.compute( _calcSuggestionsMean, param: { @@ -1023,6 +1036,7 @@ class ClusterFeedbackService { 'personClusters': personClusters, 'ignoredClusters': ignoredClusters, 'maxClusterDistance': maxClusterDistance, + 'personClusterToIgnoredClusters': personClusterToIgnoredClusters, }, ); } @@ -1341,6 +1355,9 @@ List<(String, double, String)> _calcSuggestionsMean(Map args) { final Set personClusters = args['personClusters']; final Set ignoredClusters = args['ignoredClusters']; final double maxClusterDistance = args['maxClusterDistance']; + final Map>? personClusterToIgnoredClusters = + args['personClusterToIgnoredClusters']; + final bool extraIgnoreCheck = personClusterToIgnoredClusters != null; final Map> suggestions = {}; const suggestionMax = 2000; @@ -1366,6 +1383,12 @@ List<(String, double, String)> _calcSuggestionsMean(Map args) { dev.log('[WARNING] no avg for personcluster $personCluster'); continue; } + if (extraIgnoreCheck && + personClusterToIgnoredClusters[personCluster] != null && + personClusterToIgnoredClusters[personCluster]! + .contains(otherClusterID)) { + continue; + } final Vector avg = clusterAvg[personCluster]!; final distance = 1 - avg.dot(otherAvg); comparisons++;