diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 5945175470..905c2a3c49 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -846,6 +846,21 @@ class FaceMLDataDB { return result; } + Future> getClusterToClusterSummary(Iterable clusterIDs) async { + final db = await instance.sqliteAsyncDB; + final Map result = {}; + final rows = await db.getAll( + 'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${clusterIDs.join(",")})', + ); + for (final r in rows) { + final id = r[clusterIDColumn] as int; + final avg = r[avgColumn] as Uint8List; + final count = r[countColumn] as int; + result[id] = (avg, count); + } + return result; + } + Future> getClusterIDToPersonID() async { final db = await instance.database; final List> maps = await db.rawQuery( diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 10f7b90d66..c6c185741c 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -877,21 +877,21 @@ class ClusterFeedbackService { final faceMlDb = FaceMLDataDB.instance; // Get the cluster averages for the person's clusters and the suggestions' clusters - final Map clusterToSummary = - await faceMlDb.getAllClusterSummary(); + final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); + final Map personClusterToSummary = + await faceMlDb.getClusterToClusterSummary(personClusters); final clusterSummaryCallTime = DateTime.now(); // Calculate the avg embedding of the person - final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); final personEmbeddingsCount = personClusters - .map((e) => clusterToSummary[e]!.$2) + .map((e) => personClusterToSummary[e]!.$2) .reduce((a, b) => a + b); final List personAvg = List.filled(192, 0); for (final personClusterID in personClusters) { - final personClusterBlob = clusterToSummary[personClusterID]!.$1; + final personClusterBlob = personClusterToSummary[personClusterID]!.$1; final personClusterAvg = EVector.fromBuffer(personClusterBlob).values; final clusterWeight = - clusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; + personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; for (int i = 0; i < personClusterAvg.length; i++) { personAvg[i] += personClusterAvg[i] * clusterWeight; // Weighted sum of the cluster averages