From f1fd74b11966b0a30d756a5e5e5bdeafd52ad7f1 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 4 Apr 2024 18:47:09 +0530 Subject: [PATCH] [mob] Option to add/remove face to cluster from file info --- mobile/lib/face/db.dart | 12 ++ .../face_ml/feedback/cluster_feedback.dart | 6 + .../ui/viewer/file_details/face_widget.dart | 179 ++++++++++++------ .../file_details/faces_item_widget.dart | 26 ++- 4 files changed, 161 insertions(+), 62 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 0dca25e939..024ce41c42 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -763,4 +763,16 @@ class FaceMLDataDB { } await forceUpdateClusterIds(faceIDToClusterID); } + + Future addFacesToCluster( + List faceIDs, + int clusterID, + ) async { + final faceIDToClusterID = {}; + for (final faceID in faceIDs) { + faceIDToClusterID[faceID] = clusterID; + } + + await forceUpdateClusterIds(faceIDToClusterID); + } } diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 4437192895..cef19c7a28 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -340,6 +340,12 @@ class ClusterFeedbackService { return; } + Future addFilesToCluster(List faceIDs, int clusterID) async { + await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID); + Bus.instance.fire(PeopleChangedEvent()); + return; + } + Future checkAndDoAutomaticMerges(Person p) async { final faceMlDb = FaceMLDataDB.instance; final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index aec0080016..c11695b588 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -2,12 +2,14 @@ import "dart:developer" show log; import "dart:io" show Platform; import "dart:typed_data"; +import "package:flutter/cupertino.dart"; import "package:flutter/foundation.dart" show kDebugMode; import "package:flutter/material.dart"; import "package:photos/face/db.dart"; import "package:photos/face/model/face.dart"; import "package:photos/face/model/person.dart"; import 'package:photos/models/file/file.dart'; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/search_service.dart"; import "package:photos/theme/ente_theme.dart"; import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; @@ -16,13 +18,15 @@ import "package:photos/ui/viewer/people/cropped_face_image_view.dart"; import "package:photos/ui/viewer/people/people_page.dart"; import "package:photos/utils/face/face_box_crop.dart"; import "package:photos/utils/thumbnail_util.dart"; +// import "package:photos/utils/toast_util.dart"; -class FaceWidget extends StatelessWidget { +class FaceWidget extends StatefulWidget { final EnteFile file; final Face face; final Person? person; final int? clusterID; final bool highlight; + final bool editMode; const FaceWidget( this.file, @@ -30,9 +34,17 @@ class FaceWidget extends StatelessWidget { this.person, this.clusterID, this.highlight = false, + this.editMode = false, Key? key, }) : super(key: key); + @override + State createState() => _FaceWidgetState(); +} + +class _FaceWidgetState extends State { + bool isJustRemoved = false; + @override Widget build(BuildContext context) { if (Platform.isIOS || Platform.isAndroid) { @@ -43,22 +55,27 @@ class FaceWidget extends StatelessWidget { final ImageProvider imageProvider = MemoryImage(snapshot.data!); return GestureDetector( onTap: () async { - log( - "FaceWidget is tapped, with person $person and clusterID $clusterID", - name: "FaceWidget", - ); - if (person == null && clusterID == null) { + if (widget.editMode) { + _cornerIconPressed(); return; } - if (person != null) { + + log( + "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}", + name: "FaceWidget", + ); + if (widget.person == null && widget.clusterID == null) { + return; + } + if (widget.person != null) { await Navigator.of(context).push( MaterialPageRoute( builder: (context) => PeoplePage( - person: person!, + person: widget.person!, ), ), ); - } else if (clusterID != null) { + } else if (widget.clusterID != null) { final fileIdsToClusterIds = await FaceMLDataDB.instance.getFileIdToClusterIds(); final files = await SearchService.instance.getAllFiles(); @@ -66,7 +83,7 @@ class FaceWidget extends StatelessWidget { .where( (file) => fileIdsToClusterIds[file.uploadedFileID] - ?.contains(clusterID) ?? + ?.contains(widget.clusterID) ?? false, ) .toList(); @@ -74,7 +91,7 @@ class FaceWidget extends StatelessWidget { MaterialPageRoute( builder: (context) => ClusterPage( clusterFiles, - clusterID: clusterID!, + clusterID: widget.clusterID!, ), ), ); @@ -82,46 +99,69 @@ class FaceWidget extends StatelessWidget { }, child: Column( children: [ - // TODO: the edges of the green line are still not properly rounded around ClipRRect - Container( - height: 60, - width: 60, - decoration: ShapeDecoration( - shape: RoundedRectangleBorder( - borderRadius: - const BorderRadius.all(Radius.elliptical(16, 12)), - side: highlight - ? BorderSide( - color: getEnteColorScheme(context).primary700, - width: 1.0, - ) - : BorderSide.none, - ), - ), - child: ClipRRect( - borderRadius: - const BorderRadius.all(Radius.elliptical(16, 12)), - child: SizedBox( - width: 60, + Stack( + children: [ + Container( height: 60, - child: Image( - image: imageProvider, - fit: BoxFit.cover, + width: 60, + decoration: ShapeDecoration( + shape: RoundedRectangleBorder( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12), + ), + side: widget.highlight + ? BorderSide( + color: + getEnteColorScheme(context).primary700, + width: 1.0, + ) + : BorderSide.none, + ), + ), + child: ClipRRect( + borderRadius: + const BorderRadius.all(Radius.elliptical(16, 12)), + child: SizedBox( + width: 60, + height: 60, + child: Image( + image: imageProvider, + fit: BoxFit.cover, + ), + ), ), ), - ), + // TODO: the edges of the green line are still not properly rounded around ClipRRect + if (widget.editMode) + Positioned( + right: 0, + top: 0, + child: GestureDetector( + onTap: _cornerIconPressed, + child: isJustRemoved + ? const Icon( + CupertinoIcons.add_circled_solid, + color: Colors.green, + ) + : const Icon( + Icons.cancel, + color: Colors.red, + ), + ), + ), + ], ), const SizedBox(height: 8), - if (person != null) + if (widget.person != null) Text( - person!.attr.name.trim(), + widget.person!.attr.name.trim(), style: Theme.of(context).textTheme.bodySmall, overflow: TextOverflow.ellipsis, maxLines: 1, ), if (kDebugMode) Text( - 'S: ${face.score.toStringAsFixed(3)}', + 'S: ${widget.face.score.toStringAsFixed(3)}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), @@ -168,21 +208,21 @@ class FaceWidget extends StatelessWidget { return GestureDetector( onTap: () async { log( - "FaceWidget is tapped, with person $person and clusterID $clusterID", + "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}", name: "FaceWidget", ); - if (person == null && clusterID == null) { + if (widget.person == null && widget.clusterID == null) { return; } - if (person != null) { + if (widget.person != null) { await Navigator.of(context).push( MaterialPageRoute( builder: (context) => PeoplePage( - person: person!, + person: widget.person!, ), ), ); - } else if (clusterID != null) { + } else if (widget.clusterID != null) { final fileIdsToClusterIds = await FaceMLDataDB.instance.getFileIdToClusterIds(); final files = await SearchService.instance.getAllFiles(); @@ -190,7 +230,7 @@ class FaceWidget extends StatelessWidget { .where( (file) => fileIdsToClusterIds[file.uploadedFileID] - ?.contains(clusterID) ?? + ?.contains(widget.clusterID) ?? false, ) .toList(); @@ -198,7 +238,7 @@ class FaceWidget extends StatelessWidget { MaterialPageRoute( builder: (context) => ClusterPage( clusterFiles, - clusterID: clusterID!, + clusterID: widget.clusterID!, ), ), ); @@ -213,7 +253,7 @@ class FaceWidget extends StatelessWidget { shape: RoundedRectangleBorder( borderRadius: const BorderRadius.all(Radius.elliptical(16, 12)), - side: highlight + side: widget.highlight ? BorderSide( color: getEnteColorScheme(context).primary700, width: 2.0, @@ -228,23 +268,23 @@ class FaceWidget extends StatelessWidget { width: 60, height: 60, child: CroppedFaceImageView( - enteFile: file, - face: face, + enteFile: widget.file, + face: widget.face, ), ), ), ), const SizedBox(height: 8), - if (person != null) + if (widget.person != null) Text( - person!.attr.name.trim(), + widget.person!.attr.name.trim(), style: Theme.of(context).textTheme.bodySmall, overflow: TextOverflow.ellipsis, maxLines: 1, ), if (kDebugMode) Text( - 'S: ${face.score.toStringAsFixed(3)}', + 'S: ${widget.face.score.toStringAsFixed(3)}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), @@ -256,36 +296,55 @@ class FaceWidget extends StatelessWidget { } } + void _cornerIconPressed() async { + log('face widget (file info) corner icon is pressed'); + try { + if (isJustRemoved) { + await ClusterFeedbackService.instance + .addFilesToCluster([widget.face.faceID], widget.clusterID!); + } else { + await ClusterFeedbackService.instance + .removeFilesFromCluster([widget.file], widget.clusterID!); + } + + setState(() { + isJustRemoved = !isJustRemoved; + }); + } catch (e, s) { + log("removing face/file from cluster from file info widget failed: $e, \n $s"); + } + } + Future getFaceCrop() async { try { - final Uint8List? cachedFace = faceCropCache.get(face.faceID); + final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID); if (cachedFace != null) { return cachedFace; } - final faceCropCacheFile = cachedFaceCropPath(face.faceID); + final faceCropCacheFile = cachedFaceCropPath(widget.face.faceID); if ((await faceCropCacheFile.exists())) { final data = await faceCropCacheFile.readAsBytes(); - faceCropCache.put(face.faceID, data); + faceCropCache.put(widget.face.faceID, data); return data; } final result = await pool.withResource( () async => await getFaceCrops( - file, + widget.file, { - face.faceID: face.detection.box, + widget.face.faceID: widget.face.detection.box, }, ), ); - final Uint8List? computedCrop = result?[face.faceID]; + final Uint8List? computedCrop = result?[widget.face.faceID]; if (computedCrop != null) { - faceCropCache.put(face.faceID, computedCrop); + faceCropCache.put(widget.face.faceID, computedCrop); faceCropCacheFile.writeAsBytes(computedCrop).ignore(); } return computedCrop; } catch (e, s) { log( - "Error getting face for faceID: ${face.faceID}", + "Error getting face for faceID: ${widget.face.faceID}", error: e, stackTrace: s, ); diff --git a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart index 3a541a4779..505dbf43bd 100644 --- a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart +++ b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart @@ -9,23 +9,44 @@ import "package:photos/ui/components/buttons/chip_button_widget.dart"; import "package:photos/ui/components/info_item_widget.dart"; import "package:photos/ui/viewer/file_details/face_widget.dart"; -class FacesItemWidget extends StatelessWidget { +class FacesItemWidget extends StatefulWidget { final EnteFile file; const FacesItemWidget(this.file, {super.key}); + @override + State createState() => _FacesItemWidgetState(); +} + +class _FacesItemWidgetState extends State { + bool editMode = false; + + @override + void initState() { + super.initState(); + setState(() {}); + } + @override Widget build(BuildContext context) { return InfoItemWidget( key: const ValueKey("Faces"), leadingIcon: Icons.face_retouching_natural_outlined, - subtitleSection: _faceWidgets(context, file), + subtitleSection: _faceWidgets(context, widget.file, editMode), hasChipButtons: true, + editOnTap: _toggleEditMode, ); } + void _toggleEditMode() { + setState(() { + editMode = !editMode; + }); + } + Future> _faceWidgets( BuildContext context, EnteFile file, + bool editMode, ) async { try { if (file.uploadedFileID == null) { @@ -84,6 +105,7 @@ class FacesItemWidget extends StatelessWidget { clusterID: clusterID, person: person, highlight: highlight, + editMode: highlight ? editMode : false, ), ); }