[mob][photos] Normalize weighted embeddings for cluster summary

This commit is contained in:
laurenspriem
2024-05-06 17:03:06 +05:30
parent 71b572917e
commit b4736fb1d6
2 changed files with 8 additions and 5 deletions

View File

@@ -662,14 +662,16 @@ class FaceClusteringService {
newEmbeddings.add(oldEmbeddings);
final newMeanVector =
newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
newClusterSummaries[clusterId] = (
EVector(values: newMeanVector.toList()).writeToBuffer(),
EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
oldCount + newCount
);
} else {
final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
final newMeanVector = newEmbeddings.reduce((a, b) => a + b);
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
newClusterSummaries[clusterId] =
(EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
(EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount);
}
}
log(

View File

@@ -868,7 +868,8 @@ class ClusterFeedbackService {
),
);
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
final avgNormalized = avg / avg.norm();
final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer();
updatesForClusterSummary[clusterID] =
(avgEmbeddingBuffer, embeddings.length);
// store the intermediate updates
@@ -882,7 +883,7 @@ class ClusterFeedbackService {
);
}
}
clusterAvg[clusterID] = avg;
clusterAvg[clusterID] = avgNormalized;
}
if (updatesForClusterSummary.isNotEmpty) {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);