From ea8846ee0c2d08b95da7d624a23ac555918c35f4 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Mon, 6 May 2024 10:38:26 +0530 Subject: [PATCH] [mob][photos] Properly sort faces on fileCreationTime for clustering --- .../face_info_for_clustering.dart | 2 +- .../face_ml/face_ml_service.dart | 77 +++++++++++-------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart index 273d85da55..bd32889dcb 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart @@ -2,7 +2,7 @@ import "dart:typed_data" show Uint8List; class FaceInfoForClustering { final String faceID; - final int? clusterId; + int? clusterId; final Uint8List embeddingBytes; final double faceScore; final double blurValue; diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 7a9b75b1b5..b39a09af1e 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -2,6 +2,7 @@ import "dart:async"; import "dart:developer" as dev show log; import "dart:io" show File; import "dart:isolate"; +import "dart:math" show min; import "dart:typed_data" show Uint8List, Float32List, ByteData; import "dart:ui" show Image; @@ -302,30 +303,41 @@ class FaceMlService { final int totalFaces = await FaceMLDataDB.instance .getTotalFaceCount(minFaceScore: minFaceScore); + // read the creation times from Files DB, in a map from fileID to creation time + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + final startEmbeddingFetch = DateTime.now(); + // read all embeddings + final allFaceInfoForClustering = await FaceMLDataDB.instance + .getFaceInfoForClustering( + minScore: minFaceScore, + maxFaces: totalFaces, + ) + .then((set) => set.toList()); + // sort the embeddings based on file creation time, oldest first + allFaceInfoForClustering.sort((a, b) { + final aFileId = getFileIdFromFaceId(a.faceID); + final bFileId = getFileIdFromFaceId(b.faceID); + return fileIDToCreationTime[aFileId]! + .compareTo(fileIDToCreationTime[bFileId]!); + }); + _logger.info( + 'Getting and sorting embeddings took ${DateTime.now().difference(startEmbeddingFetch).inMilliseconds} ms', + ); + // Get the current cluster statistics final Map oldClusterSummaries = await FaceMLDataDB.instance.getAllClusterSummary(); - if (clusterInBuckets) { - // read the creation times from Files DB, in a map from fileID to creation time - final fileIDToCreationTime = - await FilesDB.instance.getFileIDToCreationTime(); + if (clusterInBuckets) { const int bucketSize = 20000; - const int batchSize = 20000; const int offsetIncrement = 7500; int offset = 0; int bucket = 1; while (true) { - final bucketStartTime = DateTime.now(); - final faceInfoForClustering = - await FaceMLDataDB.instance.getFaceInfoForClustering( - minScore: minFaceScore, - maxFaces: bucketSize, - offset: offset, - batchSize: batchSize, - ); - if (faceInfoForClustering.isEmpty) { + if (offset > allFaceInfoForClustering.length - 1) { _logger.warning( 'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces', ); @@ -338,9 +350,15 @@ class FaceMlService { break; } + final bucketStartTime = DateTime.now(); + final faceInfoForClustering = allFaceInfoForClustering.sublist( + offset, + min(offset + bucketSize, allFaceInfoForClustering.length), + ); + final clusteringResult = await FaceClusteringService.instance.predictLinear( - faceInfoForClustering, + faceInfoForClustering.toSet(), fileIDToCreationTime: fileIDToCreationTime, offset: offset, oldClusterSummaries: oldClusterSummaries, @@ -354,6 +372,14 @@ class FaceMlService { .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster); await FaceMLDataDB.instance .clusterSummaryUpdate(clusteringResult.newClusterSummaries!); + for (final faceInfo in faceInfoForClustering) { + faceInfo.clusterId ??= + clusteringResult.newFaceIdToCluster[faceInfo.faceID]; + } + for (final clusterUpdate + in clusteringResult.newClusterSummaries!.entries) { + oldClusterSummaries[clusterUpdate.key] = clusterUpdate.value; + } _logger.info( 'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset, in ${DateTime.now().difference(bucketStartTime).inSeconds} seconds', ); @@ -365,28 +391,11 @@ class FaceMlService { bucket++; } } else { - // Read all the embeddings from the database, in a map from faceID to embedding final clusterStartTime = DateTime.now(); - final faceInfoForClustering = - await FaceMLDataDB.instance.getFaceInfoForClustering( - minScore: minFaceScore, - maxFaces: totalFaces, - ); - final gotFaceEmbeddingsTime = DateTime.now(); - _logger.info( - 'read embeddings ${faceInfoForClustering.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms', - ); - - // Read the creation times from Files DB, in a map from fileID to creation time - final fileIDToCreationTime = - await FilesDB.instance.getFileIDToCreationTime(); - _logger.info('read creation times from FilesDB in ' - '${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms'); - // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID final clusteringResult = await FaceClusteringService.instance.predictLinear( - faceInfoForClustering, + allFaceInfoForClustering.toSet(), fileIDToCreationTime: fileIDToCreationTime, oldClusterSummaries: oldClusterSummaries, ); @@ -396,7 +405,7 @@ class FaceMlService { } final clusterDoneTime = DateTime.now(); _logger.info( - 'done with clustering ${faceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', + 'done with clustering ${allFaceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', ); // Store the updated clusterIDs in the database