From b5cff212bbe897ad131168ba381e549b8d0dc8a2 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 21 Mar 2024 16:59:55 +0530 Subject: [PATCH] Refactor of clustering --- .../linear_clustering_service.dart | 171 ++++++++---------- 1 file changed, 72 insertions(+), 99 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart index d82db2b43f..c427cbc930 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart @@ -7,6 +7,7 @@ import "dart:typed_data"; import "package:logging/logging.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; +import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:synchronized/synchronized.dart"; class FaceInfo { @@ -15,10 +16,12 @@ class FaceInfo { int? clusterId; String? closestFaceId; int? closestDist; + int? fileCreationTime; FaceInfo({ required this.faceID, required this.embedding, this.clusterId, + this.fileCreationTime, }); } @@ -31,7 +34,6 @@ class FaceLinearClustering { final Duration _inactivityDuration = const Duration(seconds: 30); int _activeTasks = 0; - final _initLock = Lock(); late Isolate _isolate; @@ -151,8 +153,8 @@ class FaceLinearClustering { _resetInactivityTimer(); } else { _logger.info( - 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', - ); + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); dispose(); } }); @@ -220,6 +222,8 @@ class FaceLinearClustering { log( "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", ); + + // Organize everything into a list of FaceInfo objects final List faceInfos = []; for (final entry in x.entries) { faceInfos.add( @@ -249,59 +253,61 @@ class FaceLinearClustering { } // Sort the faceInfos such that the ones with null clusterId are at the end - faceInfos.sort((a, b) { - if (a.clusterId == null && b.clusterId == null) { - return 0; - } else if (a.clusterId == null) { - return 1; - } else if (b.clusterId == null) { - return -1; - } else { - return 0; - } - }); - // Count the amount of null values at the end - int nullCount = 0; - for (final faceInfo in faceInfos.reversed) { + final List facesWithClusterID = []; + final List facesWithoutClusterID = []; + for (final FaceInfo faceInfo in faceInfos) { if (faceInfo.clusterId == null) { - nullCount++; + facesWithoutClusterID.add(faceInfo); } else { - break; + facesWithClusterID.add(faceInfo); } } + final sortedFaceInfos = []; + sortedFaceInfos.addAll(facesWithClusterID); + sortedFaceInfos.addAll(facesWithoutClusterID); + log( - "[ClusterIsolate] ${DateTime.now()} Clustering $nullCount new faces without clusterId, and ${faceInfos.length - nullCount} faces with clusterId", + "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId", ); - for (final clusteredFaceInfo - in faceInfos.sublist(0, faceInfos.length - nullCount)) { - assert(clusteredFaceInfo.clusterId != null); + + // Make sure the first face has a clusterId + final int totalFaces = sortedFaceInfos.length; + int clusterID = 1; + if (sortedFaceInfos.isNotEmpty) { + if (sortedFaceInfos.first.clusterId == null) { + sortedFaceInfos.first.clusterId = clusterID; + } else { + clusterID = sortedFaceInfos.first.clusterId!; + } + } else { + return {}; } - final int totalFaces = faceInfos.length; - int clusterID = 1; - if (faceInfos.isNotEmpty) { - faceInfos.first.clusterId = clusterID; - } + // Start actual clustering log( "[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces", ); + final Map newFaceIdToCluster = {}; final stopwatchClustering = Stopwatch()..start(); for (int i = 1; i < totalFaces; i++) { // Incremental clustering, so we can skip faces that already have a clusterId - if (faceInfos[i].clusterId != null) { - clusterID = max(clusterID, faceInfos[i].clusterId!); + if (sortedFaceInfos[i].clusterId != null) { + clusterID = max(clusterID, sortedFaceInfos[i].clusterId!); + if (i % 250 == 0) { + log("[ClusterIsolate] ${DateTime.now()} First $i faces already had a clusterID"); + } continue; } - final currentEmbedding = faceInfos[i].embedding; + final currentEmbedding = sortedFaceInfos[i].embedding; int closestIdx = -1; double closestDistance = double.infinity; if (i % 250 == 0) { log("[ClusterIsolate] ${DateTime.now()} Processing $i faces"); } - for (int j = 0; j < i; j++) { + for (int j = i - 1; j >= 0; j--) { final double distance = cosineDistForNormVectors( currentEmbedding, - faceInfos[j].embedding, + sortedFaceInfos[j].embedding, ); if (distance < closestDistance) { closestDistance = distance; @@ -310,42 +316,43 @@ class FaceLinearClustering { } if (closestDistance < recommendedDistanceThreshold) { - if (faceInfos[closestIdx].clusterId == null) { + if (sortedFaceInfos[closestIdx].clusterId == null) { // Ideally this should never happen, but just in case log it log( - " [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID", + " [ClusterIsolate] [WARNING] ${DateTime.now()} Found new cluster $clusterID", ); clusterID++; - faceInfos[closestIdx].clusterId = clusterID; + sortedFaceInfos[closestIdx].clusterId = clusterID; + newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID; } - faceInfos[i].clusterId = faceInfos[closestIdx].clusterId; + sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; + newFaceIdToCluster[sortedFaceInfos[i].faceID] = + sortedFaceInfos[closestIdx].clusterId!; } else { clusterID++; - faceInfos[i].clusterId = clusterID; + sortedFaceInfos[i].clusterId = clusterID; + newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID; } } - final Map result = {}; - for (final faceInfo in faceInfos) { - result[faceInfo.faceID] = faceInfo.clusterId!; - } + stopwatchClustering.stop(); log( - ' [ClusterIsolate] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings (${faceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID', + ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID', ); - // return result; - // NOTe: The main clustering logic is done, the following is just filtering and logging - final input = x; - final faceIdToCluster = result; - stopwatchClustering.reset(); - stopwatchClustering.start(); + // analyze the results + FaceLinearClustering._analyzeClusterResults(sortedFaceInfos); - final Set newFaceIds = {}; - input.forEach((key, value) { - if (value.$1 == null) { - newFaceIds.add(key); - } - }); + return newFaceIdToCluster; + } + + static void _analyzeClusterResults(List sortedFaceInfos) { + final stopwatch = Stopwatch()..start(); + + final Map faceIdToCluster = {}; + for (final faceInfo in sortedFaceInfos) { + faceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } // Find faceIDs that are part of a cluster which is larger than 5 and are new faceIDs final Map clusterIdToSize = {}; @@ -356,12 +363,6 @@ class FaceLinearClustering { clusterIdToSize[value] = 1; } }); - final Map faceIdToClusterFiltered = {}; - for (final entry in faceIdToCluster.entries) { - if (clusterIdToSize[entry.value]! > 0 && newFaceIds.contains(entry.key)) { - faceIdToClusterFiltered[entry.key] = entry.value; - } - } // print top 10 cluster ids and their sizes based on the internal cluster id final clusterIds = faceIdToCluster.values.toSet(); @@ -369,7 +370,7 @@ class FaceLinearClustering { return faceIdToCluster.values.where((id) => id == clusterId).length; }).toList(); clusterSizes.sort(); - // find clusters whose size is graeter than 1 + // find clusters whose size is greater than 1 int oneClusterCount = 0; int moreThan5Count = 0; int moreThan10Count = 0; @@ -377,57 +378,29 @@ class FaceLinearClustering { int moreThan50Count = 0; int moreThan100Count = 0; - // for (int i = 0; i < clusterSizes.length; i++) { - // if (clusterSizes[i] > 100) { - // moreThan100Count++; - // } else if (clusterSizes[i] > 50) { - // moreThan50Count++; - // } else if (clusterSizes[i] > 20) { - // moreThan20Count++; - // } else if (clusterSizes[i] > 10) { - // moreThan10Count++; - // } else if (clusterSizes[i] > 5) { - // moreThan5Count++; - // } else if (clusterSizes[i] == 1) { - // oneClusterCount++; - // } - // } for (int i = 0; i < clusterSizes.length; i++) { if (clusterSizes[i] > 100) { moreThan100Count++; - } - if (clusterSizes[i] > 50) { + } else if (clusterSizes[i] > 50) { moreThan50Count++; - } - if (clusterSizes[i] > 20) { + } else if (clusterSizes[i] > 20) { moreThan20Count++; - } - if (clusterSizes[i] > 10) { + } else if (clusterSizes[i] > 10) { moreThan10Count++; - } - if (clusterSizes[i] > 5) { + } else if (clusterSizes[i] > 5) { moreThan5Count++; - } - if (clusterSizes[i] == 1) { + } else if (clusterSizes[i] == 1) { oneClusterCount++; } } + // print the metrics log( - '[ClusterIsolate] Total clusters ${clusterIds.length}, ' - 'oneClusterCount $oneClusterCount, ' - 'moreThan5Count $moreThan5Count, ' - 'moreThan10Count $moreThan10Count, ' - 'moreThan20Count $moreThan20Count, ' - 'moreThan50Count $moreThan50Count, ' - 'moreThan100Count $moreThan100Count', + "[ClusterIsolate] Total clusters ${clusterIds.length}: \n oneClusterCount $oneClusterCount \n moreThan5Count $moreThan5Count \n moreThan10Count $moreThan10Count \n moreThan20Count $moreThan20Count \n moreThan50Count $moreThan50Count \n moreThan100Count $moreThan100Count", ); - stopwatchClustering.stop(); + stopwatch.stop(); log( - "[ClusterIsolate] Clustering additional steps took ${stopwatchClustering.elapsedMilliseconds} ms", + "[ClusterIsolate] Clustering additional analysis took ${stopwatch.elapsedMilliseconds} ms", ); - - // log('Top clusters count ${clusterSizes.reversed.take(10).toList()}'); - return faceIdToClusterFiltered; } }