diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 236f71c43d..4caf79b316 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -455,7 +455,7 @@ class ClusterFeedbackService { double maxMedianDistance = 0.62, double goodMedianDistance = 0.55, double maxMeanDistance = 0.65, - double goodMeanDistance = 0.54, + double goodMeanDistance = 0.50, }) async { final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); // Get all the cluster data @@ -472,10 +472,10 @@ class ClusterFeedbackService { .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0) .reduce((value, element) => min(value, element)); final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1]; + late final Map clusterAvgBigClusters; for (final minimumSize in checkSizes.toSet()) { if (smallestPersonClusterSize >= minimumSize) { - final Map clusterAvgBigClusters = - await _getUpdateClusterAvg( + clusterAvgBigClusters = await _getUpdateClusterAvg( allClusterIdsToCountMap, ignoredClusters, minClusterSize: minimumSize, @@ -502,29 +502,8 @@ class ClusterFeedbackService { } w?.reset(); - // Get and update the cluster summary to get the avg (centroid) and count - final Map clusterAvg = await _getUpdateClusterAvg( - allClusterIdsToCountMap, - ignoredClusters, - ); - w?.log( - 'computed avg for ${clusterAvg.length} clusters,', - ); - - // Find the other cluster candidates based on the mean - final List<(int, double)> suggestionsMean = _calcSuggestionsMean( - clusterAvg, - personClusters, - ignoredClusters, - goodMeanDistance, - ); - if (suggestionsMean.isNotEmpty) { - return suggestionsMean - .map((e) => (e.$1, e.$2, true)) - .toList(growable: false); - } - // Find the other cluster candidates based on the median + final clusterAvg = clusterAvgBigClusters; final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean( clusterAvg, personClusters, @@ -595,7 +574,7 @@ class ClusterFeedbackService { final List distances = []; for (final otherEmbedding in sampledOtherEmbeddings) { for (final embedding in sampledEmbeddings) { - distances.add(cosineDistanceSIMD(embedding,otherEmbedding)); + distances.add(cosineDistanceSIMD(embedding, otherEmbedding)); } } distances.sort(); @@ -800,7 +779,7 @@ class ClusterFeedbackService { continue; } final Vector avg = clusterAvg[personCluster]!; - final distance = cosineDistanceSIMD(avg,otherAvg); + final distance = cosineDistanceSIMD(avg, otherAvg); if (distance < maxClusterDistance) { if (minDistance == null || distance < minDistance) { minDistance = distance; @@ -951,7 +930,7 @@ class ClusterFeedbackService { final fileIdToDistanceMap = {}; for (final entry in faceIdToVectorMap.entries) { fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = - cosineDistanceSIMD(personAvg,entry.value); + cosineDistanceSIMD(personAvg, entry.value); } w?.log('calculated distances for cluster $clusterID'); suggestion.filesInCluster.sort((b, a) {