[mob][photos] Normalize weighted embeddings for cluster summary
This commit is contained in:
@@ -662,14 +662,16 @@ class FaceClusteringService {
|
||||
newEmbeddings.add(oldEmbeddings);
|
||||
final newMeanVector =
|
||||
newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
|
||||
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
|
||||
newClusterSummaries[clusterId] = (
|
||||
EVector(values: newMeanVector.toList()).writeToBuffer(),
|
||||
EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
|
||||
oldCount + newCount
|
||||
);
|
||||
} else {
|
||||
final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
|
||||
final newMeanVector = newEmbeddings.reduce((a, b) => a + b);
|
||||
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
|
||||
newClusterSummaries[clusterId] =
|
||||
(EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
|
||||
(EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount);
|
||||
}
|
||||
}
|
||||
log(
|
||||
|
||||
@@ -868,7 +868,8 @@ class ClusterFeedbackService {
|
||||
),
|
||||
);
|
||||
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
||||
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
||||
final avgNormalized = avg / avg.norm();
|
||||
final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer();
|
||||
updatesForClusterSummary[clusterID] =
|
||||
(avgEmbeddingBuffer, embeddings.length);
|
||||
// store the intermediate updates
|
||||
@@ -882,7 +883,7 @@ class ClusterFeedbackService {
|
||||
);
|
||||
}
|
||||
}
|
||||
clusterAvg[clusterID] = avg;
|
||||
clusterAvg[clusterID] = avgNormalized;
|
||||
}
|
||||
if (updatesForClusterSummary.isNotEmpty) {
|
||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||
|
||||
Reference in New Issue
Block a user