diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index 6ed5b85e3c..1196bf5689 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -477,6 +477,25 @@ class MLDataDB { return result; } + Future>> getClusterIdToFaceIdsForPerson( + String personID, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceClustersTable.$clusterIDColumn, $faceIDColumn FROM $clusterPersonTable ' + 'INNER JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$clusterIDColumn ' + 'WHERE $personIdColumn = ?', + [personID], + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final faceID = map[faceIDColumn] as String; + result.putIfAbsent(clusterID, () => {}).add(faceID); + } + return result; + } + Future> getFaceIDsForPerson(String personID) async { final db = await instance.asyncDB; final faceIdsResult = await db.getAll( diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index f540897db5..cf8513a613 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -8,6 +8,7 @@ import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/stop_watch.dart"; import "package:photos/models/api/entity/type.dart"; +import "package:photos/models/base/id.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/models/ml/face/face.dart'; import "package:photos/models/ml/face/person.dart"; @@ -280,6 +281,8 @@ class PersonService { for (var e in entities) { final personData = PersonData.fromJson(json.decode(e.data)); int faceCount = 0; + + // Locally store the assignment of faces to clusters and people for (var cluster in personData.assigned!) { faceCount += cluster.faces.length; for (var faceId in cluster.faces) { @@ -305,6 +308,42 @@ class PersonService { "Person ${e.id} ${personData.name} has ${personData.assigned!.length} clusters with $faceCount faces", ); } + + // Locally store the rejection of faces to a person + if (personData.rejectedFaceIDs != null) { + final personFaceIDs = await faceMLDataDB.getFaceIDsForPerson(e.id); + final rejectedFaceIDsSet = personData.rejectedFaceIDs!.toSet(); + final remotelyRejectedFaceIDs = + rejectedFaceIDsSet.intersection(personFaceIDs); + if (remotelyRejectedFaceIDs.isNotEmpty) { + logger.info( + "Person ${e.id} ${personData.name} has ${remotelyRejectedFaceIDs.length} rejected faces", + ); + // Assign rejected faces to new clusters + for (final faceId in remotelyRejectedFaceIDs) { + faceIdToClusterID[faceId] = newClusterID(); + } + // Check that we don't have any empty clusters now + final dbPersonClusterInfo = + await faceMLDataDB.getClusterIdToFaceIdsForPerson(e.id); + for (final clusterIdToFaceIDs in dbPersonClusterInfo.entries) { + final clusterID = clusterIdToFaceIDs.key; + final faceIDs = clusterIdToFaceIDs.value; + faceIDs.removeWhere( + (faceID) => remotelyRejectedFaceIDs.contains(faceID), + ); + if (faceIDs.isEmpty) { + logger.info( + "Cluster $clusterID for person ${e.id} ${personData.name} is empty due to rejected faces from remote, removing the cluster from person", + ); + await faceMLDataDB.removeClusterToPerson( + personID: e.id, + clusterID: clusterID, + ); + } + } + } + } } logger.info("Storing feedback for ${faceIdToClusterID.length} faces"); await faceMLDataDB.updateFaceIdToClusterId(faceIdToClusterID);