From cfb4ded9913d87f638753aef3ed08a2ce932b8c5 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 29 May 2024 11:13:21 +0530 Subject: [PATCH] [mob][photos] Fix breakupCluster not returning cluster summaries --- .../face_clustering_service.dart | 83 +++++++++---------- .../face_ml/feedback/cluster_feedback.dart | 2 +- 2 files changed, 41 insertions(+), 44 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 310deb964a..35a25748ac 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -45,6 +45,8 @@ class ClusteringResult { bool get isEmpty => newFaceIdToCluster.isEmpty; + bool get hasAllResults => newClusterSummaries != null && newClusterIdToFaceIds != null; + ClusteringResult({ required this.newFaceIdToCluster, this.newClusterSummaries, @@ -127,8 +129,7 @@ class FaceClusteringService { break; } } catch (e, stackTrace) { - sendPort - .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + sendPort.send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); } }); } @@ -256,6 +257,7 @@ class FaceClusteringService { Future predictLinearComputer( Map input, { Map? fileIDToCreationTime, + required Map oldClusterSummaries, double distanceThreshold = kRecommendedDistanceThreshold, }) async { if (input.isEmpty) { @@ -291,6 +293,7 @@ class FaceClusteringService { param: { "input": clusteringInput, "fileIDToCreationTime": fileIDToCreationTime, + "oldClusterSummaries": oldClusterSummaries, "distanceThreshold": distanceThreshold, "conservativeDistanceThreshold": distanceThreshold - 0.08, "useDynamicThreshold": false, @@ -314,6 +317,7 @@ class FaceClusteringService { Future predictCompleteComputer( Map input, { Map? fileIDToCreationTime, + required Map oldClusterSummaries, double distanceThreshold = kRecommendedDistanceThreshold, double mergeThreshold = 0.30, }) async { @@ -336,6 +340,7 @@ class FaceClusteringService { param: { "input": input, "fileIDToCreationTime": fileIDToCreationTime, + "oldClusterSummaries": oldClusterSummaries, "distanceThreshold": distanceThreshold, "mergeThreshold": mergeThreshold, }, @@ -355,6 +360,7 @@ class FaceClusteringService { Future predictWithinClusterComputer( Map input, { Map? fileIDToCreationTime, + Map oldClusterSummaries = const {}, double distanceThreshold = kRecommendedDistanceThreshold, }) async { _logger.info( @@ -369,6 +375,7 @@ class FaceClusteringService { final result = await predictCompleteComputer( input, fileIDToCreationTime: fileIDToCreationTime, + oldClusterSummaries: oldClusterSummaries, distanceThreshold: distanceThreshold - 0.08, mergeThreshold: mergeThreshold, ); @@ -381,6 +388,7 @@ class FaceClusteringService { final clusterResult = await predictLinearComputer( input, fileIDToCreationTime: fileIDToCreationTime, + oldClusterSummaries: oldClusterSummaries, distanceThreshold: distanceThreshold, ); return clusterResult; @@ -396,12 +404,10 @@ class FaceClusteringService { final input = args['input'] as Set; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; final distanceThreshold = args['distanceThreshold'] as double; - final conservativeDistanceThreshold = - args['conservativeDistanceThreshold'] as double; + final conservativeDistanceThreshold = args['conservativeDistanceThreshold'] as double; final useDynamicThreshold = args['useDynamicThreshold'] as bool; final offset = args['offset'] as int?; - final oldClusterSummaries = - args['oldClusterSummaries'] as Map?; + final oldClusterSummaries = args['oldClusterSummaries'] as Map?; log( "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", @@ -425,8 +431,7 @@ class FaceClusteringService { dtype: DType.float32, ), clusterId: face.clusterId, - fileCreationTime: - fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)], + fileCreationTime: fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)], ), ); } @@ -493,9 +498,8 @@ class FaceClusteringService { double closestDistance = double.infinity; late double thresholdValue; if (useDynamicThreshold) { - thresholdValue = sortedFaceInfos[i].badFace! - ? conservativeDistanceThreshold - : distanceThreshold; + thresholdValue = + sortedFaceInfos[i].badFace! ? conservativeDistanceThreshold : distanceThreshold; if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++; } else { thresholdValue = distanceThreshold; @@ -505,11 +509,10 @@ class FaceClusteringService { } // WARNING: The loop below is now O(n^2) so be very careful with anything you put in there! for (int j = i - 1; j >= 0; j--) { - final double distance = 1 - - sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!); + final double distance = + 1 - sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!); if (distance < closestDistance) { - if (sortedFaceInfos[j].badFace! && - distance > conservativeDistanceThreshold) { + if (sortedFaceInfos[j].badFace! && distance > conservativeDistanceThreshold) { continue; } closestDistance = distance; @@ -535,8 +538,7 @@ class FaceClusteringService { // Finally, assign the new clusterId to the faces final Map newFaceIdToCluster = {}; - final newClusteredFaceInfos = - sortedFaceInfos.sublist(alreadyClusteredCount); + final newClusteredFaceInfos = sortedFaceInfos.sublist(alreadyClusteredCount); for (final faceInfo in newClusteredFaceInfos) { newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; } @@ -597,9 +599,8 @@ class FaceClusteringService { final Map newClusterSummaries = {}; for (final clusterId in newClusterIdToFaceInfos.keys) { - final List newEmbeddings = newClusterIdToFaceInfos[clusterId]! - .map((faceInfo) => faceInfo.vEmbedding!) - .toList(); + final List newEmbeddings = + newClusterIdToFaceInfos[clusterId]!.map((faceInfo) => faceInfo.vEmbedding!).toList(); final newCount = newEmbeddings.length; if (oldSummary.containsKey(clusterId)) { final oldMean = Vector.fromList( @@ -609,8 +610,7 @@ class FaceClusteringService { final oldCount = oldSummary[clusterId]!.$2; final oldEmbeddings = oldMean * oldCount; newEmbeddings.add(oldEmbeddings); - final newMeanVector = - newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount); + final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount); final newMeanVectorNormalized = newMeanVector / newMeanVector.norm(); newClusterSummaries[clusterId] = ( EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), @@ -619,10 +619,8 @@ class FaceClusteringService { } else { final newMeanVector = newEmbeddings.reduce((a, b) => a + b); final newMeanVectorNormalized = newMeanVector / newMeanVector.norm(); - newClusterSummaries[clusterId] = ( - EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), - newCount - ); + newClusterSummaries[clusterId] = + (EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount); } } log( @@ -696,6 +694,7 @@ class FaceClusteringService { final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; final distanceThreshold = args['distanceThreshold'] as double; final mergeThreshold = args['mergeThreshold'] as double; + final oldClusterSummaries = args['oldClusterSummaries'] as Map?; log( "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering", @@ -711,8 +710,7 @@ class FaceClusteringService { EVector.fromBuffer(entry.value).values, dtype: DType.float32, ), - fileCreationTime: - fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], + fileCreationTime: fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], ), ); } @@ -746,8 +744,7 @@ class FaceClusteringService { double closestDistance = double.infinity; for (int j = 0; j < totalFaces; j++) { if (i == j) continue; - final double distance = - 1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!); + final double distance = 1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!); if (distance < closestDistance) { closestDistance = distance; closestIdx = j; @@ -777,21 +774,17 @@ class FaceClusteringService { } final Map clusterIdToMeanEmbeddingAndWeight = {}; for (final clusterId in clusterIdToFaceInfos.keys) { - final List embeddings = clusterIdToFaceInfos[clusterId]! - .map((faceInfo) => faceInfo.vEmbedding!) - .toList(); + final List embeddings = + clusterIdToFaceInfos[clusterId]!.map((faceInfo) => faceInfo.vEmbedding!).toList(); final count = clusterIdToFaceInfos[clusterId]!.length; final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count; - final Vector meanEmbeddingNormalized = - meanEmbedding / meanEmbedding.norm(); - clusterIdToMeanEmbeddingAndWeight[clusterId] = - (meanEmbeddingNormalized, count); + final Vector meanEmbeddingNormalized = meanEmbedding / meanEmbedding.norm(); + clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbeddingNormalized, count); } // Now merge the clusters that are close to each other, based on mean embedding final List<(int, int)> mergedClustersList = []; - final List clusterIds = - clusterIdToMeanEmbeddingAndWeight.keys.toList(); + final List clusterIds = clusterIdToMeanEmbeddingAndWeight.keys.toList(); log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges'); while (true) { if (clusterIds.length < 2) break; @@ -858,10 +851,14 @@ class FaceClusteringService { } } - final newClusterSummaries = FaceClusteringService.updateClusterSummaries( - oldSummary: {}, - newFaceInfos: faceInfos, - ); + // Now calculate the mean of the embeddings for each cluster and update the cluster summaries + Map? newClusterSummaries; + if (oldClusterSummaries != null) { + newClusterSummaries = FaceClusteringService.updateClusterSummaries( + oldSummary: oldClusterSummaries, + newFaceInfos: faceInfos, + ); + } stopwatchClustering.stop(); log( diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index b354719f9a..5a41a60df7 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -434,7 +434,7 @@ class ClusterFeedbackService { ); if (clusterResult == null || - clusterResult.newClusterIdToFaceIds == null || + !clusterResult.hasAllResults || clusterResult.isEmpty) { _logger.warning('No clusters found or something went wrong'); return ClusteringResult(newFaceIdToCluster: {});