From be3e33f5c5099624ed70b5136a62d8099e6cfdd8 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Fri, 31 May 2024 18:07:04 +0530 Subject: [PATCH] Handle missing cluster summary on sorting --- .../face_ml/feedback/cluster_feedback.dart | 190 ++++++++++++------ 1 file changed, 126 insertions(+), 64 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index af6f07a1ab..c328104b75 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -67,7 +67,8 @@ class ClusterFeedbackService { try { // Get the suggestions for the person using centroids and median final startTime = DateTime.now(); - final List<(int, double, bool)> foundSuggestions = await _getSuggestions(person); + final List<(int, double, bool)> foundSuggestions = + await _getSuggestions(person); final findSuggestionsTime = DateTime.now(); _logger.info( 'getSuggestionForPerson `_getSuggestions`: Found ${foundSuggestions.length} suggestions in ${findSuggestionsTime.difference(startTime).inMilliseconds} ms', @@ -141,12 +142,15 @@ class ClusterFeedbackService { final fileID = getFileIdFromFaceId(faceID); return files.any((file) => file.uploadedFileID == fileID); }); - final embeddings = await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); - final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime(); + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); // Re-cluster within the deleted faces - final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer( + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( embeddings, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.20, @@ -159,14 +163,16 @@ class ClusterFeedbackService { // Update the deleted faces await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); - await FaceMLDataDB.instance.clusterSummaryUpdate(clusterResult.newClusterSummaries); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusterResult.newClusterSummaries); // Make sure the deleted faces don't get suggested in the future final notClusterIdToPersonId = {}; for (final clusterId in newFaceIdToClusterID.values.toSet()) { notClusterIdToPersonId[clusterId] = p.remoteID; } - await FaceMLDataDB.instance.bulkCaptureNotPersonFeedback(notClusterIdToPersonId); + await FaceMLDataDB.instance + .bulkCaptureNotPersonFeedback(notClusterIdToPersonId); Bus.instance.fire(PeopleChangedEvent()); return; @@ -189,12 +195,15 @@ class ClusterFeedbackService { final fileID = getFileIdFromFaceId(faceID); return files.any((file) => file.uploadedFileID == fileID); }); - final embeddings = await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); - final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime(); + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); // Re-cluster within the deleted faces - final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer( + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( embeddings, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.20, @@ -206,7 +215,8 @@ class ClusterFeedbackService { // Update the deleted faces await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); - await FaceMLDataDB.instance.clusterSummaryUpdate(clusterResult.newClusterSummaries); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusterResult.newClusterSummaries); Bus.instance.fire( PeopleChangedEvent( @@ -309,12 +319,14 @@ class ClusterFeedbackService { final allClusterToFaceCount = await faceMlDb.clusterIdToFaceCount(); final clustersToInspect = []; for (final clusterID in allClusterToFaceCount.keys) { - if (allClusterToFaceCount[clusterID]! > 20 && allClusterToFaceCount[clusterID]! < 500) { + if (allClusterToFaceCount[clusterID]! > 20 && + allClusterToFaceCount[clusterID]! < 500) { clustersToInspect.add(clusterID); } } - final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime(); + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); final susClusters = <(int, int)>[]; @@ -325,7 +337,8 @@ class ClusterFeedbackService { final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs); - final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer( + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( embeddings, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.22, @@ -338,7 +351,8 @@ class ClusterFeedbackService { continue; } - final newClusterIdToCount = clusterResult.newClusterIdToFaceIds.map((key, value) { + final newClusterIdToCount = + clusterResult.newClusterIdToFaceIds.map((key, value) { return MapEntry(key, value.length); }); final amountOfNewClusters = newClusterIdToCount.length; @@ -359,12 +373,15 @@ class ClusterFeedbackService { final int secondBiggestClusterID = clusterIDs.reduce((a, b) { return newClusterIdToCount[a]! > newClusterIdToCount[b]! ? a : b; }); - final int secondBiggestSize = newClusterIdToCount[secondBiggestClusterID]!; + final int secondBiggestSize = + newClusterIdToCount[secondBiggestClusterID]!; final secondBiggestRatio = secondBiggestSize / originalClusterSize; if (biggestRatio < 0.5 || secondBiggestRatio > 0.2) { - final faceIdsOfCluster = await faceMlDb.getFaceIDsForCluster(clusterID); - final uniqueFileIDs = faceIdsOfCluster.map(getFileIdFromFaceId).toSet(); + final faceIdsOfCluster = + await faceMlDb.getFaceIDsForCluster(clusterID); + final uniqueFileIDs = + faceIdsOfCluster.map(getFileIdFromFaceId).toSet(); susClusters.add((clusterID, uniqueFileIDs.length)); _logger.info( '[CheckMixedClusters] Detected that cluster $clusterID with size ${uniqueFileIDs.length} might be mixed', @@ -409,9 +426,11 @@ class ClusterFeedbackService { return ClusteringResult.empty(); } - final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime(); + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); - final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer( + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( embeddings, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.22, @@ -422,7 +441,8 @@ class ClusterFeedbackService { return ClusteringResult.empty(); } - final clusterIdToCount = clusterResult.newClusterIdToFaceIds.map((key, value) { + final clusterIdToCount = + clusterResult.newClusterIdToFaceIds.map((key, value) { return MapEntry(key, value.length); }); final amountOfNewClusters = clusterIdToCount.length; @@ -433,7 +453,8 @@ class ClusterFeedbackService { if (kDebugMode) { final Set allClusteredFaceIDsSet = {}; - for (final List value in clusterResult.newClusterIdToFaceIds.values) { + for (final List value + in clusterResult.newClusterIdToFaceIds.values) { allClusteredFaceIDsSet.addAll(value); } assert((originalFaceIDsSet.difference(allClusteredFaceIDsSet)).isEmpty); @@ -447,14 +468,16 @@ class ClusterFeedbackService { try { // Delete old clusters await FaceMLDataDB.instance.dropClustersAndPersonTable(); - final List persons = await PersonService.instance.getPersons(); + final List persons = + await PersonService.instance.getPersons(); for (final PersonEntity p in persons) { await PersonService.instance.deletePerson(p.remoteID); } // Create new fake clusters based on blur value. One for values between 0 and 10, one for 10-20, etc till 200 final int startClusterID = DateTime.now().microsecondsSinceEpoch; - final faceIDsToBlurValues = await FaceMLDataDB.instance.getFaceIDsToBlurValues(200); + final faceIDsToBlurValues = + await FaceMLDataDB.instance.getFaceIDsToBlurValues(200); final faceIdToCluster = {}; for (final entry in faceIDsToBlurValues.entries) { final faceID = entry.key; @@ -527,40 +550,46 @@ class ClusterFeedbackService { _logger.info( 'L2 norm of current mean: $currentL2Norm', ); - final trueDistance = biggestMean.distanceTo(currentMean, distance: Distance.cosine); + final trueDistance = + biggestMean.distanceTo(currentMean, distance: Distance.cosine); _logger.info('True distance between the two means: $trueDistance'); // Median distance const sampleSize = 100; - final Iterable biggestEmbeddings = - await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(biggestClusterID); - final List biggestSampledEmbeddingsProto = _randomSampleWithoutReplacement( + final Iterable biggestEmbeddings = await FaceMLDataDB + .instance + .getFaceEmbeddingsForCluster(biggestClusterID); + final List biggestSampledEmbeddingsProto = + _randomSampleWithoutReplacement( biggestEmbeddings, sampleSize, ); - final List biggestSampledEmbeddings = biggestSampledEmbeddingsProto - .map( - (embedding) => Vector.fromList( - EVector.fromBuffer(embedding).values, - dtype: DType.float32, - ), - ) - .toList(growable: false); + final List biggestSampledEmbeddings = + biggestSampledEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) + .toList(growable: false); final Iterable currentEmbeddings = await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); - final List currentSampledEmbeddingsProto = _randomSampleWithoutReplacement( + final List currentSampledEmbeddingsProto = + _randomSampleWithoutReplacement( currentEmbeddings, sampleSize, ); - final List currentSampledEmbeddings = currentSampledEmbeddingsProto - .map( - (embedding) => Vector.fromList( - EVector.fromBuffer(embedding).values, - dtype: DType.float32, - ), - ) - .toList(growable: false); + final List currentSampledEmbeddings = + currentSampledEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) + .toList(growable: false); // Calculate distances and find the median final List distances = []; @@ -579,7 +608,8 @@ class ClusterFeedbackService { distances.sort(); trueDistances.sort(); final double medianDistance = distances[distances.length ~/ 2]; - final double trueMedianDistance = trueDistances[trueDistances.length ~/ 2]; + final double trueMedianDistance = + trueDistances[trueDistances.length ~/ 2]; _logger.info( "Median distance between biggest cluster and current cluster: $medianDistance (using sample of $sampleSize)", ); @@ -594,7 +624,8 @@ class ClusterFeedbackService { final List blurValues = await FaceMLDataDB.instance .getBlurValuesForCluster(clusterID) .then((value) => value.toList()); - final blurValuesIntegers = blurValues.map((value) => value.round()).toList(); + final blurValuesIntegers = + blurValues.map((value) => value.round()).toList(); blurValuesIntegers.sort(); _logger.info( "Blur values for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}: $blurValuesIntegers", @@ -622,12 +653,14 @@ class ClusterFeedbackService { final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount(); final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); - final personFaceIDs = await FaceMLDataDB.instance.getFaceIDsForPerson(p.remoteID); + final personFaceIDs = + await FaceMLDataDB.instance.getFaceIDsForPerson(p.remoteID); final personFileIDs = personFaceIDs.map(getFileIdFromFaceId).toSet(); w?.log( '${p.data.name} has ${personClusters.length} existing clusters, getting all database data done', ); - final allClusterIdToFaceIDs = await FaceMLDataDB.instance.getAllClusterIdToFaceIDs(); + final allClusterIdToFaceIDs = + await FaceMLDataDB.instance.getAllClusterIdToFaceIDs(); w?.log('getAllClusterIdToFaceIDs done'); // First only do a simple check on the big clusters, if the person does not have small clusters yet @@ -638,7 +671,8 @@ class ClusterFeedbackService { late Map clusterAvgBigClusters; final List<(int, double)> suggestionsMean = []; for (final minimumSize in checkSizes.toSet()) { - if (smallestPersonClusterSize >= min(minimumSize, kMinimumClusterSizeSearchResult)) { + if (smallestPersonClusterSize >= + min(minimumSize, kMinimumClusterSizeSearchResult)) { clusterAvgBigClusters = await _getUpdateClusterAvg( allClusterIdsToCountMap, ignoredClusters, @@ -647,7 +681,8 @@ class ClusterFeedbackService { w?.log( 'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', ); - final List<(int, double)> suggestionsMeanBigClusters = await calcSuggestionsMeanInComputer( + final List<(int, double)> suggestionsMeanBigClusters = + await calcSuggestionsMeanInComputer( clusterAvgBigClusters, personClusters, ignoredClusters, @@ -662,7 +697,8 @@ class ClusterFeedbackService { .map((faceID) => getFileIdFromFaceId(faceID)) .toSet(); final overlap = personFileIDs.intersection(suggestionSet); - if (overlap.isNotEmpty && ((overlap.length / suggestionSet.length) > 0.5)) { + if (overlap.isNotEmpty && + ((overlap.length / suggestionSet.length) > 0.5)) { await FaceMLDataDB.instance.captureNotPersonFeedback( personID: p.remoteID, clusterID: suggestion.$1, @@ -672,7 +708,9 @@ class ClusterFeedbackService { suggestionsMean.add(suggestion); } if (suggestionsMean.isNotEmpty) { - return suggestionsMean.map((e) => (e.$1, e.$2, true)).toList(growable: false); + return suggestionsMean + .map((e) => (e.$1, e.$2, true)) + .toList(growable: false); } } } @@ -680,14 +718,16 @@ class ClusterFeedbackService { // Find the other cluster candidates based on the median final clusterAvg = clusterAvgBigClusters; - final List<(int, double)> moreSuggestionsMean = await calcSuggestionsMeanInComputer( + final List<(int, double)> moreSuggestionsMean = + await calcSuggestionsMeanInComputer( clusterAvg, personClusters, ignoredClusters, maxMeanDistance, ); if (moreSuggestionsMean.isEmpty) { - _logger.info("No suggestions found using mean, even with higher threshold"); + _logger + .info("No suggestions found using mean, even with higher threshold"); return []; } @@ -709,7 +749,8 @@ class ClusterFeedbackService { await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); personEmbeddingsProto.addAll(embeddings); } - final List sampledEmbeddingsProto = _randomSampleWithoutReplacement( + final List sampledEmbeddingsProto = + _randomSampleWithoutReplacement( personEmbeddingsProto, sampleSize, ); @@ -853,7 +894,8 @@ class ClusterFeedbackService { // get clusterIDs sorted by count in descending order final sortedClusterIDs = allClusterIds.toList(); sortedClusterIDs.sort( - (a, b) => allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), + (a, b) => + allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), ); int indexedInCurrentRun = 0; w?.reset(); @@ -868,7 +910,8 @@ class ClusterFeedbackService { currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0; clusterIdsToRead.add(clusterID); } else { - if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) < maxEmbeddingToRead) { + if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) < + maxEmbeddingToRead) { clusterIdsToRead.add(clusterID); currentPendingRead += allClusterIdsToCountMap[clusterID]!; } else { @@ -877,8 +920,9 @@ class ClusterFeedbackService { } } - final Map> clusterEmbeddings = - await FaceMLDataDB.instance.getFaceEmbeddingsForClusters(clusterIdsToRead); + final Map> clusterEmbeddings = await FaceMLDataDB + .instance + .getFaceEmbeddingsForClusters(clusterIdsToRead); w?.logAndReset( 'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters', @@ -895,7 +939,8 @@ class ClusterFeedbackService { final avg = vectors.reduce((a, b) => a + b) / vectors.length; final avgNormalized = avg / avg.norm(); final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer(); - updatesForClusterSummary[clusterID] = (avgEmbeddingBuffer, embeddings.length); + updatesForClusterSummary[clusterID] = + (avgEmbeddingBuffer, embeddings.length); // store the intermediate updates indexedInCurrentRun++; if (updatesForClusterSummary.length > 100) { @@ -996,10 +1041,25 @@ class ClusterFeedbackService { await faceMlDb.getClusterToClusterSummary(personClusters); final clusterSummaryCallTime = DateTime.now(); + // remove personClusters that don't have any summary + for (final clusterID in personClusters.toSet()) { + if (!personClusterToSummary.containsKey(clusterID)) { + _logger.warning('missing summary for $clusterID'); + personClusters.remove(clusterID); + } + } + if (personClusters.isEmpty) { + _logger.warning('No person clusters with summary found'); + return; + } + // Calculate the avg embedding of the person final w = (kDebugMode ? EnteWatch('sortSuggestions') : null)?..start(); - final personEmbeddingsCount = - personClusters.map((e) => personClusterToSummary[e]!.$2).reduce((a, b) => a + b); + int personEmbeddingsCount = 0; + for (final clusterID in personClusters) { + personEmbeddingsCount += personClusterToSummary[clusterID]!.$2; + } + Vector personAvg = Vector.filled(192, 0); for (final personClusterID in personClusters) { final personClusterBlob = personClusterToSummary[personClusterID]!.$1; @@ -1007,7 +1067,8 @@ class ClusterFeedbackService { EVector.fromBuffer(personClusterBlob).values, dtype: DType.float32, ); - final clusterWeight = personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; + final clusterWeight = + personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; personAvg += personClusterAvg * clusterWeight; } w?.log('calculated person avg'); @@ -1038,7 +1099,8 @@ class ClusterFeedbackService { ); final fileIdToDistanceMap = {}; for (final entry in faceIdToVectorMap.entries) { - fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = 1 - personAvg.dot(entry.value); + fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = + 1 - personAvg.dot(entry.value); } w?.log('calculated distances for cluster $clusterID'); suggestion.filesInCluster.sort((b, a) {