diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index f504eafa4b..27df434bce 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -662,14 +662,16 @@ class FaceClusteringService { newEmbeddings.add(oldEmbeddings); final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount); + final newMeanVectorNormalized = newMeanVector / newMeanVector.norm(); newClusterSummaries[clusterId] = ( - EVector(values: newMeanVector.toList()).writeToBuffer(), + EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), oldCount + newCount ); } else { - final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount; + final newMeanVector = newEmbeddings.reduce((a, b) => a + b); + final newMeanVectorNormalized = newMeanVector / newMeanVector.norm(); newClusterSummaries[clusterId] = - (EVector(values: newMeanVector.toList()).writeToBuffer(), newCount); + (EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount); } } log( diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 06167d20d4..d8faa8c0ee 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -868,7 +868,8 @@ class ClusterFeedbackService { ), ); final avg = vectors.reduce((a, b) => a + b) / vectors.length; - final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer(); + final avgNormalized = avg / avg.norm(); + final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer(); updatesForClusterSummary[clusterID] = (avgEmbeddingBuffer, embeddings.length); // store the intermediate updates @@ -882,7 +883,7 @@ class ClusterFeedbackService { ); } } - clusterAvg[clusterID] = avg; + clusterAvg[clusterID] = avgNormalized; } if (updatesForClusterSummary.isNotEmpty) { await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);