From 38381f6bb0eb4143fefb26c8425067660ec8ad10 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Sat, 13 Apr 2024 16:02:49 +0530 Subject: [PATCH 01/39] [mob] Tiny face widget debug change --- mobile/lib/ui/viewer/file_details/face_widget.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index 4bd4a7bb59..0fe8958e34 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -164,7 +164,7 @@ class _FaceWidgetState extends State { ), if (kDebugMode) Text( - 'B: ${widget.face.blur.toStringAsFixed(3)}', + 'B: ${widget.face.blur.toStringAsFixed(0)}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), From ef4135f378da4dcba56216a824886dc73d5640d5 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Sat, 13 Apr 2024 16:41:51 +0530 Subject: [PATCH 02/39] [mob] tiny change --- .../machine_learning/face_ml/feedback/cluster_feedback.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index d89b007cce..d1cdb2b7b1 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -194,7 +194,7 @@ class ClusterFeedbackService { // TODO: iterate over this method to find sweet spot Future>> breakUpCluster( int clusterID, { - useDbscan = false, + bool useDbscan = false, }) async { _logger.info( 'breakUpCluster called for cluster $clusterID with dbscan $useDbscan', From e3b1cb8014b48eb91ec2cfa1f78a73661bf8f17d Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 16 Apr 2024 14:33:13 +0530 Subject: [PATCH 03/39] [mob] Debug log blur values of cluster when opening cluster page --- mobile/lib/face/db.dart | 20 ++++++++++++++++ .../face_ml/feedback/cluster_feedback.dart | 24 +++++++++++++++++++ mobile/lib/ui/viewer/people/cluster_page.dart | 7 ++++++ 3 files changed, 51 insertions(+) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index f192607ec9..e12242987d 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -344,6 +344,26 @@ class FaceMLDataDB { return maps.map((e) => e[fcFaceId] as String).toSet(); } + Future> getBlurValuesForCluster(int clusterID) async { + final db = await instance.sqliteAsyncDB; + const String query = ''' + SELECT $facesTable.$faceBlur + FROM $facesTable + JOIN $faceClustersTable ON $facesTable.$faceIDColumn = $faceClustersTable.$fcFaceId + WHERE $faceClustersTable.$fcClusterID = ? + '''; + // const String query2 = ''' + // SELECT $faceBlur + // FROM $facesTable + // WHERE $faceIDColumn IN (SELECT $fcFaceId FROM $faceClustersTable WHERE $fcClusterID = ?) + // '''; + final List> maps = await db.getAll( + query, + [clusterID], + ); + return maps.map((e) => e[faceBlur] as double).toSet(); + } + Future> getFaceIdsToClusterIds( Iterable faceIds, ) async { diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index d1cdb2b7b1..6ce0aeb1cd 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -295,6 +295,30 @@ class ClusterFeedbackService { return clusterIdToFaceIds; } + Future debugLogClusterBlurValues( + int clusterID, { + int? clusterSize, + }) async { + final List blurValues = await FaceMLDataDB.instance + .getBlurValuesForCluster(clusterID) + .then((value) => value.toList()); + + // Round the blur values to integers + final blurValuesIntegers = + blurValues.map((value) => value.round()).toList(); + + // Sort the blur values in ascending order + blurValuesIntegers.sort(); + + // Log the sorted blur values + + _logger.info( + "Blur values for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}: $blurValuesIntegers", + ); + + return; + } + /// Returns a map of person's clusterID to map of closest clusterID to with disstance Future>> getSuggestionsUsingMean( PersonEntity p, { diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart index e2cdaca525..7c3a7c3370 100644 --- a/mobile/lib/ui/viewer/people/cluster_page.dart +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -1,5 +1,6 @@ import "dart:async"; +import "package:flutter/foundation.dart"; import 'package:flutter/material.dart'; import "package:flutter_animate/flutter_animate.dart"; import 'package:photos/core/event_bus.dart'; @@ -69,6 +70,12 @@ class _ClusterPageState extends State { setState(() {}); } }); + kDebugMode + ? ClusterFeedbackService.instance.debugLogClusterBlurValues( + widget.clusterID, + clusterSize: files.length, + ) + : null; } @override From 624a06c3f88585d78456a87443bf266d354c7228 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 16 Apr 2024 14:37:56 +0530 Subject: [PATCH 04/39] [mob] Deprecate methods --- .../machine_learning/face_ml/face_detection/detection.dart | 1 + 1 file changed, 1 insertion(+) diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index 7aa0881417..c95a5e0d53 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -16,6 +16,7 @@ abstract class Detection { String toString(); } +@Deprecated('Old method only used in other deprecated methods') extension BBoxExtension on List { void roundBoxToDouble() { final widthRounded = (this[2] - this[0]).roundToDouble(); From e3b8d8975f7c09a27e7606707e2b62aff79e57a6 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 16 Apr 2024 14:51:23 +0530 Subject: [PATCH 05/39] [mob] Better blur detection handling background noise --- .../face_ml/face_detection/detection.dart | 50 +++++++++++++- .../blur_detection_service.dart | 68 +++++++++++++++---- mobile/lib/utils/image_ml_util.dart | 11 ++- 3 files changed, 107 insertions(+), 22 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index c95a5e0d53..c63ef51c20 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -1,7 +1,24 @@ -import 'dart:math' show sqrt, pow; +import 'dart:math' show max, min, pow, sqrt; import "package:photos/face/model/dimension.dart"; +enum FaceDirection { left, right, straight } + +extension FaceDirectionExtension on FaceDirection { + String toDirectionString() { + switch (this) { + case FaceDirection.left: + return 'L'; + case FaceDirection.right: + return 'R'; + case FaceDirection.straight: + return 'S'; + default: + throw Exception('Unknown FaceDirection'); + } + } +} + abstract class Detection { final double score; @@ -426,6 +443,37 @@ class FaceDetectionAbsolute extends Detection { /// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight]. double get height => yMaxBox - yMinBox; + + FaceDirection getFaceDirection() { + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0])) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + final bool noseCloseToLeftEye = + (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX; + final bool noseCloseToRightEye = + (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX; + + // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) { + if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { + return FaceDirection.left; + // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) { + } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { + return FaceDirection.right; + } + + return FaceDirection.straight; + } } List relativeToAbsoluteDetections({ diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart index 43f6b252d2..a2a28c3626 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart @@ -1,4 +1,5 @@ import 'package:logging/logging.dart'; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; class BlurDetectionService { @@ -12,8 +13,10 @@ class BlurDetectionService { Future<(bool, double)> predictIsBlurGrayLaplacian( List> grayImage, { int threshold = kLaplacianThreshold, + FaceDirection faceDirection = FaceDirection.straight, }) async { - final List> laplacian = _applyLaplacian(grayImage); + final List> laplacian = + _applyLaplacian(grayImage, faceDirection: faceDirection); final double variance = _calculateVariance(laplacian); _logger.info('Variance: $variance'); return (variance < threshold, variance); @@ -46,43 +49,80 @@ class BlurDetectionService { return variance; } - List> _padImage(List> image) { + List> _padImage( + List> image, { + int removeSideColumns = 56, + FaceDirection faceDirection = FaceDirection.straight, + }) { + // Exception is removeSideColumns is not even + if (removeSideColumns % 2 != 0) { + throw Exception('removeSideColumns must be even'); + } + final int numRows = image.length; final int numCols = image[0].length; + final int paddedNumCols = numCols + 2 - removeSideColumns; + final int paddedNumRows = numRows + 2; // Create a new matrix with extra padding final List> paddedImage = List.generate( - numRows + 2, - (i) => List.generate(numCols + 2, (j) => 0, growable: false), + paddedNumRows, + (i) => List.generate( + paddedNumCols, + (j) => 0, + growable: false, + ), growable: false, ); - // Copy original image into the center of the padded image - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - paddedImage[i + 1][j + 1] = image[i][j]; + // Copy original image into the center of the padded image, taking into account the face direction + if (faceDirection == FaceDirection.straight) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = + image[i][j + (removeSideColumns / 2).round()]; + } + } + // If the face is facing left, we only take the right side of the face image + } else if (faceDirection == FaceDirection.left) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns]; + } + } + // If the face is facing right, we only take the left side of the face image + } else if (faceDirection == FaceDirection.right) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = image[i][j]; + } } } // Reflect padding // Top and bottom rows - for (int j = 1; j <= numCols; j++) { + for (int j = 1; j <= (paddedNumCols - 2); j++) { paddedImage[0][j] = paddedImage[2][j]; // Top row paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row } // Left and right columns for (int i = 0; i < numRows + 2; i++) { paddedImage[i][0] = paddedImage[i][2]; // Left column - paddedImage[i][numCols + 1] = paddedImage[i][numCols - 1]; // Right column + paddedImage[i][paddedNumCols - 1] = + paddedImage[i][paddedNumCols - 3]; // Right column } return paddedImage; } - List> _applyLaplacian(List> image) { - final List> paddedImage = _padImage(image); - final int numRows = image.length; - final int numCols = image[0].length; + List> _applyLaplacian( + List> image, { + FaceDirection faceDirection = FaceDirection.straight, + }) { + final List> paddedImage = + _padImage(image, faceDirection: faceDirection); + final int numRows = paddedImage.length - 2; + final int numCols = paddedImage[0].length - 2; final List> outputImage = List.generate( numRows, (i) => List.generate(numCols, (j) => 0, growable: false), diff --git a/mobile/lib/utils/image_ml_util.dart b/mobile/lib/utils/image_ml_util.dart index ab29eb9192..7ce10e306f 100644 --- a/mobile/lib/utils/image_ml_util.dart +++ b/mobile/lib/utils/image_ml_util.dart @@ -1099,19 +1099,16 @@ Future<(Float32List, List, List, List, Size)> imageHeight: image.height, ); - final List>> faceLandmarks = - absoluteFaces.map((face) => face.allKeypoints).toList(); - final alignedImagesFloat32List = - Float32List(3 * width * height * faceLandmarks.length); + Float32List(3 * width * height * absoluteFaces.length); final alignmentResults = []; final isBlurs = []; final blurValues = []; int alignedImageIndex = 0; - for (final faceLandmark in faceLandmarks) { + for (final face in absoluteFaces) { final (alignmentResult, correctlyEstimated) = - SimilarityTransform.instance.estimate(faceLandmark); + SimilarityTransform.instance.estimate(face.allKeypoints); if (!correctlyEstimated) { alignedImageIndex += 3 * width * height; alignmentResults.add(AlignmentResult.empty()); @@ -1137,7 +1134,7 @@ Future<(Float32List, List, List, List, Size)> final grayscalems = blurDetectionStopwatch.elapsedMilliseconds; log('creating grayscale matrix took $grayscalems ms'); final (isBlur, blurValue) = await BlurDetectionService.instance - .predictIsBlurGrayLaplacian(faceGrayMatrix); + .predictIsBlurGrayLaplacian(faceGrayMatrix, faceDirection: face.getFaceDirection()); final blurms = blurDetectionStopwatch.elapsedMilliseconds - grayscalems; log('blur detection took $blurms ms'); log( From 618b152f177c5b20824ff1c97854d45b1b1aca4d Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 16 Apr 2024 15:04:15 +0530 Subject: [PATCH 06/39] [mob] FaceMlService static method for full pipeline --- .../face_ml/face_ml_service.dart | 162 +++++++++--------- 1 file changed, 85 insertions(+), 77 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 4efcc444d4..0eeda2758b 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -204,83 +204,9 @@ class FaceMlService { try { switch (function) { case FaceMlOperation.analyzeImage: - final int enteFileID = args["enteFileID"] as int; - final String imagePath = args["filePath"] as String; - final int faceDetectionAddress = - args["faceDetectionAddress"] as int; - final int faceEmbeddingAddress = - args["faceEmbeddingAddress"] as int; - - final resultBuilder = - FaceMlResultBuilder.fromEnteFileID(enteFileID); - - dev.log( - "Start analyzing image with uploadedFileID: $enteFileID inside the isolate", - ); - final stopwatchTotal = Stopwatch()..start(); - final stopwatch = Stopwatch()..start(); - - // Decode the image once to use for both face detection and alignment - final imageData = await File(imagePath).readAsBytes(); - final image = await decodeImageFromData(imageData); - final ByteData imgByteData = await getByteDataFromImage(image); - dev.log('Reading and decoding image took ' - '${stopwatch.elapsedMilliseconds} ms'); - stopwatch.reset(); - - // Get the faces - final List faceDetectionResult = - await FaceMlService.detectFacesSync( - image, - imgByteData, - faceDetectionAddress, - resultBuilder: resultBuilder, - ); - - dev.log( - "${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in " - "${stopwatch.elapsedMilliseconds} ms"); - - // If no faces were detected, return a result with no faces. Otherwise, continue. - if (faceDetectionResult.isEmpty) { - dev.log( - "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in " - "${stopwatch.elapsedMilliseconds} ms"); - sendPort.send(resultBuilder.buildNoFaceDetected().toJsonString()); - break; - } - - stopwatch.reset(); - // Align the faces - final Float32List faceAlignmentResult = - await FaceMlService.alignFacesSync( - image, - imgByteData, - faceDetectionResult, - resultBuilder: resultBuilder, - ); - - dev.log("Completed `alignFacesSync` function, in " - "${stopwatch.elapsedMilliseconds} ms"); - - stopwatch.reset(); - // Get the embeddings of the faces - final embeddings = await FaceMlService.embedFacesSync( - faceAlignmentResult, - faceEmbeddingAddress, - resultBuilder: resultBuilder, - ); - - dev.log("Completed `embedFacesSync` function, in " - "${stopwatch.elapsedMilliseconds} ms"); - - stopwatch.stop(); - stopwatchTotal.stop(); - dev.log("Finished Analyze image (${embeddings.length} faces) with " - "uploadedFileID $enteFileID, in " - "${stopwatchTotal.elapsedMilliseconds} ms"); - - sendPort.send(resultBuilder.build().toJsonString()); + final FaceMlResult result = + await FaceMlService.analyzeImageSync(args); + sendPort.send(result.toJsonString()); break; } } catch (e, stackTrace) { @@ -875,6 +801,7 @@ class FaceMlService { } } + /// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageSync] in the isolate. Future analyzeImageInSingleIsolate(EnteFile enteFile) async { _checkEnteFileForID(enteFile); await ensureInitialized(); @@ -931,6 +858,87 @@ class FaceMlService { return result; } + static Future analyzeImageSync(Map args) async { + try { + final int enteFileID = args["enteFileID"] as int; + final String imagePath = args["filePath"] as String; + final int faceDetectionAddress = args["faceDetectionAddress"] as int; + final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int; + + final resultBuilder = FaceMlResultBuilder.fromEnteFileID(enteFileID); + + dev.log( + "Start analyzing image with uploadedFileID: $enteFileID inside the isolate", + ); + final stopwatchTotal = Stopwatch()..start(); + final stopwatch = Stopwatch()..start(); + + // Decode the image once to use for both face detection and alignment + final imageData = await File(imagePath).readAsBytes(); + final image = await decodeImageFromData(imageData); + final ByteData imgByteData = await getByteDataFromImage(image); + dev.log('Reading and decoding image took ' + '${stopwatch.elapsedMilliseconds} ms'); + stopwatch.reset(); + + // Get the faces + final List faceDetectionResult = + await FaceMlService.detectFacesSync( + image, + imgByteData, + faceDetectionAddress, + resultBuilder: resultBuilder, + ); + + dev.log( + "${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + // If no faces were detected, return a result with no faces. Otherwise, continue. + if (faceDetectionResult.isEmpty) { + dev.log( + "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in " + "${stopwatch.elapsedMilliseconds} ms"); + return resultBuilder.buildNoFaceDetected(); + } + + stopwatch.reset(); + // Align the faces + final Float32List faceAlignmentResult = + await FaceMlService.alignFacesSync( + image, + imgByteData, + faceDetectionResult, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `alignFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.reset(); + // Get the embeddings of the faces + final embeddings = await FaceMlService.embedFacesSync( + faceAlignmentResult, + faceEmbeddingAddress, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `embedFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.stop(); + stopwatchTotal.stop(); + dev.log("Finished Analyze image (${embeddings.length} faces) with " + "uploadedFileID $enteFileID, in " + "${stopwatchTotal.elapsedMilliseconds} ms"); + + return resultBuilder.build(); + } catch (e, s) { + dev.log("Could not analyze image: \n e: $e \n s: $s"); + rethrow; + } + } + Future _getImagePathForML( EnteFile enteFile, { FileDataForML typeOfData = FileDataForML.fileData, From b1eb6c11b2e8de6149e334bb51f03b1a1df04fb5 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 16 Apr 2024 15:11:51 +0530 Subject: [PATCH 07/39] [mob] Show face direction in face widget in debug mode --- mobile/lib/face/model/detection.dart | 40 +++++++++++++++++++ .../face_ml/face_detection/detection.dart | 6 +-- .../ui/viewer/file_details/face_widget.dart | 7 ++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/face/model/detection.dart index cd7ff6c64f..0718340b27 100644 --- a/mobile/lib/face/model/detection.dart +++ b/mobile/lib/face/model/detection.dart @@ -1,6 +1,9 @@ +import "dart:math" show min, max; + import "package:logging/logging.dart"; import "package:photos/face/model/box.dart"; import "package:photos/face/model/landmark.dart"; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; /// Stores the face detection data, notably the bounding box and landmarks. /// @@ -89,4 +92,41 @@ class Detection { return -1; } } + + FaceDirection getFaceDirection() { + final leftEye = [landmarks[0].x, landmarks[0].y]; + final rightEye = [landmarks[1].x, landmarks[1].y]; + final nose = [landmarks[2].x, landmarks[2].y]; + final leftMouth = [landmarks[3].x, landmarks[3].y]; + final rightMouth = [landmarks[4].x, landmarks[4].y]; + + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0])) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + final bool noseCloseToLeftEye = + (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX; + final bool noseCloseToRightEye = + (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX; + + // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) { + if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { + return FaceDirection.left; + // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) { + } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { + return FaceDirection.right; + } + + return FaceDirection.straight; + } } diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index c63ef51c20..7e74f872ef 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -8,11 +8,11 @@ extension FaceDirectionExtension on FaceDirection { String toDirectionString() { switch (this) { case FaceDirection.left: - return 'L'; + return 'Left'; case FaceDirection.right: - return 'R'; + return 'Right'; case FaceDirection.straight: - return 'S'; + return 'Straigt'; default: throw Exception('Unknown FaceDirection'); } diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index 0fe8958e34..c33004eff4 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -9,6 +9,7 @@ import "package:photos/face/db.dart"; import "package:photos/face/model/face.dart"; import "package:photos/face/model/person.dart"; import 'package:photos/models/file/file.dart'; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/search_service.dart"; import "package:photos/theme/ente_theme.dart"; @@ -168,6 +169,12 @@ class _FaceWidgetState extends State { style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), + if (kDebugMode) + Text( + 'D: ${widget.face.detection.getFaceDirection().toDirectionString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), if (kDebugMode) Text( 'V: ${widget.face.visibility}', From d4d9253f1ac5854403be3252bd31bc7822ac1afe Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 16 Apr 2024 15:22:59 +0530 Subject: [PATCH 08/39] [mob] Minor changes --- .../machine_learning/face_ml/face_detection/detection.dart | 2 +- .../lib/services/machine_learning/face_ml/face_ml_service.dart | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index 7e74f872ef..de8535c871 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -12,7 +12,7 @@ extension FaceDirectionExtension on FaceDirection { case FaceDirection.right: return 'Right'; case FaceDirection.straight: - return 'Straigt'; + return 'Straight'; default: throw Exception('Unknown FaceDirection'); } diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 0eeda2758b..d9d964ec53 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -204,8 +204,10 @@ class FaceMlService { try { switch (function) { case FaceMlOperation.analyzeImage: + final time = DateTime.now(); final FaceMlResult result = await FaceMlService.analyzeImageSync(args); + dev.log("`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms"); sendPort.send(result.toJsonString()); break; } From f61a5f91b6de03ee3285d2558f9f8261f48ffb77 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 16 Apr 2024 19:22:27 +0530 Subject: [PATCH 09/39] [mob] Debug option to get clusters of blur values --- mobile/lib/face/db.dart | 14 ++++++++ .../face_ml/feedback/cluster_feedback.dart | 32 +++++++++++++++++++ .../debug/face_debug_section_widget.dart | 29 +++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index e12242987d..0cf952ac4e 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -364,6 +364,20 @@ class FaceMLDataDB { return maps.map((e) => e[faceBlur] as double).toSet(); } + Future> getFaceIDsToBlurValues( + int maxBlurValue, + ) async { + final db = await instance.sqliteAsyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $faceBlur FROM $facesTable WHERE $faceBlur < $maxBlurValue AND $faceBlur > 1 ORDER BY $faceBlur ASC', + ); + final Map result = {}; + for (final map in maps) { + result[map[faceIDColumn] as String] = map[faceBlur] as double; + } + return result; + } + Future> getFaceIdsToClusterIds( Iterable faceIds, ) async { diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 6ce0aeb1cd..f82a082383 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -295,6 +295,38 @@ class ClusterFeedbackService { return clusterIdToFaceIds; } + /// WARNING: this method is purely for debugging purposes, never use in production + Future createFakeClustersByBlurValue() async { + try { + // Delete old clusters + await FaceMLDataDB.instance.resetClusterIDs(); + await FaceMLDataDB.instance.dropClustersAndPersonTable(); + final List persons = + await PersonService.instance.getPersons(); + for (final PersonEntity p in persons) { + await PersonService.instance.deletePerson(p.remoteID); + } + + // Create new fake clusters based on blur value. One for values between 0 and 10, one for 10-20, etc till 200 + final int startClusterID = DateTime.now().microsecondsSinceEpoch; + final faceIDsToBlurValues = + await FaceMLDataDB.instance.getFaceIDsToBlurValues(200); + final faceIdToCluster = {}; + for (final entry in faceIDsToBlurValues.entries) { + final faceID = entry.key; + final blurValue = entry.value; + final newClusterID = startClusterID + blurValue ~/ 10; + faceIdToCluster[faceID] = newClusterID; + } + await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); + + Bus.instance.fire(PeopleChangedEvent()); + } catch (e, s) { + _logger.severe("Error in createFakeClustersByBlurValue", e, s); + rethrow; + } + } + Future debugLogClusterBlurValues( int clusterID, { int? clusterSize, diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart index 85aa992a31..f1c835fa26 100644 --- a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -8,6 +8,7 @@ import "package:photos/events/people_changed_event.dart"; import "package:photos/face/db.dart"; import "package:photos/face/model/person.dart"; import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart'; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/theme/ente_theme.dart'; import 'package:photos/ui/components/captioned_text_widget.dart'; @@ -284,6 +285,34 @@ class _FaceDebugSectionWidgetState extends State { }, ), sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Rank blurs", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: + "This will delete all clusters and put blurry faces in separate clusters per ten points.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + await ClusterFeedbackService.instance + .createFakeClustersByBlurValue(); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('Failed to rank faces on blur values ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ); + }, + ), + sectionOptionSpacing, MenuItemWidget( captionedTextWidget: const CaptionedTextWidget( title: "Drop embeddings & feedback", From 6f26901073ef15380164c93c9ae3d330691b1984 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 17 Apr 2024 14:54:31 +0530 Subject: [PATCH 10/39] [mob] Minor refactor of clustering methods --- .../face_clustering_service.dart | 58 ++++++------------- 1 file changed, 19 insertions(+), 39 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index b5f17b54c7..c33b9992b6 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -100,31 +100,11 @@ class FaceClusteringService { try { switch (function) { case ClusterOperation.linearIncrementalClustering: - final input = args['input'] as Map; - final fileIDToCreationTime = - args['fileIDToCreationTime'] as Map?; - final distanceThreshold = args['distanceThreshold'] as double; - final offset = args['offset'] as int?; - final result = FaceClusteringService._runLinearClustering( - input, - fileIDToCreationTime: fileIDToCreationTime, - distanceThreshold: distanceThreshold, - offset: offset, - ); + final result = FaceClusteringService._runLinearClustering(args); sendPort.send(result); break; case ClusterOperation.dbscanClustering: - final input = args['input'] as Map; - final fileIDToCreationTime = - args['fileIDToCreationTime'] as Map?; - final eps = args['eps'] as double; - final minPts = args['minPts'] as int; - final result = FaceClusteringService._runDbscanClustering( - input, - fileIDToCreationTime: fileIDToCreationTime, - eps: eps, - minPts: minPts, - ); + final result = FaceClusteringService._runDbscanClustering(args); sendPort.send(result); break; } @@ -194,7 +174,7 @@ class FaceClusteringService { _inactivityTimer?.cancel(); } - /// Runs the clustering algorithm on the given [input], in an isolate. + /// Runs the clustering algorithm [_runLinearClustering] on the given [input], in an isolate. /// /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. /// @@ -299,19 +279,19 @@ class FaceClusteringService { return clusterFaceIDs; } - static Map _runLinearClustering( - Map x, { - Map? fileIDToCreationTime, - double distanceThreshold = kRecommendedDistanceThreshold, - int? offset, - }) { + static Map _runLinearClustering(Map args) { + final input = args['input'] as Map; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final distanceThreshold = args['distanceThreshold'] as double; + final offset = args['offset'] as int?; + log( - "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", + "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", ); // Organize everything into a list of FaceInfo objects final List faceInfos = []; - for (final entry in x.entries) { + for (final entry in input.entries) { faceInfos.add( FaceInfo( faceID: entry.key, @@ -517,14 +497,14 @@ class FaceClusteringService { ); } - static List> _runDbscanClustering( - Map x, { - Map? fileIDToCreationTime, - double eps = 0.3, - int minPts = 5, - }) { + static List> _runDbscanClustering(Map args) { + final input = args['input'] as Map; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final eps = args['eps'] as double; + final minPts = args['minPts'] as int; + log( - "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", + "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", ); final DBSCAN dbscan = DBSCAN( @@ -535,7 +515,7 @@ class FaceClusteringService { // Organize everything into a list of FaceInfo objects final List faceInfos = []; - for (final entry in x.entries) { + for (final entry in input.entries) { faceInfos.add( FaceInfo( faceID: entry.key, From 72e677e9e50fddceca637bb8c8ed2c5fd0bfa5e3 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 17 Apr 2024 16:36:45 +0530 Subject: [PATCH 11/39] [mob] Minimum cluster size of 20 --- mobile/lib/services/search_service.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index 7e1eb5eaaa..3f54187c16 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -824,7 +824,7 @@ class SearchService { "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", ); } - if (files.length < 3 && sortedClusterIds.length > 3) { + if (files.length < 20 && sortedClusterIds.length > 3) { continue; } facesResult.add( From 51d15cc4414eb0e1a96914383ca6631b842e68a4 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 17 Apr 2024 16:38:47 +0530 Subject: [PATCH 12/39] [mob] Clustering with dynamic threshold based on face blur and score --- mobile/lib/face/db.dart | 62 +++++++++++++++++-- mobile/lib/face/model/face.dart | 2 +- .../face_clustering_service.dart | 50 ++++++++++++--- .../face_info_for_clustering.dart | 18 ++++++ .../blur_detection_service.dart | 2 +- .../face_filtering_constants.dart | 3 +- .../face_ml/face_ml_result.dart | 4 +- .../face_ml/face_ml_service.dart | 20 +++--- .../face_ml/feedback/cluster_feedback.dart | 15 ++++- 9 files changed, 146 insertions(+), 30 deletions(-) create mode 100644 mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 0cf952ac4e..6b48ef7cac 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -12,6 +12,7 @@ import 'package:photos/face/db_fields.dart'; import "package:photos/face/db_model_mappers.dart"; import "package:photos/face/model/face.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import 'package:sqflite/sqflite.dart'; import 'package:sqlite_async/sqlite_async.dart' as sqlite_async; @@ -444,12 +445,63 @@ class FaceMLDataDB { ); } + Future> getFaceInfoForClustering({ + double minScore = kMinHighQualityFaceScore, + int minClarity = kLaplacianHardThreshold, + int maxFaces = 20000, + int offset = 0, + int batchSize = 10000, + }) async { + final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start(); + w.logAndReset( + 'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize', + ); + final db = await instance.sqliteAsyncDB; + + final Set result = {}; + while (true) { + // Query a batch of rows + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur FROM $facesTable' + ' WHERE $faceScore > $minScore AND $faceBlur > $minClarity' + ' ORDER BY $faceIDColumn' + ' DESC LIMIT $batchSize OFFSET $offset', + ); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + final List faceIds = []; + for (final map in maps) { + faceIds.add(map[faceIDColumn] as String); + } + final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds); + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + final faceInfo = FaceInfoForClustering( + faceID: faceID, + clusterId: faceIdToClusterId[faceID], + embeddingBytes: map[faceEmbeddingBlob] as Uint8List, + faceScore: map[faceScore] as double, + blurValue: map[faceBlur] as double, + ); + result.add(faceInfo); + } + if (result.length >= maxFaces) { + break; + } + offset += batchSize; + } + w.stopWithLog('done reading face embeddings ${result.length}'); + return result; + } + /// Returns a map of faceID to record of clusterId and faceEmbeddingBlob /// /// Only selects faces with score greater than [minScore] and blur score greater than [minClarity] Future> getFaceEmbeddingMap({ double minScore = kMinHighQualityFaceScore, - int minClarity = kLaplacianThreshold, + int minClarity = kLaplacianHardThreshold, int maxFaces = 20000, int offset = 0, int batchSize = 10000, @@ -515,7 +567,7 @@ class FaceMLDataDB { facesTable, columns: [faceIDColumn, faceEmbeddingBlob], where: - '$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', + '$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', limit: batchSize, offset: offset, orderBy: '$faceIDColumn DESC', @@ -542,7 +594,7 @@ class FaceMLDataDB { }) async { final db = await instance.sqliteAsyncDB; final List> maps = await db.getAll( - 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold', + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold', ); return maps.first['count'] as int; } @@ -551,7 +603,7 @@ class FaceMLDataDB { final db = await instance.sqliteAsyncDB; final List> totalFacesMaps = await db.getAll( - 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold', + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold', ); final int totalFaces = totalFacesMaps.first['count'] as int; @@ -564,7 +616,7 @@ class FaceMLDataDB { } Future getBlurryFaceCount([ - int blurThreshold = kLaplacianThreshold, + int blurThreshold = kLaplacianHardThreshold, ]) async { final db = await instance.database; final List> maps = await db.rawQuery( diff --git a/mobile/lib/face/model/face.dart b/mobile/lib/face/model/face.dart index 631eeb141c..c480bde794 100644 --- a/mobile/lib/face/model/face.dart +++ b/mobile/lib/face/model/face.dart @@ -20,7 +20,7 @@ class Face { final double blur; FileInfo? fileInfo; - bool get isBlurry => blur < kLaplacianThreshold; + bool get isBlurry => blur < kLaplacianHardThreshold; bool get hasHighScore => score > kMinHighQualityFaceScore; diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index c33b9992b6..d2445c99e9 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -9,12 +9,16 @@ import "package:ml_linalg/dtype.dart"; import "package:ml_linalg/vector.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:simple_cluster/simple_cluster.dart"; import "package:synchronized/synchronized.dart"; class FaceInfo { final String faceID; + final double? faceScore; + final double? blurValue; final List? embedding; final Vector? vEmbedding; int? clusterId; @@ -23,6 +27,8 @@ class FaceInfo { int? fileCreationTime; FaceInfo({ required this.faceID, + this.faceScore, + this.blurValue, this.embedding, this.vEmbedding, this.clusterId, @@ -49,6 +55,7 @@ class FaceClusteringService { bool isRunning = false; static const kRecommendedDistanceThreshold = 0.24; + static const kConservativeDistanceThreshold = 0.06; // singleton pattern FaceClusteringService._privateConstructor(); @@ -180,9 +187,11 @@ class FaceClusteringService { /// /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic. Future?> predictLinear( - Map input, { + Set input, { Map? fileIDToCreationTime, double distanceThreshold = kRecommendedDistanceThreshold, + double conservativeDistanceThreshold = kConservativeDistanceThreshold, + bool useDynamicThreshold = true, int? offset, }) async { if (input.isEmpty) { @@ -212,6 +221,8 @@ class FaceClusteringService { 'input': input, 'fileIDToCreationTime': fileIDToCreationTime, 'distanceThreshold': distanceThreshold, + 'conservativeDistanceThreshold': conservativeDistanceThreshold, + 'useDynamicThreshold': useDynamicThreshold, 'offset': offset, } ), @@ -280,9 +291,13 @@ class FaceClusteringService { } static Map _runLinearClustering(Map args) { - final input = args['input'] as Map; + // final input = args['input'] as Map; + final input = args['input'] as Set; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; final distanceThreshold = args['distanceThreshold'] as double; + final conservativeDistanceThreshold = + args['conservativeDistanceThreshold'] as double; + final useDynamicThreshold = args['useDynamicThreshold'] as bool; final offset = args['offset'] as int?; log( @@ -291,17 +306,19 @@ class FaceClusteringService { // Organize everything into a list of FaceInfo objects final List faceInfos = []; - for (final entry in input.entries) { + for (final face in input) { faceInfos.add( FaceInfo( - faceID: entry.key, + faceID: face.faceID, + faceScore: face.faceScore, + blurValue: face.blurValue, vEmbedding: Vector.fromList( - EVector.fromBuffer(entry.value.$2).values, + EVector.fromBuffer(face.embeddingBytes).values, dtype: DType.float32, ), - clusterId: entry.value.$1, + clusterId: face.clusterId, fileCreationTime: - fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], + fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)], ), ); } @@ -341,6 +358,7 @@ class FaceClusteringService { // Make sure the first face has a clusterId final int totalFaces = sortedFaceInfos.length; + int dynamicThresholdCount = 0; if (sortedFaceInfos.isEmpty) { return {}; @@ -368,6 +386,17 @@ class FaceClusteringService { int closestIdx = -1; double closestDistance = double.infinity; + late double thresholdValue; + if (useDynamicThreshold) { + final bool badFace = + (sortedFaceInfos[i].faceScore! < kMinHighQualityFaceScore || + sortedFaceInfos[i].blurValue! < kLaplacianSoftThreshold); + thresholdValue = + badFace ? conservativeDistanceThreshold : distanceThreshold; + if (badFace) dynamicThresholdCount++; + } else { + thresholdValue = distanceThreshold; + } if (i % 250 == 0) { log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces"); } @@ -396,7 +425,7 @@ class FaceClusteringService { } } - if (closestDistance < distanceThreshold) { + if (closestDistance < thresholdValue) { if (sortedFaceInfos[closestIdx].clusterId == null) { // Ideally this should never happen, but just in case log it log( @@ -432,6 +461,11 @@ class FaceClusteringService { log( ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', ); + if (useDynamicThreshold) { + log( + "[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or high blur value", + ); + } // analyze the results FaceClusteringService._analyzeClusterResults(sortedFaceInfos); diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart new file mode 100644 index 0000000000..8dd78ff1f0 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart @@ -0,0 +1,18 @@ + +import "dart:typed_data" show Uint8List; + +class FaceInfoForClustering { + final String faceID; + final int? clusterId; + final Uint8List embeddingBytes; + final double faceScore; + final double blurValue; + + FaceInfoForClustering({ + required this.faceID, + this.clusterId, + required this.embeddingBytes, + required this.faceScore, + required this.blurValue, + }); +} \ No newline at end of file diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart index a2a28c3626..9c8d2d8c80 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart @@ -12,7 +12,7 @@ class BlurDetectionService { Future<(bool, double)> predictIsBlurGrayLaplacian( List> grayImage, { - int threshold = kLaplacianThreshold, + int threshold = kLaplacianHardThreshold, FaceDirection faceDirection = FaceDirection.straight, }) async { final List> laplacian = diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart index 8fb1838c04..7a3ef61e3e 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart @@ -1,7 +1,8 @@ import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart'; /// Blur detection threshold -const kLaplacianThreshold = 15; +const kLaplacianHardThreshold = 15; +const kLaplacianSoftThreshold = 100; /// Default blur value const kLapacianDefault = 10000.0; diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart index 5ad0c4eee8..bb03de519a 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart @@ -504,7 +504,7 @@ class FaceResult { final int fileId; final String faceId; - bool get isBlurry => blurValue < kLaplacianThreshold; + bool get isBlurry => blurValue < kLaplacianHardThreshold; const FaceResult({ required this.detection, @@ -545,7 +545,7 @@ class FaceResultBuilder { int fileId = -1; String faceId = ''; - bool get isBlurry => blurValue < kLaplacianThreshold; + bool get isBlurry => blurValue < kLaplacianHardThreshold; FaceResultBuilder({ required this.fileId, diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index d9d964ec53..556a4135fa 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -310,14 +310,14 @@ class FaceMlService { int bucket = 1; while (true) { - final faceIdToEmbeddingBucket = - await FaceMLDataDB.instance.getFaceEmbeddingMap( + final faceInfoForClustering = + await FaceMLDataDB.instance.getFaceInfoForClustering( minScore: minFaceScore, maxFaces: bucketSize, offset: offset, batchSize: batchSize, ); - if (faceIdToEmbeddingBucket.isEmpty) { + if (faceInfoForClustering.isEmpty) { _logger.warning( 'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces', ); @@ -332,7 +332,7 @@ class FaceMlService { final faceIdToCluster = await FaceClusteringService.instance.predictLinear( - faceIdToEmbeddingBucket, + faceInfoForClustering, fileIDToCreationTime: fileIDToCreationTime, offset: offset, ); @@ -343,7 +343,7 @@ class FaceMlService { await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); _logger.info( - 'Done with clustering ${offset + faceIdToEmbeddingBucket.length} embeddings (${(100 * (offset + faceIdToEmbeddingBucket.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset', + 'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset', ); if (offset + bucketSize >= totalFaces) { _logger.info('All faces clustered'); @@ -355,14 +355,14 @@ class FaceMlService { } else { // Read all the embeddings from the database, in a map from faceID to embedding final clusterStartTime = DateTime.now(); - final faceIdToEmbedding = - await FaceMLDataDB.instance.getFaceEmbeddingMap( + final faceInfoForClustering = + await FaceMLDataDB.instance.getFaceInfoForClustering( minScore: minFaceScore, maxFaces: totalFaces, ); final gotFaceEmbeddingsTime = DateTime.now(); _logger.info( - 'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms', + 'read embeddings ${faceInfoForClustering.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms', ); // Read the creation times from Files DB, in a map from fileID to creation time @@ -374,7 +374,7 @@ class FaceMlService { // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID final faceIdToCluster = await FaceClusteringService.instance.predictLinear( - faceIdToEmbedding, + faceInfoForClustering, fileIDToCreationTime: fileIDToCreationTime, ); if (faceIdToCluster == null) { @@ -383,7 +383,7 @@ class FaceMlService { } final clusterDoneTime = DateTime.now(); _logger.info( - 'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', + 'done with clustering ${faceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', ); // Store the updated clusterIDs in the database diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index f82a082383..8c38ac2764 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; @@ -233,14 +235,23 @@ class ClusterFeedbackService { } } else { final clusteringInput = embeddings.map((key, value) { - return MapEntry(key, (null, value)); - }); + return MapEntry( + key, + FaceInfoForClustering( + faceID: key, + embeddingBytes: value, + faceScore: kMinHighQualityFaceScore + 0.01, + blurValue: kLapacianDefault, + ), + ); + }).values.toSet(); final faceIdToCluster = await FaceClusteringService.instance.predictLinear( clusteringInput, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.23, + useDynamicThreshold: false, ); if (faceIdToCluster == null) { From e3fd8369013d66bdc81229a0d8357b2c7159b2b0 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 17 Apr 2024 17:46:53 +0530 Subject: [PATCH 13/39] [mob] clustering make it difficult for good faces to link to bad faces --- .../face_clustering_service.dart | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index d2445c99e9..a71da45bce 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -19,6 +19,7 @@ class FaceInfo { final String faceID; final double? faceScore; final double? blurValue; + final bool? badFace; final List? embedding; final Vector? vEmbedding; int? clusterId; @@ -29,6 +30,7 @@ class FaceInfo { required this.faceID, this.faceScore, this.blurValue, + this.badFace, this.embedding, this.vEmbedding, this.clusterId, @@ -312,6 +314,8 @@ class FaceClusteringService { faceID: face.faceID, faceScore: face.faceScore, blurValue: face.blurValue, + badFace: face.faceScore < kMinHighQualityFaceScore || + face.blurValue < kLaplacianSoftThreshold, vEmbedding: Vector.fromList( EVector.fromBuffer(face.embeddingBytes).values, dtype: DType.float32, @@ -388,12 +392,10 @@ class FaceClusteringService { double closestDistance = double.infinity; late double thresholdValue; if (useDynamicThreshold) { - final bool badFace = - (sortedFaceInfos[i].faceScore! < kMinHighQualityFaceScore || - sortedFaceInfos[i].blurValue! < kLaplacianSoftThreshold); - thresholdValue = - badFace ? conservativeDistanceThreshold : distanceThreshold; - if (badFace) dynamicThresholdCount++; + thresholdValue = sortedFaceInfos[i].badFace! + ? conservativeDistanceThreshold + : distanceThreshold; + if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++; } else { thresholdValue = distanceThreshold; } @@ -414,6 +416,10 @@ class FaceClusteringService { ); } if (distance < closestDistance) { + if (sortedFaceInfos[j].badFace! && + distance > conservativeDistanceThreshold) { + continue; + } closestDistance = distance; closestIdx = j; // if (distance < distanceThreshold) { From 45d18b187c4be7ca7be9028a6a3780ea6e71fd67 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 11:25:48 +0530 Subject: [PATCH 14/39] [mob] Add completeClustering functionality --- .../face_clustering_service.dart | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index a71da45bce..25259a0c0c 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -4,6 +4,7 @@ import "dart:isolate"; import "dart:math" show max; import "dart:typed_data"; +import "package:computer/computer.dart"; import "package:logging/logging.dart"; import "package:ml_linalg/dtype.dart"; import "package:ml_linalg/vector.dart"; @@ -42,6 +43,7 @@ enum ClusterOperation { linearIncrementalClustering, dbscanClustering } class FaceClusteringService { final _logger = Logger("FaceLinearClustering"); + final _computer = Computer.shared(); Timer? _inactivityTimer; final Duration _inactivityDuration = const Duration(minutes: 3); @@ -243,6 +245,45 @@ class FaceClusteringService { } } + Future> predictComplete( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + }) async { + if (input.isEmpty) { + _logger.warning( + "Complete Clustering dataset of embeddings is empty, returning empty list.", + ); + return {}; + } + + // Clustering inside the isolate + _logger.info( + "Start Complete clustering on ${input.length} embeddings inside computer isolate", + ); + + try { + final startTime = DateTime.now(); + final faceIdToCluster = await _computer.compute( + runCompleteClustering, + param: { + "input": input, + "fileIDToCreationTime": fileIDToCreationTime, + "distanceThreshold": distanceThreshold, + }, + taskName: "createImageEmbedding", + ) as Map; + final endTime = DateTime.now(); + _logger.info( + "Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms", + ); + return faceIdToCluster; + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + Future>> predictDbscan( Map input, { Map? fileIDToCreationTime, @@ -537,6 +578,104 @@ class FaceClusteringService { ); } + + + static Map runCompleteClustering(Map args) { + final input = args['input'] as Map; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final distanceThreshold = args['distanceThreshold'] as double; + + log( + "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering", + ); + + // Organize everything into a list of FaceInfo objects + final List faceInfos = []; + for (final entry in input.entries) { + faceInfos.add( + FaceInfo( + faceID: entry.key, + vEmbedding: Vector.fromList( + EVector.fromBuffer(entry.value).values, + dtype: DType.float32, + ), + fileCreationTime: + fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], + ), + ); + } + + // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first + if (fileIDToCreationTime != null) { + faceInfos.sort((a, b) { + if (a.fileCreationTime == null && b.fileCreationTime == null) { + return 0; + } else if (a.fileCreationTime == null) { + return 1; + } else if (b.fileCreationTime == null) { + return -1; + } else { + return a.fileCreationTime!.compareTo(b.fileCreationTime!); + } + }); + } + + if (faceInfos.isEmpty) { + return {}; + } + final int totalFaces = faceInfos.length; + + // Start actual clustering + log( + "[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering", + ); + + // set current epoch time as clusterID + int clusterID = DateTime.now().microsecondsSinceEpoch; + + // Start actual clustering + final Map newFaceIdToCluster = {}; + final stopwatchClustering = Stopwatch()..start(); + for (int i = 0; i < totalFaces; i++) { + int closestIdx = -1; + double closestDistance = double.infinity; + if (i % 250 == 0) { + log("[CompleteClustering] ${DateTime.now()} Processed $i faces"); + } + for (int j = 0; j < totalFaces; j++) { + if (i == j) continue; + final double distance = + 1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!); + if (distance < closestDistance) { + closestDistance = distance; + closestIdx = j; + } + } + + if (closestDistance < distanceThreshold) { + if (faceInfos[closestIdx].clusterId == null) { + clusterID++; + faceInfos[closestIdx].clusterId = clusterID; + newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID; + } + faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!; + newFaceIdToCluster[faceInfos[i].faceID] = + faceInfos[closestIdx].clusterId!; + } else { + clusterID++; + faceInfos[i].clusterId = clusterID; + newFaceIdToCluster[faceInfos[i].faceID] = clusterID; + } + } + + stopwatchClustering.stop(); + log( + ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', + ); + + return newFaceIdToCluster; + } + static List> _runDbscanClustering(Map args) { final input = args['input'] as Map; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; From 7a5e1263e03e0e0346e842439c0ed25431d845a6 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 11:26:24 +0530 Subject: [PATCH 15/39] [mob] Use complete clustering for breaking up clusters --- .../face_ml/feedback/cluster_feedback.dart | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 8c38ac2764..fe64218ec8 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -13,8 +13,6 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; -import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; -import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; @@ -234,27 +232,14 @@ class ClusterFeedbackService { maxClusterID++; } } else { - final clusteringInput = embeddings.map((key, value) { - return MapEntry( - key, - FaceInfoForClustering( - faceID: key, - embeddingBytes: value, - faceScore: kMinHighQualityFaceScore + 0.01, - blurValue: kLapacianDefault, - ), - ); - }).values.toSet(); - final faceIdToCluster = - await FaceClusteringService.instance.predictLinear( - clusteringInput, + await FaceClusteringService.instance.predictComplete( + embeddings, fileIDToCreationTime: fileIDToCreationTime, - distanceThreshold: 0.23, - useDynamicThreshold: false, + distanceThreshold: 0.30, ); - if (faceIdToCluster == null) { + if (faceIdToCluster.isEmpty) { _logger.info('No clusters found'); return {}; } else { From ba58ac1358983b81ef60e3d602daaa2d1594dfd0 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 14:44:12 +0530 Subject: [PATCH 16/39] [mob] Add merges to predictComplete method --- .../face_clustering_service.dart | 91 +++++++++++++++++-- .../face_ml/feedback/cluster_feedback.dart | 26 +++++- 2 files changed, 108 insertions(+), 9 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 25259a0c0c..2920cd760d 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -249,6 +249,7 @@ class FaceClusteringService { Map input, { Map? fileIDToCreationTime, double distanceThreshold = kRecommendedDistanceThreshold, + double mergeThreshold = 0.30, }) async { if (input.isEmpty) { _logger.warning( @@ -270,6 +271,7 @@ class FaceClusteringService { "input": input, "fileIDToCreationTime": fileIDToCreationTime, "distanceThreshold": distanceThreshold, + "mergeThreshold": mergeThreshold, }, taskName: "createImageEmbedding", ) as Map; @@ -578,12 +580,11 @@ class FaceClusteringService { ); } - - static Map runCompleteClustering(Map args) { final input = args['input'] as Map; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; final distanceThreshold = args['distanceThreshold'] as double; + final mergeThreshold = args['mergeThreshold'] as double; log( "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering", @@ -637,10 +638,11 @@ class FaceClusteringService { final Map newFaceIdToCluster = {}; final stopwatchClustering = Stopwatch()..start(); for (int i = 0; i < totalFaces; i++) { + if (faceInfos[i].clusterId != null) continue; int closestIdx = -1; double closestDistance = double.infinity; - if (i % 250 == 0) { - log("[CompleteClustering] ${DateTime.now()} Processed $i faces"); + if (i + 1 % 250 == 0) { + log("[CompleteClustering] ${DateTime.now()} Processed ${i - 1} faces"); } for (int j = 0; j < totalFaces; j++) { if (i == j) continue; @@ -656,18 +658,91 @@ class FaceClusteringService { if (faceInfos[closestIdx].clusterId == null) { clusterID++; faceInfos[closestIdx].clusterId = clusterID; - newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID; } faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!; - newFaceIdToCluster[faceInfos[i].faceID] = - faceInfos[closestIdx].clusterId!; } else { clusterID++; faceInfos[i].clusterId = clusterID; - newFaceIdToCluster[faceInfos[i].faceID] = clusterID; } } + // Now calculate the mean of the embeddings for each cluster + final Map> clusterIdToFaceInfos = {}; + for (final faceInfo in faceInfos) { + if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) { + clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo); + } else { + clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; + } + } + final Map clusterIdToMeanEmbeddingAndWeight = {}; + for (final clusterId in clusterIdToFaceInfos.keys) { + final List embeddings = clusterIdToFaceInfos[clusterId]! + .map((faceInfo) => faceInfo.vEmbedding!) + .toList(); + final count = clusterIdToFaceInfos[clusterId]!.length; + final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count; + clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count); + } + + // Now merge the clusters that are close to each other, based on mean embedding + final List<(int, int)> mergedClustersList = []; + final List clusterIds = + clusterIdToMeanEmbeddingAndWeight.keys.toList(); + log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges'); + while (true) { + if (clusterIds.length < 2) break; + double distance = double.infinity; + (int, int) clusterIDsToMerge = (-1, -1); + for (int i = 0; i < clusterIds.length; i++) { + for (int j = 0; j < clusterIds.length; j++) { + if (i == j) continue; + final double newDistance = 1.0 - + clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot( + clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1, + ); + if (newDistance < distance) { + distance = newDistance; + clusterIDsToMerge = (clusterIds[i], clusterIds[j]); + } + } + } + if (distance < mergeThreshold) { + mergedClustersList.add(clusterIDsToMerge); + final clusterID1 = clusterIDsToMerge.$1; + final clusterID2 = clusterIDsToMerge.$2; + final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1; + final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1; + final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2; + final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2; + final weight1 = count1 / (count1 + count2); + final weight2 = count2 / (count1 + count2); + clusterIdToMeanEmbeddingAndWeight[clusterID1] = ( + mean1 * weight1 + mean2 * weight2, + count1 + count2, + ); + clusterIdToMeanEmbeddingAndWeight.remove(clusterID2); + clusterIds.remove(clusterID2); + } else { + break; + } + } + log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged'); + + // Now assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + for (final mergedClusters in mergedClustersList) { + if (faceInfo.clusterId == mergedClusters.$2) { + faceInfo.clusterId = mergedClusters.$1; + } + } + } + + // Finally, assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + stopwatchClustering.stop(); log( ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index fe64218ec8..e3edccc515 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; +// import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +// import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; @@ -232,14 +234,36 @@ class ClusterFeedbackService { maxClusterID++; } } else { + // final clusteringInput = embeddings + // .map((key, value) { + // return MapEntry( + // key, + // FaceInfoForClustering( + // faceID: key, + // embeddingBytes: value, + // faceScore: kMinHighQualityFaceScore + 0.01, + // blurValue: kLapacianDefault, + // ), + // ); + // }) + // .values + // .toSet(); + // final faceIdToCluster = + // await FaceClusteringService.instance.predictLinear( + // clusteringInput, + // fileIDToCreationTime: fileIDToCreationTime, + // distanceThreshold: 0.23, + // useDynamicThreshold: false, + // ); final faceIdToCluster = await FaceClusteringService.instance.predictComplete( embeddings, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.30, + mergeThreshold: 0.30, ); - if (faceIdToCluster.isEmpty) { + if (faceIdToCluster == null || faceIdToCluster.isEmpty) { _logger.info('No clusters found'); return {}; } else { From 417c5eab1c0ae967a96c0dec8a71d055bf637daf Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 14:48:57 +0530 Subject: [PATCH 17/39] [mob] Rename cluster method to clarify use of Computer --- .../face_ml/face_clustering/face_clustering_service.dart | 2 +- .../machine_learning/face_ml/feedback/cluster_feedback.dart | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 2920cd760d..581f447aca 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -245,7 +245,7 @@ class FaceClusteringService { } } - Future> predictComplete( + Future> predictCompleteComputer( Map input, { Map? fileIDToCreationTime, double distanceThreshold = kRecommendedDistanceThreshold, diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index e3edccc515..68cae30af8 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -256,7 +256,7 @@ class ClusterFeedbackService { // useDynamicThreshold: false, // ); final faceIdToCluster = - await FaceClusteringService.instance.predictComplete( + await FaceClusteringService.instance.predictCompleteComputer( embeddings, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.30, From afbb1c69c24b4a06ca1a617f8e70cf306c4f7ad2 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 15:20:28 +0530 Subject: [PATCH 18/39] [mob] More clustering methods in Computer within cluster --- .../face_clustering_service.dart | 100 +++++++++++++++++- .../face_ml/feedback/cluster_feedback.dart | 26 +---- 2 files changed, 99 insertions(+), 27 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 581f447aca..f8c180b4f7 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -111,7 +111,7 @@ class FaceClusteringService { try { switch (function) { case ClusterOperation.linearIncrementalClustering: - final result = FaceClusteringService._runLinearClustering(args); + final result = FaceClusteringService.runLinearClustering(args); sendPort.send(result); break; case ClusterOperation.dbscanClustering: @@ -185,7 +185,7 @@ class FaceClusteringService { _inactivityTimer?.cancel(); } - /// Runs the clustering algorithm [_runLinearClustering] on the given [input], in an isolate. + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate. /// /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. /// @@ -245,6 +245,65 @@ class FaceClusteringService { } } + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding + Future?> predictLinearComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + }) async { + if (input.isEmpty) { + _logger.warning( + "Linear Clustering dataset of embeddings is empty, returning empty list.", + ); + return {}; + } + + // Clustering inside the isolate + _logger.info( + "Start Linear clustering on ${input.length} embeddings inside computer isolate", + ); + + try { + final clusteringInput = input + .map((key, value) { + return MapEntry( + key, + FaceInfoForClustering( + faceID: key, + embeddingBytes: value, + faceScore: kMinHighQualityFaceScore + 0.01, + blurValue: kLapacianDefault, + ), + ); + }) + .values + .toSet(); + final startTime = DateTime.now(); + final faceIdToCluster = await _computer.compute( + runLinearClustering, + param: { + "input": clusteringInput, + "fileIDToCreationTime": fileIDToCreationTime, + "distanceThreshold": distanceThreshold, + "conservativeDistanceThreshold": distanceThreshold, + "useDynamicThreshold": false, + }, + taskName: "createImageEmbedding", + ) as Map; + final endTime = DateTime.now(); + _logger.info( + "Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms", + ); + return faceIdToCluster; + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + + /// Runs the clustering algorithm [runCompleteClustering] on the given [input], in computer. + /// + /// WARNING: Only use on small datasets, as it is not optimized for large datasets. Future> predictCompleteComputer( Map input, { Map? fileIDToCreationTime, @@ -286,6 +345,41 @@ class FaceClusteringService { } } + Future?> predictWithinClusterComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + }) async { + _logger.info( + '`predictWithinClusterComputer` called with ${input.length} faces and distance threshold $distanceThreshold', + ); + try { + if (input.length < 100) { + final mergeThreshold = distanceThreshold + 0.06; + _logger.info( + 'Running complete clustering on ${input.length} faces with distance threshold $mergeThreshold', + ); + return predictCompleteComputer( + input, + fileIDToCreationTime: fileIDToCreationTime, + mergeThreshold: mergeThreshold, + ); + } else { + _logger.info( + 'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold', + ); + return predictLinearComputer( + input, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: distanceThreshold, + ); + } + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + Future>> predictDbscan( Map input, { Map? fileIDToCreationTime, @@ -335,7 +429,7 @@ class FaceClusteringService { return clusterFaceIDs; } - static Map _runLinearClustering(Map args) { + static Map runLinearClustering(Map args) { // final input = args['input'] as Map; final input = args['input'] as Set; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 68cae30af8..48feb1ea34 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -234,33 +234,11 @@ class ClusterFeedbackService { maxClusterID++; } } else { - // final clusteringInput = embeddings - // .map((key, value) { - // return MapEntry( - // key, - // FaceInfoForClustering( - // faceID: key, - // embeddingBytes: value, - // faceScore: kMinHighQualityFaceScore + 0.01, - // blurValue: kLapacianDefault, - // ), - // ); - // }) - // .values - // .toSet(); - // final faceIdToCluster = - // await FaceClusteringService.instance.predictLinear( - // clusteringInput, - // fileIDToCreationTime: fileIDToCreationTime, - // distanceThreshold: 0.23, - // useDynamicThreshold: false, - // ); final faceIdToCluster = - await FaceClusteringService.instance.predictCompleteComputer( + await FaceClusteringService.instance.predictWithinClusterComputer( embeddings, fileIDToCreationTime: fileIDToCreationTime, - distanceThreshold: 0.30, - mergeThreshold: 0.30, + distanceThreshold: 0.22, ); if (faceIdToCluster == null || faceIdToCluster.isEmpty) { From 79e763bf162b68438b72ef819e4512858eab728a Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 15:29:02 +0530 Subject: [PATCH 19/39] [mob] Small changes --- .../face_ml/face_clustering/face_clustering_service.dart | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index f8c180b4f7..7d822f9e77 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -732,12 +732,12 @@ class FaceClusteringService { final Map newFaceIdToCluster = {}; final stopwatchClustering = Stopwatch()..start(); for (int i = 0; i < totalFaces; i++) { + if ((i + 1) % 250 == 0) { + log("[CompleteClustering] ${DateTime.now()} Processed ${i + 1} faces"); + } if (faceInfos[i].clusterId != null) continue; int closestIdx = -1; double closestDistance = double.infinity; - if (i + 1 % 250 == 0) { - log("[CompleteClustering] ${DateTime.now()} Processed ${i - 1} faces"); - } for (int j = 0; j < totalFaces; j++) { if (i == j) continue; final double distance = From be1b4b359c26814fd88c9eb0a5c9458ecfb451c0 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 15:32:59 +0530 Subject: [PATCH 20/39] [mob] Fix typo --- mobile/lib/face/db.dart | 38 +++++++++++++++++----------------- mobile/lib/face/db_fields.dart | 12 +++++------ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 6b48ef7cac..bd95d0f942 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -161,27 +161,27 @@ class FaceMLDataDB { final db = await instance.database; // find out clusterIds that are assigned to other persons using the clusters table final List> maps = await db.rawQuery( - 'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', [personID], ); final Set ignoredClusterIDs = - maps.map((e) => e[cluserIDColumn] as int).toSet(); + maps.map((e) => e[clusterIDColumn] as int).toSet(); final List> rejectMaps = await db.rawQuery( - 'SELECT $cluserIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', + 'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', [personID], ); final Set rejectClusterIDs = - rejectMaps.map((e) => e[cluserIDColumn] as int).toSet(); + rejectMaps.map((e) => e[clusterIDColumn] as int).toSet(); return ignoredClusterIDs.union(rejectClusterIDs); } Future> getPersonClusterIDs(String personID) async { final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', [personID], ); - return maps.map((e) => e[cluserIDColumn] as int).toSet(); + return maps.map((e) => e[clusterIDColumn] as int).toSet(); } Future clearTable() async { @@ -250,12 +250,12 @@ class FaceMLDataDB { } final cluterRows = await db.query( clusterPersonTable, - columns: [cluserIDColumn], + columns: [clusterIDColumn], where: '$personIdColumn = ?', whereArgs: [personID], ); final clusterIDs = - cluterRows.map((e) => e[cluserIDColumn] as int).toList(); + cluterRows.map((e) => e[clusterIDColumn] as int).toList(); final List> faceMaps = await db.rawQuery( 'SELECT * FROM $facesTable where ' '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' @@ -641,7 +641,7 @@ class FaceMLDataDB { clusterPersonTable, { personIdColumn: personID, - cluserIDColumn: clusterID, + clusterIDColumn: clusterID, }, ); } @@ -658,7 +658,7 @@ class FaceMLDataDB { clusterPersonTable, { personIdColumn: personID, - cluserIDColumn: clusterID, + clusterIDColumn: clusterID, }, conflictAlgorithm: ConflictAlgorithm.replace, ); @@ -675,7 +675,7 @@ class FaceMLDataDB { notPersonFeedback, { personIdColumn: personID, - cluserIDColumn: clusterID, + clusterIDColumn: clusterID, }, ); } @@ -687,7 +687,7 @@ class FaceMLDataDB { final db = await instance.database; return db.delete( clusterPersonTable, - where: '$personIdColumn = ? AND $cluserIDColumn = ?', + where: '$personIdColumn = ? AND $clusterIDColumn = ?', whereArgs: [personID, clusterID], ); } @@ -699,13 +699,13 @@ class FaceMLDataDB { final List> maps = await db.rawQuery( 'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable ' 'INNER JOIN $clusterPersonTable ' - 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' 'WHERE $clusterPersonTable.$personIdColumn = ?', [personID], ); final Map> result = {}; for (final map in maps) { - final clusterID = map[cluserIDColumn] as int; + final clusterID = map[clusterIDColumn] as int; final String faceID = map[fcFaceId] as String; final fileID = int.parse(faceID.split('_').first); result[fileID] = (result[fileID] ?? {})..add(clusterID); @@ -750,7 +750,7 @@ class FaceMLDataDB { batch.insert( clusterSummaryTable, { - cluserIDColumn: cluserID, + clusterIDColumn: cluserID, avgColumn: avg, countColumn: count, }, @@ -767,7 +767,7 @@ class FaceMLDataDB { final Map result = {}; final rows = await db.rawQuery('SELECT * from $clusterSummaryTable'); for (final r in rows) { - final id = r[cluserIDColumn] as int; + final id = r[clusterIDColumn] as int; final avg = r[avgColumn] as Uint8List; final count = r[countColumn] as int; result[id] = (avg, count); @@ -778,11 +778,11 @@ class FaceMLDataDB { Future> getClusterIDToPersonID() async { final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT $personIdColumn, $cluserIDColumn FROM $clusterPersonTable', + 'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable', ); final Map result = {}; for (final map in maps) { - result[map[cluserIDColumn] as int] = map[personIdColumn] as String; + result[map[clusterIDColumn] as int] = map[personIdColumn] as String; } return result; } @@ -827,7 +827,7 @@ class FaceMLDataDB { final db = await instance.database; final faceIdsResult = await db.rawQuery( 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' - 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' 'WHERE $clusterPersonTable.$personIdColumn = ?', [personID], ); diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart index 2dc98ac1f3..3890a261e6 100644 --- a/mobile/lib/face/db_fields.dart +++ b/mobile/lib/face/db_fields.dart @@ -62,13 +62,13 @@ const deletePersonTable = 'DROP TABLE IF EXISTS $personTable'; // Clusters Table Fields & Schema Queries const clusterPersonTable = 'cluster_person'; const personIdColumn = 'person_id'; -const cluserIDColumn = 'cluster_id'; +const clusterIDColumn = 'cluster_id'; const createClusterPersonTable = ''' CREATE TABLE IF NOT EXISTS $clusterPersonTable ( $personIdColumn TEXT NOT NULL, - $cluserIDColumn INTEGER NOT NULL, - PRIMARY KEY($personIdColumn, $cluserIDColumn) + $clusterIDColumn INTEGER NOT NULL, + PRIMARY KEY($personIdColumn, $clusterIDColumn) ); '''; const dropClusterPersonTable = 'DROP TABLE IF EXISTS $clusterPersonTable'; @@ -80,10 +80,10 @@ const avgColumn = 'avg'; const countColumn = 'count'; const createClusterSummaryTable = ''' CREATE TABLE IF NOT EXISTS $clusterSummaryTable ( - $cluserIDColumn INTEGER NOT NULL, + $clusterIDColumn INTEGER NOT NULL, $avgColumn BLOB NOT NULL, $countColumn INTEGER NOT NULL, - PRIMARY KEY($cluserIDColumn) + PRIMARY KEY($clusterIDColumn) ); '''; @@ -97,7 +97,7 @@ const notPersonFeedback = 'not_person_feedback'; const createNotPersonFeedbackTable = ''' CREATE TABLE IF NOT EXISTS $notPersonFeedback ( $personIdColumn TEXT NOT NULL, - $cluserIDColumn INTEGER NOT NULL + $clusterIDColumn INTEGER NOT NULL ); '''; const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback'; From 74247c4563bd135802820daca65ea27e5449eb66 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 16:00:47 +0530 Subject: [PATCH 21/39] [mob] Use async sqlite for getting faceID from cluster --- mobile/lib/face/db.dart | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index bd95d0f942..bcff08762e 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -335,12 +335,11 @@ class FaceMLDataDB { } Future> getFaceIDsForCluster(int clusterID) async { - final db = await instance.database; - final List> maps = await db.query( - faceClustersTable, - columns: [fcFaceId], - where: '$fcClusterID = ?', - whereArgs: [clusterID], + final db = await instance.sqliteAsyncDB; + final List> maps = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable ' + 'WHERE $faceClustersTable.$fcClusterID = ?', + [clusterID], ); return maps.map((e) => e[fcFaceId] as String).toSet(); } From e7992674d55178ed7e552539b82504a56d0ad4c4 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 16:17:04 +0530 Subject: [PATCH 22/39] [mob] Use async sqlite for getting file creation times --- mobile/lib/db/files_db.dart | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mobile/lib/db/files_db.dart b/mobile/lib/db/files_db.dart index d4a0aae65c..02320d005a 100644 --- a/mobile/lib/db/files_db.dart +++ b/mobile/lib/db/files_db.dart @@ -1325,8 +1325,8 @@ class FilesDB { } Future> getFileIDToCreationTime() async { - final db = await instance.database; - final rows = await db.rawQuery( + final db = await instance.sqliteAsyncDB; + final rows = await db.getAll( ''' SELECT $columnUploadedFileID, $columnCreationTime FROM $filesTable From 34798c344c6e78b79c01e8b72fa7961a61759cb0 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 16:32:11 +0530 Subject: [PATCH 23/39] [mob] Fix typo --- mobile/lib/face/db.dart | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index bcff08762e..27dfa801be 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -410,14 +410,14 @@ class FaceMLDataDB { } Future forceUpdateClusterIds( - Map faceIDToPersonID, + Map faceIDToClusterID, ) async { final db = await instance.database; // Start a batch final batch = db.batch(); - for (final map in faceIDToPersonID.entries) { + for (final map in faceIDToClusterID.entries) { final faceID = map.key; final clusterID = map.value; batch.insert( From e20f13f02bd991bdf5f88978ee8cbc8f9ea35809 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 17:57:14 +0530 Subject: [PATCH 24/39] [mob] Re-cluster when removing file from person/cluster --- mobile/lib/face/db.dart | 69 +++++++++++++++ .../face_ml/feedback/cluster_feedback.dart | 87 +++++++++++++++++-- 2 files changed, 148 insertions(+), 8 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 27dfa801be..6c4c275631 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -344,6 +344,17 @@ class FaceMLDataDB { return maps.map((e) => e[fcFaceId] as String).toSet(); } + Future> getFaceIDsForPerson(String personID) async { + final db = await instance.sqliteAsyncDB; + final faceIdsResult = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', + [personID], + ); + return faceIdsResult.map((e) => e[fcFaceId] as String).toSet(); + } + Future> getBlurValuesForCluster(int clusterID) async { final db = await instance.sqliteAsyncDB; const String query = ''' @@ -588,6 +599,44 @@ class FaceMLDataDB { return result; } + Future> getFaceEmbeddingMapForFaces( + Iterable faceIDs, + ) async { + _logger.info('reading face embeddings for ${faceIDs.length} faces'); + final db = await instance.sqliteAsyncDB; + + // Define the batch size + const batchSize = 10000; + int offset = 0; + + final Map result = {}; + while (true) { + // Query a batch of rows + final String query = ''' + SELECT $faceIDColumn, $faceEmbeddingBlob + FROM $facesTable + WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")}) + ORDER BY $faceIDColumn DESC + LIMIT $batchSize OFFSET $offset + '''; + final List> maps = await db.getAll(query); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + result[faceID] = map[faceEmbeddingBlob] as Uint8List; + } + if (result.length > 10000) { + break; + } + offset += batchSize; + } + _logger.info('done reading face embeddings for ${faceIDs.length} faces'); + return result; + } + Future getTotalFaceCount({ double minFaceScore = kMinHighQualityFaceScore, }) async { @@ -679,6 +728,26 @@ class FaceMLDataDB { ); } + Future bulkCaptureNotPersonFeedback( + Map clusterToPersonID, + ) async { + final db = await instance.database; + final batch = db.batch(); + for (final entry in clusterToPersonID.entries) { + final clusterID = entry.key; + final personID = entry.value; + batch.insert( + notPersonFeedback, + { + personIdColumn: personID, + clusterIDColumn: clusterID, + }, + conflictAlgorithm: ConflictAlgorithm.replace, + ); + } + await batch.commit(noResult: true); + } + Future removeClusterToPerson({ required String personID, required int clusterID, diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 48feb1ea34..d5049b3654 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -117,17 +117,90 @@ class ClusterFeedbackService { List files, PersonEntity p, ) async { - await FaceMLDataDB.instance.removeFilesFromPerson(files, p.remoteID); - Bus.instance.fire(PeopleChangedEvent()); + try { + // Get the relevant faces to be removed + final faceIDs = await FaceMLDataDB.instance + .getFaceIDsForPerson(p.remoteID) + .then((iterable) => iterable.toList()); + faceIDs.retainWhere((faceID) { + final fileID = getFileIdFromFaceId(faceID); + return files.any((file) => file.uploadedFileID == fileID); + }); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + // Re-cluster within the deleted faces + final newFaceIdToClusterID = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.20, + ); + if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) { + return; + } + + // Update the deleted faces + await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); + + // Make sure the deleted faces don't get suggested in the future + final notClusterIdToPersonId = {}; + for (final clusterId in newFaceIdToClusterID.values.toSet()) { + notClusterIdToPersonId[clusterId] = p.remoteID; + } + await FaceMLDataDB.instance + .bulkCaptureNotPersonFeedback(notClusterIdToPersonId); + + Bus.instance.fire(PeopleChangedEvent()); + return; + } catch (e, s) { + _logger.severe("Error in removeFilesFromPerson", e, s); + rethrow; + } } Future removeFilesFromCluster( List files, int clusterID, ) async { - await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID); - Bus.instance.fire(PeopleChangedEvent()); - return; + try { + // Get the relevant faces to be removed + final faceIDs = await FaceMLDataDB.instance + .getFaceIDsForCluster(clusterID) + .then((iterable) => iterable.toList()); + faceIDs.retainWhere((faceID) { + final fileID = getFileIdFromFaceId(faceID); + return files.any((file) => file.uploadedFileID == fileID); + }); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + // Re-cluster within the deleted faces + final newFaceIdToClusterID = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.20, + ); + if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) { + return; + } + + // Update the deleted faces + await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); + + Bus.instance.fire(PeopleChangedEvent()); + return; + } catch (e, s) { + _logger.severe("Error in removeFilesFromCluster", e, s); + rethrow; + } } Future addFilesToCluster(List faceIDs, int clusterID) async { @@ -205,10 +278,8 @@ class ClusterFeedbackService { final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID); final originalFaceIDsSet = faceIDs.toSet(); - final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList(); - final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs); - embeddings.removeWhere((key, value) => !faceIDs.contains(key)); + final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs); final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime(); From ab0a99cf76d1b741e84580c1b47bfd2460325f65 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 11:42:39 +0530 Subject: [PATCH 25/39] [mob] Make sure cluster page is updated after file removal --- mobile/lib/events/files_updated_event.dart | 1 + .../face_ml/feedback/cluster_feedback.dart | 11 +++++++++-- mobile/lib/ui/viewer/people/cluster_page.dart | 11 ++++++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/mobile/lib/events/files_updated_event.dart b/mobile/lib/events/files_updated_event.dart index 7d7779d490..2fc67d6460 100644 --- a/mobile/lib/events/files_updated_event.dart +++ b/mobile/lib/events/files_updated_event.dart @@ -27,4 +27,5 @@ enum EventType { unhide, coverChanged, peopleChanged, + peopleClusterChanged, } diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index d5049b3654..ee098c46f2 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -5,6 +5,8 @@ import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; +import "package:photos/events/files_updated_event.dart"; +import "package:photos/events/local_photos_updated_event.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/stop_watch.dart"; import "package:photos/face/db.dart"; @@ -13,8 +15,6 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; -// import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; -// import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; @@ -196,6 +196,13 @@ class ClusterFeedbackService { await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); Bus.instance.fire(PeopleChangedEvent()); + Bus.instance.fire( + LocalPhotosUpdatedEvent( + files, + type: EventType.peopleClusterChanged, + source: "$clusterID", + ), + ); return; } catch (e, s) { _logger.severe("Error in removeFilesFromCluster", e, s); diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart index 7c3a7c3370..47bfef535b 100644 --- a/mobile/lib/ui/viewer/people/cluster_page.dart +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -69,6 +69,13 @@ class _ClusterPageState extends State { } setState(() {}); } + if (event.type == EventType.peopleClusterChanged && + (event.source == widget.clusterID.toString())) { + for (var updatedFile in event.updatedFiles) { + files.remove(updatedFile); + } + setState(() {}); + } }); kDebugMode ? ClusterFeedbackService.instance.debugLogClusterBlurValues( @@ -107,6 +114,7 @@ class _ClusterPageState extends State { EventType.deletedFromRemote, EventType.deletedFromEverywhere, EventType.hide, + EventType.peopleClusterChanged, }, tagPrefix: widget.tagPrefix + widget.tagPrefix, selectedFiles: _selectedFiles, @@ -118,9 +126,10 @@ class _ClusterPageState extends State { preferredSize: const Size.fromHeight(50.0), child: ClusterAppBar( SearchResultPage.appBarType, - "${widget.searchResult.length} memories${widget.appendTitle}", + "${files.length} memories${widget.appendTitle}", _selectedFiles, widget.clusterID, + key: ValueKey(files.length), ), ), body: Column( From a9ca8a4a2445aeb1b22a8422a28efcf00bd8e525 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 12:19:10 +0530 Subject: [PATCH 26/39] [mob] Better handling of cluster update --- mobile/lib/events/people_changed_event.dart | 21 ++++++++++++++++++- .../face_ml/feedback/cluster_feedback.dart | 18 ++++++++++------ mobile/lib/ui/viewer/people/cluster_page.dart | 10 +++++++-- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/mobile/lib/events/people_changed_event.dart b/mobile/lib/events/people_changed_event.dart index e2d1358664..51f4eaeefe 100644 --- a/mobile/lib/events/people_changed_event.dart +++ b/mobile/lib/events/people_changed_event.dart @@ -1,3 +1,22 @@ import "package:photos/events/event.dart"; +import "package:photos/models/file/file.dart"; -class PeopleChangedEvent extends Event {} +class PeopleChangedEvent extends Event { + final List? relevantFiles; + final PeopleEventType type; + final String source; + + PeopleChangedEvent({ + this.relevantFiles, + this.type = PeopleEventType.defaultType, + this.source = "", + }); + + @override + String get reason => '$runtimeType{type: ${type.name}, "via": $source}'; +} + +enum PeopleEventType { + defaultType, + removedFilesFromCluster, +} \ No newline at end of file diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index ee098c46f2..4cb9dce6c4 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -5,8 +5,8 @@ import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; -import "package:photos/events/files_updated_event.dart"; -import "package:photos/events/local_photos_updated_event.dart"; +// import "package:photos/events/files_updated_event.dart"; +// import "package:photos/events/local_photos_updated_event.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/stop_watch.dart"; import "package:photos/face/db.dart"; @@ -195,14 +195,20 @@ class ClusterFeedbackService { // Update the deleted faces await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); - Bus.instance.fire(PeopleChangedEvent()); Bus.instance.fire( - LocalPhotosUpdatedEvent( - files, - type: EventType.peopleClusterChanged, + PeopleChangedEvent( + relevantFiles: files, + type: PeopleEventType.removedFilesFromCluster, source: "$clusterID", ), ); + // Bus.instance.fire( + // LocalPhotosUpdatedEvent( + // files, + // type: EventType.peopleClusterChanged, + // source: "$clusterID", + // ), + // ); return; } catch (e, s) { _logger.severe("Error in removeFilesFromCluster", e, s); diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart index 47bfef535b..12b830932b 100644 --- a/mobile/lib/ui/viewer/people/cluster_page.dart +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -6,6 +6,7 @@ import "package:flutter_animate/flutter_animate.dart"; import 'package:photos/core/event_bus.dart'; import 'package:photos/events/files_updated_event.dart'; import 'package:photos/events/local_photos_updated_event.dart'; +import "package:photos/events/people_changed_event.dart"; import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/models/file/file.dart'; @@ -52,6 +53,7 @@ class _ClusterPageState extends State { final _selectedFiles = SelectedFiles(); late final List files; late final StreamSubscription _filesUpdatedEvent; + late final StreamSubscription _peopleChangedEvent; @override void initState() { @@ -69,9 +71,11 @@ class _ClusterPageState extends State { } setState(() {}); } - if (event.type == EventType.peopleClusterChanged && + }); + _peopleChangedEvent = Bus.instance.on().listen((event) { + if (event.type == PeopleEventType.removedFilesFromCluster && (event.source == widget.clusterID.toString())) { - for (var updatedFile in event.updatedFiles) { + for (var updatedFile in event.relevantFiles!) { files.remove(updatedFile); } setState(() {}); @@ -88,6 +92,7 @@ class _ClusterPageState extends State { @override void dispose() { _filesUpdatedEvent.cancel(); + _peopleChangedEvent.cancel(); super.dispose(); } @@ -110,6 +115,7 @@ class _ClusterPageState extends State { ); }, reloadEvent: Bus.instance.on(), + forceReloadEvents: [Bus.instance.on()], removalEventTypes: const { EventType.deletedFromRemote, EventType.deletedFromEverywhere, From 2b88daa15fd73c11c356bb9300254b3f5e15a9cd Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 14:00:15 +0530 Subject: [PATCH 27/39] [mob] Method for detecting sideways faces --- mobile/lib/face/model/detection.dart | 25 +++++++++++++++++++ .../ui/viewer/file_details/face_widget.dart | 6 +++++ 2 files changed, 31 insertions(+) diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/face/model/detection.dart index 0718340b27..af6324ac85 100644 --- a/mobile/lib/face/model/detection.dart +++ b/mobile/lib/face/model/detection.dart @@ -129,4 +129,29 @@ class Detection { return FaceDirection.straight; } + + bool faceIsSideways() { + final leftEye = [landmarks[0].x, landmarks[0].y]; + final rightEye = [landmarks[1].x, landmarks[1].y]; + final nose = [landmarks[2].x, landmarks[2].y]; + final leftMouth = [landmarks[3].x, landmarks[3].y]; + final rightMouth = [landmarks[4].x, landmarks[4].y]; + + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = + (nose[0] < min(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + return faceIsUpright && (noseStickingOutLeft || noseStickingOutRight); + } } diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index c33004eff4..980049134d 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -175,6 +175,12 @@ class _FaceWidgetState extends State { style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), + if (kDebugMode) + Text( + 'Sideways: ${widget.face.detection.faceIsSideways().toString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), if (kDebugMode) Text( 'V: ${widget.face.visibility}', From ecc1bc99802651dbb737e3568aeb78613e6ec9dd Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 14:58:52 +0530 Subject: [PATCH 28/39] [mob] Use more conservative cluster threshold for sideways faces --- mobile/lib/face/db.dart | 31 +++++++++++++++++++ mobile/lib/face/model/detection.dart | 8 ++++- .../face_clustering_service.dart | 5 +-- .../face_info_for_clustering.dart | 5 +-- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 6c4c275631..add547376f 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -1,4 +1,5 @@ import 'dart:async'; +import "dart:convert" show json; import "dart:io" show Directory; import "dart:math"; @@ -10,6 +11,7 @@ import 'package:path_provider/path_provider.dart'; import "package:photos/extensions/stop_watch.dart"; import 'package:photos/face/db_fields.dart'; import "package:photos/face/db_model_mappers.dart"; +import "package:photos/face/model/detection.dart"; import "package:photos/face/model/face.dart"; import "package:photos/models/file/file.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; @@ -455,6 +457,31 @@ class FaceMLDataDB { ); } + // TODO: remove this method and replace by correct logic during indexing + Future> getFaceIdsToIsSidewaysFaceTEMP() async { + final db = await instance.sqliteAsyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $faceDetectionColumn FROM $facesTable', + ); + + final time = DateTime.now(); + final Map result = {}; + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + final detection = + Detection.fromJson(json.decode(map[faceDetectionColumn] as String)); + result[faceID] = detection.faceIsSideways(); + } + _logger.info( + 'decoding detections and calculating face sideways bools took ${DateTime.now().difference(time).inMilliseconds} ms', + ); + final double sidewaysRatio = + result.values.where((e) => e).length / result.length; + _logger.info('sideways face ratio: $sidewaysRatio'); + + return result; + } + Future> getFaceInfoForClustering({ double minScore = kMinHighQualityFaceScore, int minClarity = kLaplacianHardThreshold, @@ -468,6 +495,9 @@ class FaceMLDataDB { ); final db = await instance.sqliteAsyncDB; + final Map faceIdsToIsSideways = + await getFaceIdsToIsSidewaysFaceTEMP(); + final Set result = {}; while (true) { // Query a batch of rows @@ -494,6 +524,7 @@ class FaceMLDataDB { embeddingBytes: map[faceEmbeddingBlob] as Uint8List, faceScore: map[faceScore] as double, blurValue: map[faceBlur] as double, + isSideways: faceIdsToIsSideways[faceID] ?? false, ); result.add(faceInfo); } diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/face/model/detection.dart index af6324ac85..49e8c3652b 100644 --- a/mobile/lib/face/model/detection.dart +++ b/mobile/lib/face/model/detection.dart @@ -22,7 +22,7 @@ class Detection { bool get isEmpty => box.width == 0 && box.height == 0 && landmarks.isEmpty; - // emoty box + // empty box Detection.empty() : box = FaceBox( xMin: 0, @@ -94,6 +94,9 @@ class Detection { } FaceDirection getFaceDirection() { + if (isEmpty) { + return FaceDirection.straight; + } final leftEye = [landmarks[0].x, landmarks[0].y]; final rightEye = [landmarks[1].x, landmarks[1].y]; final nose = [landmarks[2].x, landmarks[2].y]; @@ -131,6 +134,9 @@ class Detection { } bool faceIsSideways() { + if (isEmpty) { + return false; + } final leftEye = [landmarks[0].x, landmarks[0].y]; final rightEye = [landmarks[1].x, landmarks[1].y]; final nose = [landmarks[2].x, landmarks[2].y]; diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 7d822f9e77..792ec3b4ce 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -452,7 +452,8 @@ class FaceClusteringService { faceScore: face.faceScore, blurValue: face.blurValue, badFace: face.faceScore < kMinHighQualityFaceScore || - face.blurValue < kLaplacianSoftThreshold, + face.blurValue < kLaplacianSoftThreshold || + face.isSideways, vEmbedding: Vector.fromList( EVector.fromBuffer(face.embeddingBytes).values, dtype: DType.float32, @@ -606,7 +607,7 @@ class FaceClusteringService { ); if (useDynamicThreshold) { log( - "[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or high blur value", + "[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or low blur clarity", ); } diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart index 8dd78ff1f0..273d85da55 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart @@ -1,4 +1,3 @@ - import "dart:typed_data" show Uint8List; class FaceInfoForClustering { @@ -7,6 +6,7 @@ class FaceInfoForClustering { final Uint8List embeddingBytes; final double faceScore; final double blurValue; + final bool isSideways; FaceInfoForClustering({ required this.faceID, @@ -14,5 +14,6 @@ class FaceInfoForClustering { required this.embeddingBytes, required this.faceScore, required this.blurValue, + this.isSideways = false, }); -} \ No newline at end of file +} From f3f85e81ca15baeb7313fcbbed1e8bb310569bf4 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 15:08:10 +0530 Subject: [PATCH 29/39] [mob] Rename constant --- mobile/lib/face/db.dart | 14 +++++++------- mobile/lib/face/model/face.dart | 2 +- .../face_clustering/face_clustering_service.dart | 4 ++-- .../face_filtering/face_filtering_constants.dart | 4 ++-- .../machine_learning/face_ml/face_ml_service.dart | 6 ++++-- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index add547376f..3b3bf3b2c6 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -261,7 +261,7 @@ class FaceMLDataDB { final List> faceMaps = await db.rawQuery( 'SELECT * FROM $facesTable where ' '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' - 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinHighQualityFaceScore ORDER BY $faceScore DESC', + 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinimumQualityFaceScore ORDER BY $faceScore DESC', ); if (faceMaps.isNotEmpty) { if (avatarFileId != null) { @@ -483,7 +483,7 @@ class FaceMLDataDB { } Future> getFaceInfoForClustering({ - double minScore = kMinHighQualityFaceScore, + double minScore = kMinimumQualityFaceScore, int minClarity = kLaplacianHardThreshold, int maxFaces = 20000, int offset = 0, @@ -541,7 +541,7 @@ class FaceMLDataDB { /// /// Only selects faces with score greater than [minScore] and blur score greater than [minClarity] Future> getFaceEmbeddingMap({ - double minScore = kMinHighQualityFaceScore, + double minScore = kMinimumQualityFaceScore, int minClarity = kLaplacianHardThreshold, int maxFaces = 20000, int offset = 0, @@ -608,7 +608,7 @@ class FaceMLDataDB { facesTable, columns: [faceIDColumn, faceEmbeddingBlob], where: - '$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', + '$faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', limit: batchSize, offset: offset, orderBy: '$faceIDColumn DESC', @@ -669,7 +669,7 @@ class FaceMLDataDB { } Future getTotalFaceCount({ - double minFaceScore = kMinHighQualityFaceScore, + double minFaceScore = kMinimumQualityFaceScore, }) async { final db = await instance.sqliteAsyncDB; final List> maps = await db.getAll( @@ -682,7 +682,7 @@ class FaceMLDataDB { final db = await instance.sqliteAsyncDB; final List> totalFacesMaps = await db.getAll( - 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold', + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold', ); final int totalFaces = totalFacesMaps.first['count'] as int; @@ -699,7 +699,7 @@ class FaceMLDataDB { ]) async { final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinHighQualityFaceScore', + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinimumQualityFaceScore', ); return maps.first['count'] as int; } diff --git a/mobile/lib/face/model/face.dart b/mobile/lib/face/model/face.dart index c480bde794..fc4bb57d5f 100644 --- a/mobile/lib/face/model/face.dart +++ b/mobile/lib/face/model/face.dart @@ -22,7 +22,7 @@ class Face { bool get isBlurry => blur < kLaplacianHardThreshold; - bool get hasHighScore => score > kMinHighQualityFaceScore; + bool get hasHighScore => score > kMinimumQualityFaceScore; bool get isHighQuality => (!isBlurry) && hasHighScore; diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 792ec3b4ce..f62c3259a3 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -271,7 +271,7 @@ class FaceClusteringService { FaceInfoForClustering( faceID: key, embeddingBytes: value, - faceScore: kMinHighQualityFaceScore + 0.01, + faceScore: kMinimumQualityFaceScore + 0.01, blurValue: kLapacianDefault, ), ); @@ -451,7 +451,7 @@ class FaceClusteringService { faceID: face.faceID, faceScore: face.faceScore, blurValue: face.blurValue, - badFace: face.faceScore < kMinHighQualityFaceScore || + badFace: face.faceScore < kMinimumQualityFaceScore || face.blurValue < kLaplacianSoftThreshold || face.isSideways, vEmbedding: Vector.fromList( diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart index 7a3ef61e3e..ff9e7b5421 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart @@ -8,7 +8,7 @@ const kLaplacianSoftThreshold = 100; const kLapacianDefault = 10000.0; /// The minimum score for a face to be considered a high quality face for clustering and person detection -const kMinHighQualityFaceScore = 0.80; +const kMinimumQualityFaceScore = 0.80; -/// The minimum score for a face to be detected, regardless of quality. Use [kMinHighQualityFaceScore] for high quality faces. +/// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces. const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold; diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 556a4135fa..712264ea12 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -207,7 +207,9 @@ class FaceMlService { final time = DateTime.now(); final FaceMlResult result = await FaceMlService.analyzeImageSync(args); - dev.log("`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms"); + dev.log( + "`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms", + ); sendPort.send(result.toJsonString()); break; } @@ -289,7 +291,7 @@ class FaceMlService { } Future clusterAllImages({ - double minFaceScore = kMinHighQualityFaceScore, + double minFaceScore = kMinimumQualityFaceScore, bool clusterInBuckets = true, }) async { _logger.info("`clusterAllImages()` called"); From a92081e703020f2b92da13fb338b7698b728ba23 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 15:20:29 +0530 Subject: [PATCH 30/39] [mob] Only analyze clustering in debugMode --- .../face_ml/face_clustering/face_clustering_service.dart | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index f62c3259a3..99cd2227ff 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -2,9 +2,10 @@ import "dart:async"; import "dart:developer"; import "dart:isolate"; import "dart:math" show max; -import "dart:typed_data"; +import "dart:typed_data" show Uint8List; import "package:computer/computer.dart"; +import "package:flutter/foundation.dart" show kDebugMode; import "package:logging/logging.dart"; import "package:ml_linalg/dtype.dart"; import "package:ml_linalg/vector.dart"; @@ -618,6 +619,7 @@ class FaceClusteringService { } static void _analyzeClusterResults(List sortedFaceInfos) { + if (!kDebugMode) return; final stopwatch = Stopwatch()..start(); final Map faceIdToCluster = {}; From 4095b14589b0c9b3b5490e2a3c5bdf01ed988a86 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 15:57:11 +0530 Subject: [PATCH 31/39] [mob] Conservative threshold for combination of low score and blur --- .../face_ml/face_clustering/face_clustering_service.dart | 2 ++ .../face_ml/face_filtering/face_filtering_constants.dart | 3 +++ 2 files changed, 5 insertions(+) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 99cd2227ff..985f3708bb 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -454,6 +454,8 @@ class FaceClusteringService { blurValue: face.blurValue, badFace: face.faceScore < kMinimumQualityFaceScore || face.blurValue < kLaplacianSoftThreshold || + (face.blurValue < kLaplacianVerySoftThreshold && + face.faceScore < kMediumQualityFaceScore) || face.isSideways, vEmbedding: Vector.fromList( EVector.fromBuffer(face.embeddingBytes).values, diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart index ff9e7b5421..b1f2f60183 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart @@ -3,12 +3,15 @@ import 'package:photos/services/machine_learning/face_ml/face_detection/face_det /// Blur detection threshold const kLaplacianHardThreshold = 15; const kLaplacianSoftThreshold = 100; +const kLaplacianVerySoftThreshold = 200; /// Default blur value const kLapacianDefault = 10000.0; /// The minimum score for a face to be considered a high quality face for clustering and person detection const kMinimumQualityFaceScore = 0.80; +const kMediumQualityFaceScore = 0.85; +const kHighQualityFaceScore = 0.90; /// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces. const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold; From 6d3e1325c49b2499313ef5f15d02b559d498a569 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 16:35:33 +0530 Subject: [PATCH 32/39] [mob] Update face widget --- .../ui/viewer/file_details/face_widget.dart | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index 980049134d..79637af829 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -48,7 +48,7 @@ class _FaceWidgetState extends State { @override Widget build(BuildContext context) { - if (Platform.isIOS || Platform.isAndroid) { + if (Platform.isIOS) { return FutureBuilder( future: getFaceCrop(), builder: (context, snapshot) { @@ -181,18 +181,6 @@ class _FaceWidgetState extends State { style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), - if (kDebugMode) - Text( - 'V: ${widget.face.visibility}', - style: Theme.of(context).textTheme.bodySmall, - maxLines: 1, - ), - if (kDebugMode) - Text( - 'A: ${widget.face.area()}', - style: Theme.of(context).textTheme.bodySmall, - maxLines: 1, - ), // if (kDebugMode) // if (highlight) // const Text( @@ -316,6 +304,24 @@ class _FaceWidgetState extends State { style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), + if (kDebugMode) + Text( + 'B: ${widget.face.blur.toStringAsFixed(0)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'D: ${widget.face.detection.getFaceDirection().toDirectionString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'Sideways: ${widget.face.detection.faceIsSideways().toString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), ], ), ); From a0fa90cb50bc964b3bd54d28387842193a7a8030 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 16:43:41 +0530 Subject: [PATCH 33/39] [mob] Remove faceArea and faceVisibility from local DB --- mobile/lib/face/db.dart | 2 -- mobile/lib/face/db_fields.dart | 4 ---- mobile/lib/face/db_model_mappers.dart | 2 -- 3 files changed, 8 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 3b3bf3b2c6..95aced5a5a 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -311,8 +311,6 @@ class FaceMLDataDB { faceBlur, imageHeight, imageWidth, - faceArea, - faceVisibilityScore, mlVersionColumn, ], where: '$fileIDColumn = ?', diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart index 3890a261e6..206806e0ca 100644 --- a/mobile/lib/face/db_fields.dart +++ b/mobile/lib/face/db_fields.dart @@ -8,8 +8,6 @@ const faceDetectionColumn = 'detection'; const faceEmbeddingBlob = 'eBlob'; const faceScore = 'score'; const faceBlur = 'blur'; -const faceArea = 'area'; -const faceVisibilityScore = 'visibility'; const imageWidth = 'width'; const imageHeight = 'height'; const faceClusterId = 'cluster_id'; @@ -24,8 +22,6 @@ const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault, $imageHeight INTEGER NOT NULL DEFAULT 0, $imageWidth INTEGER NOT NULL DEFAULT 0, - $faceArea INTEGER NOT NULL DEFAULT 0, - $faceVisibilityScore INTEGER NOT NULL DEFAULT -1, $mlVersionColumn INTEGER NOT NULL DEFAULT -1, PRIMARY KEY($fileIDColumn, $faceIDColumn) ); diff --git a/mobile/lib/face/db_model_mappers.dart b/mobile/lib/face/db_model_mappers.dart index 4e33a0bfdb..683281e93d 100644 --- a/mobile/lib/face/db_model_mappers.dart +++ b/mobile/lib/face/db_model_mappers.dart @@ -35,8 +35,6 @@ Map mapRemoteToFaceDB(Face face) { faceScore: face.score, faceBlur: face.blur, mlVersionColumn: faceMlVersion, - faceArea: face.area(), - faceVisibilityScore: face.visibility, imageWidth: face.fileInfo?.imageWidth ?? 0, imageHeight: face.fileInfo?.imageHeight ?? 0, }; From 01aecb9742916d92ad132296ce12540965828eb3 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 19 Apr 2024 18:13:35 +0530 Subject: [PATCH 34/39] [mob] store sideways face boolean in local face table --- mobile/lib/face/db.dart | 34 ++------------------------- mobile/lib/face/db_fields.dart | 2 ++ mobile/lib/face/db_model_mappers.dart | 1 + 3 files changed, 5 insertions(+), 32 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 95aced5a5a..e45ea42523 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -1,5 +1,4 @@ import 'dart:async'; -import "dart:convert" show json; import "dart:io" show Directory; import "dart:math"; @@ -11,7 +10,6 @@ import 'package:path_provider/path_provider.dart'; import "package:photos/extensions/stop_watch.dart"; import 'package:photos/face/db_fields.dart'; import "package:photos/face/db_model_mappers.dart"; -import "package:photos/face/model/detection.dart"; import "package:photos/face/model/face.dart"; import "package:photos/models/file/file.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; @@ -455,31 +453,6 @@ class FaceMLDataDB { ); } - // TODO: remove this method and replace by correct logic during indexing - Future> getFaceIdsToIsSidewaysFaceTEMP() async { - final db = await instance.sqliteAsyncDB; - final List> maps = await db.getAll( - 'SELECT $faceIDColumn, $faceDetectionColumn FROM $facesTable', - ); - - final time = DateTime.now(); - final Map result = {}; - for (final map in maps) { - final faceID = map[faceIDColumn] as String; - final detection = - Detection.fromJson(json.decode(map[faceDetectionColumn] as String)); - result[faceID] = detection.faceIsSideways(); - } - _logger.info( - 'decoding detections and calculating face sideways bools took ${DateTime.now().difference(time).inMilliseconds} ms', - ); - final double sidewaysRatio = - result.values.where((e) => e).length / result.length; - _logger.info('sideways face ratio: $sidewaysRatio'); - - return result; - } - Future> getFaceInfoForClustering({ double minScore = kMinimumQualityFaceScore, int minClarity = kLaplacianHardThreshold, @@ -493,14 +466,11 @@ class FaceMLDataDB { ); final db = await instance.sqliteAsyncDB; - final Map faceIdsToIsSideways = - await getFaceIdsToIsSidewaysFaceTEMP(); - final Set result = {}; while (true) { // Query a batch of rows final List> maps = await db.getAll( - 'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur FROM $facesTable' + 'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways FROM $facesTable' ' WHERE $faceScore > $minScore AND $faceBlur > $minClarity' ' ORDER BY $faceIDColumn' ' DESC LIMIT $batchSize OFFSET $offset', @@ -522,7 +492,7 @@ class FaceMLDataDB { embeddingBytes: map[faceEmbeddingBlob] as Uint8List, faceScore: map[faceScore] as double, blurValue: map[faceBlur] as double, - isSideways: faceIdsToIsSideways[faceID] ?? false, + isSideways: (map[isSideways] as int) == 1, ); result.add(faceInfo); } diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart index 206806e0ca..c7d0c703c4 100644 --- a/mobile/lib/face/db_fields.dart +++ b/mobile/lib/face/db_fields.dart @@ -8,6 +8,7 @@ const faceDetectionColumn = 'detection'; const faceEmbeddingBlob = 'eBlob'; const faceScore = 'score'; const faceBlur = 'blur'; +const isSideways = 'is_sideways'; const imageWidth = 'width'; const imageHeight = 'height'; const faceClusterId = 'cluster_id'; @@ -20,6 +21,7 @@ const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( $faceEmbeddingBlob BLOB NOT NULL, $faceScore REAL NOT NULL, $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault, + $isSideways INTEGER NOT NULL DEFAULT 0, $imageHeight INTEGER NOT NULL DEFAULT 0, $imageWidth INTEGER NOT NULL DEFAULT 0, $mlVersionColumn INTEGER NOT NULL DEFAULT -1, diff --git a/mobile/lib/face/db_model_mappers.dart b/mobile/lib/face/db_model_mappers.dart index 683281e93d..70dc779157 100644 --- a/mobile/lib/face/db_model_mappers.dart +++ b/mobile/lib/face/db_model_mappers.dart @@ -34,6 +34,7 @@ Map mapRemoteToFaceDB(Face face) { ).writeToBuffer(), faceScore: face.score, faceBlur: face.blur, + isSideways: face.detection.faceIsSideways() ? 1 : 0, mlVersionColumn: faceMlVersion, imageWidth: face.fileInfo?.imageWidth ?? 0, imageHeight: face.fileInfo?.imageHeight ?? 0, From 7be1b63822b886979ccef6da5b6a6a8a6cf15be8 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Sat, 20 Apr 2024 12:24:20 +0530 Subject: [PATCH 35/39] [mob] Clean up linear clustering method --- .../face_clustering_service.dart | 34 +++++-------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 985f3708bb..e00335ac20 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -234,7 +234,7 @@ class FaceClusteringService { ); // return _runLinearClusteringInComputer(input); _logger.info( - 'Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', + 'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', ); isRunning = false; @@ -493,12 +493,13 @@ class FaceClusteringService { facesWithClusterID.add(faceInfo); } } + final alreadyClusteredCount = facesWithClusterID.length; final sortedFaceInfos = []; sortedFaceInfos.addAll(facesWithClusterID); sortedFaceInfos.addAll(facesWithoutClusterID); log( - "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId", + "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and $alreadyClusteredCount faces with clusterId", ); // Make sure the first face has a clusterId @@ -520,7 +521,6 @@ class FaceClusteringService { sortedFaceInfos[0].clusterId = clusterID; clusterID++; } - final Map newFaceIdToCluster = {}; final stopwatchClustering = Stopwatch()..start(); for (int i = 1; i < totalFaces; i++) { // Incremental clustering, so we can skip faces that already have a clusterId @@ -563,12 +563,6 @@ class FaceClusteringService { } closestDistance = distance; closestIdx = j; - // if (distance < distanceThreshold) { - // if (sortedFaceInfos[j].faceID.startsWith("14914702") || - // sortedFaceInfos[j].faceID.startsWith("15488756")) { - // log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance'); - // } - // } } } @@ -580,30 +574,20 @@ class FaceClusteringService { ); clusterID++; sortedFaceInfos[closestIdx].clusterId = clusterID; - newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID; } - // if (sortedFaceInfos[i].faceID.startsWith("14914702") || - // sortedFaceInfos[i].faceID.startsWith("15488756")) { - // log( - // "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance", - // ); - // } sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; - newFaceIdToCluster[sortedFaceInfos[i].faceID] = - sortedFaceInfos[closestIdx].clusterId!; } else { - // if (sortedFaceInfos[i].faceID.startsWith("14914702") || - // sortedFaceInfos[i].faceID.startsWith("15488756")) { - // log( - // "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}", - // ); - // } clusterID++; sortedFaceInfos[i].clusterId = clusterID; - newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID; } } + // Finally, assign the new clusterId to the faces + final Map newFaceIdToCluster = {}; + for (final faceInfo in sortedFaceInfos.sublist(alreadyClusteredCount)) { + newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + stopwatchClustering.stop(); log( ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', From a0502886b646f98687936c688bf4b808ee92a2d4 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Sat, 20 Apr 2024 13:26:47 +0530 Subject: [PATCH 36/39] [mob] Minimum size argument for getting all cluster summaries --- mobile/lib/face/db.dart | 10 +++++++--- .../face_ml/feedback/cluster_feedback.dart | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index e45ea42523..5945175470 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -829,10 +829,14 @@ class FaceMLDataDB { } /// Returns a map of clusterID to (avg embedding, count) - Future> clusterSummaryAll() async { - final db = await instance.database; + Future> getAllClusterSummary([ + int? minClusterSize, + ]) async { + final db = await instance.sqliteAsyncDB; final Map result = {}; - final rows = await db.rawQuery('SELECT * from $clusterSummaryTable'); + final rows = await db.getAll( + 'SELECT * FROM $clusterSummaryTable${minClusterSize != null ? ' WHERE $countColumn >= $minClusterSize' : ''}', + ); for (final r in rows) { final id = r[clusterIDColumn] as int; final avg = r[avgColumn] as Uint8List; diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 4cb9dce6c4..95497a90d2 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -661,7 +661,7 @@ class ClusterFeedbackService { ); final Map clusterToSummary = - await faceMlDb.clusterSummaryAll(); + await faceMlDb.getAllClusterSummary(); final Map updatesForClusterSummary = {}; final Map> clusterAvg = {}; @@ -852,7 +852,7 @@ class ClusterFeedbackService { // Get the cluster averages for the person's clusters and the suggestions' clusters final Map clusterToSummary = - await faceMlDb.clusterSummaryAll(); + await faceMlDb.getAllClusterSummary(); // Calculate the avg embedding of the person final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); From 4ac295e1e209b5c42b56295e7e0f6623cfebb7e9 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Sat, 20 Apr 2024 13:35:02 +0530 Subject: [PATCH 37/39] [mob][photos] Remove ClusterResult old deprecated code --- .../face_ml/face_ml_result.dart | 266 ------------------ 1 file changed, 266 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart index bb03de519a..6530cde132 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart @@ -7,278 +7,12 @@ import "package:photos/models/file/file.dart"; import 'package:photos/models/ml/ml_typedefs.dart'; import "package:photos/models/ml/ml_versions.dart"; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; -import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import 'package:photos/services/machine_learning/face_ml/face_ml_methods.dart'; final _logger = Logger('ClusterResult_FaceMlResult'); -// TODO: should I add [faceMlVersion] and [clusterMlVersion] to the [ClusterResult] class? -@Deprecated('We are now just storing the cluster results directly in DB') -class ClusterResult { - final int personId; - String? userDefinedName; - bool get hasUserDefinedName => userDefinedName != null; - - String _thumbnailFaceId; - bool thumbnailFaceIdIsUserDefined; - - final List _fileIds; - final List _faceIds; - - final Embedding medoid; - double medoidDistanceThreshold; - - List get uniqueFileIds => _fileIds.toSet().toList(); - List get fileIDsIncludingPotentialDuplicates => _fileIds; - - List get faceIDs => _faceIds; - - String get thumbnailFaceId => _thumbnailFaceId; - - int get thumbnailFileId => getFileIdFromFaceId(_thumbnailFaceId); - - /// Sets the thumbnail faceId to the given faceId. - /// Throws an exception if the faceId is not in the list of faceIds. - set setThumbnailFaceId(String faceId) { - if (!_faceIds.contains(faceId)) { - throw Exception( - "The faceId $faceId is not in the list of faceIds: $faceId", - ); - } - _thumbnailFaceId = faceId; - thumbnailFaceIdIsUserDefined = true; - } - - /// Sets the [userDefinedName] to the given [customName] - set setUserDefinedName(String customName) { - userDefinedName = customName; - } - - int get clusterSize => _fileIds.toSet().length; - - ClusterResult({ - required this.personId, - required String thumbnailFaceId, - required List fileIds, - required List faceIds, - required this.medoid, - required this.medoidDistanceThreshold, - this.userDefinedName, - this.thumbnailFaceIdIsUserDefined = false, - }) : _thumbnailFaceId = thumbnailFaceId, - _faceIds = faceIds, - _fileIds = fileIds; - - void addFileIDsAndFaceIDs(List fileIDs, List faceIDs) { - assert(fileIDs.length == faceIDs.length); - _fileIds.addAll(fileIDs); - _faceIds.addAll(faceIDs); - } - - // TODO: Consider if we should recalculated the medoid and threshold when deleting or adding a file from the cluster - int removeFileId(int fileId) { - assert(_fileIds.length == _faceIds.length); - if (!_fileIds.contains(fileId)) { - throw Exception( - "The fileId $fileId is not in the list of fileIds: $fileId, so it's not in the cluster and cannot be removed.", - ); - } - - int removedCount = 0; - for (var i = 0; i < _fileIds.length; i++) { - if (_fileIds[i] == fileId) { - assert(getFileIdFromFaceId(_faceIds[i]) == fileId); - _fileIds.removeAt(i); - _faceIds.removeAt(i); - debugPrint( - "Removed fileId $fileId from cluster $personId at index ${i + removedCount}}", - ); - i--; // Adjust index due to removal - removedCount++; - } - } - - _ensureClusterSizeIsAboveMinimum(); - - return removedCount; - } - - int addFileID(int fileID) { - assert(_fileIds.length == _faceIds.length); - if (_fileIds.contains(fileID)) { - return 0; - } - - _fileIds.add(fileID); - _faceIds.add(FaceDetectionRelative.toFaceIDEmpty(fileID: fileID)); - - return 1; - } - - void ensureThumbnailFaceIdIsInCluster() { - if (!_faceIds.contains(_thumbnailFaceId)) { - _thumbnailFaceId = _faceIds[0]; - } - } - - void _ensureClusterSizeIsAboveMinimum() { - if (clusterSize < minimumClusterSize) { - throw Exception( - "Cluster size is below minimum cluster size of $minimumClusterSize", - ); - } - } - - Map _toJson() => { - 'personId': personId, - 'thumbnailFaceId': _thumbnailFaceId, - 'fileIds': _fileIds, - 'faceIds': _faceIds, - 'medoid': medoid, - 'medoidDistanceThreshold': medoidDistanceThreshold, - if (userDefinedName != null) 'userDefinedName': userDefinedName, - 'thumbnailFaceIdIsUserDefined': thumbnailFaceIdIsUserDefined, - }; - - String toJsonString() => jsonEncode(_toJson()); - - static ClusterResult _fromJson(Map json) { - return ClusterResult( - personId: json['personId'] ?? -1, - thumbnailFaceId: json['thumbnailFaceId'] ?? '', - fileIds: - (json['fileIds'] as List?)?.map((item) => item as int).toList() ?? [], - faceIds: - (json['faceIds'] as List?)?.map((item) => item as String).toList() ?? - [], - medoid: - (json['medoid'] as List?)?.map((item) => item as double).toList() ?? - [], - medoidDistanceThreshold: json['medoidDistanceThreshold'] ?? 0, - userDefinedName: json['userDefinedName'], - thumbnailFaceIdIsUserDefined: - json['thumbnailFaceIdIsUserDefined'] as bool, - ); - } - - static ClusterResult fromJsonString(String jsonString) { - return _fromJson(jsonDecode(jsonString)); - } -} - -class ClusterResultBuilder { - int personId = -1; - String? userDefinedName; - String thumbnailFaceId = ''; - bool thumbnailFaceIdIsUserDefined = false; - - List fileIds = []; - List faceIds = []; - - List embeddings = []; - Embedding medoid = []; - double medoidDistanceThreshold = 0; - bool medoidAndThresholdCalculated = false; - final int k = 5; - - ClusterResultBuilder.createFromIndices({ - required List clusterIndices, - required List labels, - required List allEmbeddings, - required List allFileIds, - required List allFaceIds, - }) { - final clusteredFileIds = - clusterIndices.map((fileIndex) => allFileIds[fileIndex]).toList(); - final clusteredFaceIds = - clusterIndices.map((fileIndex) => allFaceIds[fileIndex]).toList(); - final clusteredEmbeddings = - clusterIndices.map((fileIndex) => allEmbeddings[fileIndex]).toList(); - personId = labels[clusterIndices[0]]; - fileIds = clusteredFileIds; - faceIds = clusteredFaceIds; - thumbnailFaceId = faceIds[0]; - embeddings = clusteredEmbeddings; - } - - void calculateAndSetMedoidAndThreshold() { - if (embeddings.isEmpty) { - throw Exception("Cannot calculate medoid and threshold for empty list"); - } - - // Calculate the medoid and threshold - final (tempMedoid, distanceThreshold) = - _calculateMedoidAndDistanceTreshold(embeddings); - - // Update the medoid - medoid = List.from(tempMedoid); - - // Update the medoidDistanceThreshold as the distance of the medoid to its k-th nearest neighbor - medoidDistanceThreshold = distanceThreshold; - - medoidAndThresholdCalculated = true; - } - - (List, double) _calculateMedoidAndDistanceTreshold( - List> embeddings, - ) { - double minDistance = double.infinity; - List? medoid; - - // Calculate the distance between all pairs - for (int i = 0; i < embeddings.length; ++i) { - double totalDistance = 0; - for (int j = 0; j < embeddings.length; ++j) { - if (i != j) { - totalDistance += cosineDistance(embeddings[i], embeddings[j]); - - // Break early if we already exceed minDistance - if (totalDistance > minDistance) { - break; - } - } - } - - // Find the minimum total distance - if (totalDistance < minDistance) { - minDistance = totalDistance; - medoid = embeddings[i]; - } - } - - // Now, calculate k-th nearest neighbor for the medoid - final List distancesToMedoid = []; - for (List embedding in embeddings) { - if (embedding != medoid) { - distancesToMedoid.add(cosineDistance(medoid!, embedding)); - } - } - distancesToMedoid.sort(); - // TODO: empirically find the best k. Probably it should be dynamic in some way, so for instance larger for larger clusters and smaller for smaller clusters, especially since there are a lot of really small clusters and a few really large ones. - final double kthDistance = distancesToMedoid[ - distancesToMedoid.length >= k ? k - 1 : distancesToMedoid.length - 1]; - - return (medoid!, kthDistance); - } - - void changeThumbnailFaceId(String faceId) { - if (!faceIds.contains(faceId)) { - throw Exception( - "The faceId $faceId is not in the list of faceIds: $faceIds", - ); - } - thumbnailFaceId = faceId; - } - - void addFileIDsAndFaceIDs(List addedFileIDs, List addedFaceIDs) { - assert(addedFileIDs.length == addedFaceIDs.length); - fileIds.addAll(addedFileIDs); - faceIds.addAll(addedFaceIDs); - } -} - @immutable class FaceMlResult { final int fileId; From cde17441d630b98427c993039ec1737472b3496f Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Sat, 20 Apr 2024 13:35:53 +0530 Subject: [PATCH 38/39] [mob][photos] Forgot one import --- .../lib/services/machine_learning/face_ml/face_ml_result.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart index 6530cde132..19f954013e 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart @@ -1,6 +1,6 @@ import "dart:convert" show jsonEncode, jsonDecode; -import "package:flutter/material.dart" show debugPrint, immutable; +import "package:flutter/material.dart" show immutable; import "package:logging/logging.dart"; import "package:photos/face/model/dimension.dart"; import "package:photos/models/file/file.dart"; From edf99385dc53054b7131d5109728b09c656249bd Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Sat, 20 Apr 2024 14:38:46 +0530 Subject: [PATCH 39/39] [mob][photos] Precompute cluster summaries incrementally during clustering --- .../face_clustering_service.dart | 95 +++++++++++++++++-- .../face_ml/face_ml_service.dart | 26 +++-- 2 files changed, 104 insertions(+), 17 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index e00335ac20..80fff99c68 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -42,6 +42,15 @@ class FaceInfo { enum ClusterOperation { linearIncrementalClustering, dbscanClustering } +class ClusteringResult { + final Map newFaceIdToCluster; + final Map? newClusterSummaries; + ClusteringResult({ + required this.newFaceIdToCluster, + required this.newClusterSummaries, + }); +} + class FaceClusteringService { final _logger = Logger("FaceLinearClustering"); final _computer = Computer.shared(); @@ -191,13 +200,14 @@ class FaceClusteringService { /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. /// /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic. - Future?> predictLinear( + Future predictLinear( Set input, { Map? fileIDToCreationTime, double distanceThreshold = kRecommendedDistanceThreshold, double conservativeDistanceThreshold = kConservativeDistanceThreshold, bool useDynamicThreshold = true, int? offset, + required Map oldClusterSummaries, }) async { if (input.isEmpty) { _logger.warning( @@ -219,7 +229,7 @@ class FaceClusteringService { final stopwatchClustering = Stopwatch()..start(); // final Map faceIdToCluster = // await _runLinearClusteringInComputer(input); - final Map faceIdToCluster = await _runInIsolate( + final ClusteringResult? faceIdToCluster = await _runInIsolate( ( ClusterOperation.linearIncrementalClustering, { @@ -229,6 +239,7 @@ class FaceClusteringService { 'conservativeDistanceThreshold': conservativeDistanceThreshold, 'useDynamicThreshold': useDynamicThreshold, 'offset': offset, + 'oldClusterSummaries': oldClusterSummaries, } ), ); @@ -247,7 +258,7 @@ class FaceClusteringService { } /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding - Future?> predictLinearComputer( + Future predictLinearComputer( Map input, { Map? fileIDToCreationTime, double distanceThreshold = kRecommendedDistanceThreshold, @@ -256,7 +267,7 @@ class FaceClusteringService { _logger.warning( "Linear Clustering dataset of embeddings is empty, returning empty list.", ); - return {}; + return null; } // Clustering inside the isolate @@ -290,7 +301,7 @@ class FaceClusteringService { "useDynamicThreshold": false, }, taskName: "createImageEmbedding", - ) as Map; + ) as ClusteringResult; final endTime = DateTime.now(); _logger.info( "Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms", @@ -369,11 +380,12 @@ class FaceClusteringService { _logger.info( 'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold', ); - return predictLinearComputer( + final clusterResult = await predictLinearComputer( input, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: distanceThreshold, ); + return clusterResult?.newFaceIdToCluster; } } catch (e, s) { _logger.severe(e, s); @@ -430,7 +442,7 @@ class FaceClusteringService { return clusterFaceIDs; } - static Map runLinearClustering(Map args) { + static ClusteringResult? runLinearClustering(Map args) { // final input = args['input'] as Map; final input = args['input'] as Set; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; @@ -439,6 +451,8 @@ class FaceClusteringService { args['conservativeDistanceThreshold'] as double; final useDynamicThreshold = args['useDynamicThreshold'] as bool; final offset = args['offset'] as int?; + final oldClusterSummaries = + args['oldClusterSummaries'] as Map?; log( "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", @@ -507,7 +521,7 @@ class FaceClusteringService { int dynamicThresholdCount = 0; if (sortedFaceInfos.isEmpty) { - return {}; + return null; } // Start actual clustering @@ -584,7 +598,9 @@ class FaceClusteringService { // Finally, assign the new clusterId to the faces final Map newFaceIdToCluster = {}; - for (final faceInfo in sortedFaceInfos.sublist(alreadyClusteredCount)) { + final newClusteredFaceInfos = + sortedFaceInfos.sublist(alreadyClusteredCount); + for (final faceInfo in newClusteredFaceInfos) { newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; } @@ -598,10 +614,69 @@ class FaceClusteringService { ); } + // Now calculate the mean of the embeddings for each cluster and update the cluster summaries + Map? newClusterSummaries; + if (oldClusterSummaries != null) { + newClusterSummaries = FaceClusteringService.updateClusterSummaries( + oldSummary: oldClusterSummaries, + newFaceInfos: newClusteredFaceInfos, + ); + } + // analyze the results FaceClusteringService._analyzeClusterResults(sortedFaceInfos); - return newFaceIdToCluster; + return ClusteringResult( + newFaceIdToCluster: newFaceIdToCluster, + newClusterSummaries: newClusterSummaries, + ); + } + + static Map updateClusterSummaries({ + required Map oldSummary, + required List newFaceInfos, + }) { + final calcSummariesStart = DateTime.now(); + final Map> newClusterIdToFaceInfos = {}; + for (final faceInfo in newFaceInfos) { + if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) { + newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo); + } else { + newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; + } + } + + final Map newClusterSummaries = {}; + for (final clusterId in newClusterIdToFaceInfos.keys) { + final List newEmbeddings = newClusterIdToFaceInfos[clusterId]! + .map((faceInfo) => faceInfo.vEmbedding!) + .toList(); + final newCount = newEmbeddings.length; + if (oldSummary.containsKey(clusterId)) { + final oldMean = Vector.fromList( + EVector.fromBuffer(oldSummary[clusterId]!.$1).values, + dtype: DType.float32, + ); + final oldCount = oldSummary[clusterId]!.$2; + final oldEmbeddings = oldMean * oldCount; + newEmbeddings.add(oldEmbeddings); + final newMeanVector = + newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount); + newClusterSummaries[clusterId] = ( + EVector(values: newMeanVector.toList()).writeToBuffer(), + oldCount + newCount + ); + } else { + final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount; + newClusterSummaries[clusterId] = + (EVector(values: newMeanVector.toList()).writeToBuffer(), newCount); + } + } + log( + "[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms", + ); + + return newClusterSummaries; } static void _analyzeClusterResults(List sortedFaceInfos) { diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 712264ea12..3df9b30561 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -300,6 +300,10 @@ class FaceMlService { // Get a sense of the total number of faces in the database final int totalFaces = await FaceMLDataDB.instance .getTotalFaceCount(minFaceScore: minFaceScore); + + // Get the current cluster statistics + final Map oldClusterSummaries = + await FaceMLDataDB.instance.getAllClusterSummary(); if (clusterInBuckets) { // read the creation times from Files DB, in a map from fileID to creation time final fileIDToCreationTime = @@ -332,18 +336,22 @@ class FaceMlService { break; } - final faceIdToCluster = + final clusteringResult = await FaceClusteringService.instance.predictLinear( faceInfoForClustering, fileIDToCreationTime: fileIDToCreationTime, offset: offset, + oldClusterSummaries: oldClusterSummaries, ); - if (faceIdToCluster == null) { + if (clusteringResult == null) { _logger.warning("faceIdToCluster is null"); return; } - await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); + await FaceMLDataDB.instance + .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusteringResult.newClusterSummaries!); _logger.info( 'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset', ); @@ -374,12 +382,13 @@ class FaceMlService { '${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms'); // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID - final faceIdToCluster = + final clusteringResult = await FaceClusteringService.instance.predictLinear( faceInfoForClustering, fileIDToCreationTime: fileIDToCreationTime, + oldClusterSummaries: oldClusterSummaries, ); - if (faceIdToCluster == null) { + if (clusteringResult == null) { _logger.warning("faceIdToCluster is null"); return; } @@ -390,9 +399,12 @@ class FaceMlService { // Store the updated clusterIDs in the database _logger.info( - 'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB', + 'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB', ); - await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); + await FaceMLDataDB.instance + .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusteringResult.newClusterSummaries!); _logger.info('Done updating FaceIDs with clusterIDs in the DB, in ' '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds'); }