Handle missing cluster summary on sorting

This commit is contained in:
Neeraj Gupta
2024-05-31 18:07:04 +05:30
parent 8b53dac00c
commit be3e33f5c5

View File

@@ -67,7 +67,8 @@ class ClusterFeedbackService {
try {
// Get the suggestions for the person using centroids and median
final startTime = DateTime.now();
final List<(int, double, bool)> foundSuggestions = await _getSuggestions(person);
final List<(int, double, bool)> foundSuggestions =
await _getSuggestions(person);
final findSuggestionsTime = DateTime.now();
_logger.info(
'getSuggestionForPerson `_getSuggestions`: Found ${foundSuggestions.length} suggestions in ${findSuggestionsTime.difference(startTime).inMilliseconds} ms',
@@ -141,12 +142,15 @@ class ClusterFeedbackService {
final fileID = getFileIdFromFaceId(faceID);
return files.any((file) => file.uploadedFileID == fileID);
});
final embeddings = await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
final embeddings =
await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime();
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
// Re-cluster within the deleted faces
final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer(
final clusterResult =
await FaceClusteringService.instance.predictWithinClusterComputer(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.20,
@@ -159,14 +163,16 @@ class ClusterFeedbackService {
// Update the deleted faces
await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
await FaceMLDataDB.instance.clusterSummaryUpdate(clusterResult.newClusterSummaries);
await FaceMLDataDB.instance
.clusterSummaryUpdate(clusterResult.newClusterSummaries);
// Make sure the deleted faces don't get suggested in the future
final notClusterIdToPersonId = <int, String>{};
for (final clusterId in newFaceIdToClusterID.values.toSet()) {
notClusterIdToPersonId[clusterId] = p.remoteID;
}
await FaceMLDataDB.instance.bulkCaptureNotPersonFeedback(notClusterIdToPersonId);
await FaceMLDataDB.instance
.bulkCaptureNotPersonFeedback(notClusterIdToPersonId);
Bus.instance.fire(PeopleChangedEvent());
return;
@@ -189,12 +195,15 @@ class ClusterFeedbackService {
final fileID = getFileIdFromFaceId(faceID);
return files.any((file) => file.uploadedFileID == fileID);
});
final embeddings = await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
final embeddings =
await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime();
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
// Re-cluster within the deleted faces
final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer(
final clusterResult =
await FaceClusteringService.instance.predictWithinClusterComputer(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.20,
@@ -206,7 +215,8 @@ class ClusterFeedbackService {
// Update the deleted faces
await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
await FaceMLDataDB.instance.clusterSummaryUpdate(clusterResult.newClusterSummaries);
await FaceMLDataDB.instance
.clusterSummaryUpdate(clusterResult.newClusterSummaries);
Bus.instance.fire(
PeopleChangedEvent(
@@ -309,12 +319,14 @@ class ClusterFeedbackService {
final allClusterToFaceCount = await faceMlDb.clusterIdToFaceCount();
final clustersToInspect = <int>[];
for (final clusterID in allClusterToFaceCount.keys) {
if (allClusterToFaceCount[clusterID]! > 20 && allClusterToFaceCount[clusterID]! < 500) {
if (allClusterToFaceCount[clusterID]! > 20 &&
allClusterToFaceCount[clusterID]! < 500) {
clustersToInspect.add(clusterID);
}
}
final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime();
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
final susClusters = <(int, int)>[];
@@ -325,7 +337,8 @@ class ClusterFeedbackService {
final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs);
final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer(
final clusterResult =
await FaceClusteringService.instance.predictWithinClusterComputer(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.22,
@@ -338,7 +351,8 @@ class ClusterFeedbackService {
continue;
}
final newClusterIdToCount = clusterResult.newClusterIdToFaceIds.map((key, value) {
final newClusterIdToCount =
clusterResult.newClusterIdToFaceIds.map((key, value) {
return MapEntry(key, value.length);
});
final amountOfNewClusters = newClusterIdToCount.length;
@@ -359,12 +373,15 @@ class ClusterFeedbackService {
final int secondBiggestClusterID = clusterIDs.reduce((a, b) {
return newClusterIdToCount[a]! > newClusterIdToCount[b]! ? a : b;
});
final int secondBiggestSize = newClusterIdToCount[secondBiggestClusterID]!;
final int secondBiggestSize =
newClusterIdToCount[secondBiggestClusterID]!;
final secondBiggestRatio = secondBiggestSize / originalClusterSize;
if (biggestRatio < 0.5 || secondBiggestRatio > 0.2) {
final faceIdsOfCluster = await faceMlDb.getFaceIDsForCluster(clusterID);
final uniqueFileIDs = faceIdsOfCluster.map(getFileIdFromFaceId).toSet();
final faceIdsOfCluster =
await faceMlDb.getFaceIDsForCluster(clusterID);
final uniqueFileIDs =
faceIdsOfCluster.map(getFileIdFromFaceId).toSet();
susClusters.add((clusterID, uniqueFileIDs.length));
_logger.info(
'[CheckMixedClusters] Detected that cluster $clusterID with size ${uniqueFileIDs.length} might be mixed',
@@ -409,9 +426,11 @@ class ClusterFeedbackService {
return ClusteringResult.empty();
}
final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime();
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
final clusterResult = await FaceClusteringService.instance.predictWithinClusterComputer(
final clusterResult =
await FaceClusteringService.instance.predictWithinClusterComputer(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.22,
@@ -422,7 +441,8 @@ class ClusterFeedbackService {
return ClusteringResult.empty();
}
final clusterIdToCount = clusterResult.newClusterIdToFaceIds.map((key, value) {
final clusterIdToCount =
clusterResult.newClusterIdToFaceIds.map((key, value) {
return MapEntry(key, value.length);
});
final amountOfNewClusters = clusterIdToCount.length;
@@ -433,7 +453,8 @@ class ClusterFeedbackService {
if (kDebugMode) {
final Set allClusteredFaceIDsSet = {};
for (final List<String> value in clusterResult.newClusterIdToFaceIds.values) {
for (final List<String> value
in clusterResult.newClusterIdToFaceIds.values) {
allClusteredFaceIDsSet.addAll(value);
}
assert((originalFaceIDsSet.difference(allClusteredFaceIDsSet)).isEmpty);
@@ -447,14 +468,16 @@ class ClusterFeedbackService {
try {
// Delete old clusters
await FaceMLDataDB.instance.dropClustersAndPersonTable();
final List<PersonEntity> persons = await PersonService.instance.getPersons();
final List<PersonEntity> persons =
await PersonService.instance.getPersons();
for (final PersonEntity p in persons) {
await PersonService.instance.deletePerson(p.remoteID);
}
// Create new fake clusters based on blur value. One for values between 0 and 10, one for 10-20, etc till 200
final int startClusterID = DateTime.now().microsecondsSinceEpoch;
final faceIDsToBlurValues = await FaceMLDataDB.instance.getFaceIDsToBlurValues(200);
final faceIDsToBlurValues =
await FaceMLDataDB.instance.getFaceIDsToBlurValues(200);
final faceIdToCluster = <String, int>{};
for (final entry in faceIDsToBlurValues.entries) {
final faceID = entry.key;
@@ -527,40 +550,46 @@ class ClusterFeedbackService {
_logger.info(
'L2 norm of current mean: $currentL2Norm',
);
final trueDistance = biggestMean.distanceTo(currentMean, distance: Distance.cosine);
final trueDistance =
biggestMean.distanceTo(currentMean, distance: Distance.cosine);
_logger.info('True distance between the two means: $trueDistance');
// Median distance
const sampleSize = 100;
final Iterable<Uint8List> biggestEmbeddings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(biggestClusterID);
final List<Uint8List> biggestSampledEmbeddingsProto = _randomSampleWithoutReplacement(
final Iterable<Uint8List> biggestEmbeddings = await FaceMLDataDB
.instance
.getFaceEmbeddingsForCluster(biggestClusterID);
final List<Uint8List> biggestSampledEmbeddingsProto =
_randomSampleWithoutReplacement(
biggestEmbeddings,
sampleSize,
);
final List<Vector> biggestSampledEmbeddings = biggestSampledEmbeddingsProto
.map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
final List<Vector> biggestSampledEmbeddings =
biggestSampledEmbeddingsProto
.map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
final Iterable<Uint8List> currentEmbeddings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
final List<Uint8List> currentSampledEmbeddingsProto = _randomSampleWithoutReplacement(
final List<Uint8List> currentSampledEmbeddingsProto =
_randomSampleWithoutReplacement(
currentEmbeddings,
sampleSize,
);
final List<Vector> currentSampledEmbeddings = currentSampledEmbeddingsProto
.map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
final List<Vector> currentSampledEmbeddings =
currentSampledEmbeddingsProto
.map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
// Calculate distances and find the median
final List<double> distances = [];
@@ -579,7 +608,8 @@ class ClusterFeedbackService {
distances.sort();
trueDistances.sort();
final double medianDistance = distances[distances.length ~/ 2];
final double trueMedianDistance = trueDistances[trueDistances.length ~/ 2];
final double trueMedianDistance =
trueDistances[trueDistances.length ~/ 2];
_logger.info(
"Median distance between biggest cluster and current cluster: $medianDistance (using sample of $sampleSize)",
);
@@ -594,7 +624,8 @@ class ClusterFeedbackService {
final List<double> blurValues = await FaceMLDataDB.instance
.getBlurValuesForCluster(clusterID)
.then((value) => value.toList());
final blurValuesIntegers = blurValues.map((value) => value.round()).toList();
final blurValuesIntegers =
blurValues.map((value) => value.round()).toList();
blurValuesIntegers.sort();
_logger.info(
"Blur values for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}: $blurValuesIntegers",
@@ -622,12 +653,14 @@ class ClusterFeedbackService {
final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount();
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
final personFaceIDs = await FaceMLDataDB.instance.getFaceIDsForPerson(p.remoteID);
final personFaceIDs =
await FaceMLDataDB.instance.getFaceIDsForPerson(p.remoteID);
final personFileIDs = personFaceIDs.map(getFileIdFromFaceId).toSet();
w?.log(
'${p.data.name} has ${personClusters.length} existing clusters, getting all database data done',
);
final allClusterIdToFaceIDs = await FaceMLDataDB.instance.getAllClusterIdToFaceIDs();
final allClusterIdToFaceIDs =
await FaceMLDataDB.instance.getAllClusterIdToFaceIDs();
w?.log('getAllClusterIdToFaceIDs done');
// First only do a simple check on the big clusters, if the person does not have small clusters yet
@@ -638,7 +671,8 @@ class ClusterFeedbackService {
late Map<int, Vector> clusterAvgBigClusters;
final List<(int, double)> suggestionsMean = [];
for (final minimumSize in checkSizes.toSet()) {
if (smallestPersonClusterSize >= min(minimumSize, kMinimumClusterSizeSearchResult)) {
if (smallestPersonClusterSize >=
min(minimumSize, kMinimumClusterSizeSearchResult)) {
clusterAvgBigClusters = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
@@ -647,7 +681,8 @@ class ClusterFeedbackService {
w?.log(
'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize',
);
final List<(int, double)> suggestionsMeanBigClusters = await calcSuggestionsMeanInComputer(
final List<(int, double)> suggestionsMeanBigClusters =
await calcSuggestionsMeanInComputer(
clusterAvgBigClusters,
personClusters,
ignoredClusters,
@@ -662,7 +697,8 @@ class ClusterFeedbackService {
.map((faceID) => getFileIdFromFaceId(faceID))
.toSet();
final overlap = personFileIDs.intersection(suggestionSet);
if (overlap.isNotEmpty && ((overlap.length / suggestionSet.length) > 0.5)) {
if (overlap.isNotEmpty &&
((overlap.length / suggestionSet.length) > 0.5)) {
await FaceMLDataDB.instance.captureNotPersonFeedback(
personID: p.remoteID,
clusterID: suggestion.$1,
@@ -672,7 +708,9 @@ class ClusterFeedbackService {
suggestionsMean.add(suggestion);
}
if (suggestionsMean.isNotEmpty) {
return suggestionsMean.map((e) => (e.$1, e.$2, true)).toList(growable: false);
return suggestionsMean
.map((e) => (e.$1, e.$2, true))
.toList(growable: false);
}
}
}
@@ -680,14 +718,16 @@ class ClusterFeedbackService {
// Find the other cluster candidates based on the median
final clusterAvg = clusterAvgBigClusters;
final List<(int, double)> moreSuggestionsMean = await calcSuggestionsMeanInComputer(
final List<(int, double)> moreSuggestionsMean =
await calcSuggestionsMeanInComputer(
clusterAvg,
personClusters,
ignoredClusters,
maxMeanDistance,
);
if (moreSuggestionsMean.isEmpty) {
_logger.info("No suggestions found using mean, even with higher threshold");
_logger
.info("No suggestions found using mean, even with higher threshold");
return [];
}
@@ -709,7 +749,8 @@ class ClusterFeedbackService {
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
personEmbeddingsProto.addAll(embeddings);
}
final List<Uint8List> sampledEmbeddingsProto = _randomSampleWithoutReplacement(
final List<Uint8List> sampledEmbeddingsProto =
_randomSampleWithoutReplacement(
personEmbeddingsProto,
sampleSize,
);
@@ -853,7 +894,8 @@ class ClusterFeedbackService {
// get clusterIDs sorted by count in descending order
final sortedClusterIDs = allClusterIds.toList();
sortedClusterIDs.sort(
(a, b) => allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
(a, b) =>
allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
);
int indexedInCurrentRun = 0;
w?.reset();
@@ -868,7 +910,8 @@ class ClusterFeedbackService {
currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0;
clusterIdsToRead.add(clusterID);
} else {
if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) < maxEmbeddingToRead) {
if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) <
maxEmbeddingToRead) {
clusterIdsToRead.add(clusterID);
currentPendingRead += allClusterIdsToCountMap[clusterID]!;
} else {
@@ -877,8 +920,9 @@ class ClusterFeedbackService {
}
}
final Map<int, Iterable<Uint8List>> clusterEmbeddings =
await FaceMLDataDB.instance.getFaceEmbeddingsForClusters(clusterIdsToRead);
final Map<int, Iterable<Uint8List>> clusterEmbeddings = await FaceMLDataDB
.instance
.getFaceEmbeddingsForClusters(clusterIdsToRead);
w?.logAndReset(
'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters',
@@ -895,7 +939,8 @@ class ClusterFeedbackService {
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
final avgNormalized = avg / avg.norm();
final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer();
updatesForClusterSummary[clusterID] = (avgEmbeddingBuffer, embeddings.length);
updatesForClusterSummary[clusterID] =
(avgEmbeddingBuffer, embeddings.length);
// store the intermediate updates
indexedInCurrentRun++;
if (updatesForClusterSummary.length > 100) {
@@ -996,10 +1041,25 @@ class ClusterFeedbackService {
await faceMlDb.getClusterToClusterSummary(personClusters);
final clusterSummaryCallTime = DateTime.now();
// remove personClusters that don't have any summary
for (final clusterID in personClusters.toSet()) {
if (!personClusterToSummary.containsKey(clusterID)) {
_logger.warning('missing summary for $clusterID');
personClusters.remove(clusterID);
}
}
if (personClusters.isEmpty) {
_logger.warning('No person clusters with summary found');
return;
}
// Calculate the avg embedding of the person
final w = (kDebugMode ? EnteWatch('sortSuggestions') : null)?..start();
final personEmbeddingsCount =
personClusters.map((e) => personClusterToSummary[e]!.$2).reduce((a, b) => a + b);
int personEmbeddingsCount = 0;
for (final clusterID in personClusters) {
personEmbeddingsCount += personClusterToSummary[clusterID]!.$2;
}
Vector personAvg = Vector.filled(192, 0);
for (final personClusterID in personClusters) {
final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
@@ -1007,7 +1067,8 @@ class ClusterFeedbackService {
EVector.fromBuffer(personClusterBlob).values,
dtype: DType.float32,
);
final clusterWeight = personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
final clusterWeight =
personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
personAvg += personClusterAvg * clusterWeight;
}
w?.log('calculated person avg');
@@ -1038,7 +1099,8 @@ class ClusterFeedbackService {
);
final fileIdToDistanceMap = {};
for (final entry in faceIdToVectorMap.entries) {
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = 1 - personAvg.dot(entry.value);
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
1 - personAvg.dot(entry.value);
}
w?.log('calculated distances for cluster $clusterID');
suggestion.filesInCluster.sort((b, a) {