## Description

- Fixed issue with slow clustering
- Reset for face thumbnail generation

## Test

Tested in debug mode on my pixel.
This commit is contained in:
Laurens Priem
2024-05-28 16:42:47 +05:30
committed by GitHub
12 changed files with 74 additions and 206 deletions

View File

@@ -5,17 +5,15 @@ import "package:ml_linalg/linalg.dart";
/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
///
/// WARNING: This assumes both vectors are already normalized!
/// WARNING: For even more performance, consider calculating the logic below inline!
@pragma("vm:prefer-inline")
double cosineDistanceSIMD(Vector vector1, Vector vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');
}
return 1 - vector1.dot(vector2);
}
/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
///
/// WARNING: Only use when you're not sure if vectors are normalized. If you're sure they are, use [cosineDistanceSIMD] instead for better performance.
/// WARNING: Only use when you're not sure if vectors are normalized. If you're sure they are, use [cosineDistanceSIMD] instead for better performance, or inline for best performance.
double cosineDistanceSIMDSafe(Vector vector1, Vector vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');

View File

@@ -10,11 +10,9 @@ import "package:logging/logging.dart";
import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:simple_cluster/simple_cluster.dart";
import "package:synchronized/synchronized.dart";
class FaceInfo {
@@ -22,7 +20,6 @@ class FaceInfo {
final double? faceScore;
final double? blurValue;
final bool? badFace;
final List<double>? embedding;
final Vector? vEmbedding;
int? clusterId;
String? closestFaceId;
@@ -33,14 +30,13 @@ class FaceInfo {
this.faceScore,
this.blurValue,
this.badFace,
this.embedding,
this.vEmbedding,
this.clusterId,
this.fileCreationTime,
});
}
enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
enum ClusterOperation { linearIncrementalClustering }
class ClusteringResult {
final Map<String, int> newFaceIdToCluster;
@@ -129,10 +125,6 @@ class FaceClusteringService {
final result = FaceClusteringService.runLinearClustering(args);
sendPort.send(result);
break;
case ClusterOperation.dbscanClustering:
final result = FaceClusteringService._runDbscanClustering(args);
sendPort.send(result);
break;
}
} catch (e, stackTrace) {
sendPort
@@ -203,8 +195,6 @@ class FaceClusteringService {
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate.
///
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<ClusteringResult?> predictLinear(
Set<FaceInfoForClustering> input, {
Map<int, int>? fileIDToCreationTime,
@@ -401,55 +391,6 @@ class FaceClusteringService {
}
}
Future<List<List<String>>> predictDbscan(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double eps = 0.3,
int minPts = 5,
}) async {
if (input.isEmpty) {
_logger.warning(
"DBSCAN Clustering dataset of embeddings is empty, returning empty list.",
);
return [];
}
if (isRunning) {
_logger.warning(
"DBSCAN Clustering is already running, returning empty list.",
);
return [];
}
isRunning = true;
// Clustering inside the isolate
_logger.info(
"Start DBSCAN clustering on ${input.length} embeddings inside computer isolate",
);
final stopwatchClustering = Stopwatch()..start();
// final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input);
final List<List<String>> clusterFaceIDs = await _runInIsolate(
(
ClusterOperation.dbscanClustering,
{
'input': input,
'fileIDToCreationTime': fileIDToCreationTime,
'eps': eps,
'minPts': minPts,
}
),
);
// return _runLinearClusteringInComputer(input);
_logger.info(
'DBSCAN Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
);
isRunning = false;
return clusterFaceIDs;
}
static ClusteringResult? runLinearClustering(Map args) {
// final input = args['input'] as Map<String, (int?, Uint8List)>;
final input = args['input'] as Set<FaceInfoForClustering>;
@@ -562,19 +503,10 @@ class FaceClusteringService {
if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces");
}
// WARNING: The loop below is now O(n^2) so be very careful with anything you put in there!
for (int j = i - 1; j >= 0; j--) {
late double distance;
if (sortedFaceInfos[i].vEmbedding != null) {
distance = cosineDistanceSIMD(
sortedFaceInfos[i].vEmbedding!,
sortedFaceInfos[j].vEmbedding!,
);
} else {
distance = cosineDistForNormVectors(
sortedFaceInfos[i].embedding!,
sortedFaceInfos[j].embedding!,
);
}
final double distance = 1 -
sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!);
if (distance < closestDistance) {
if (sortedFaceInfos[j].badFace! &&
distance > conservativeDistanceThreshold) {
@@ -814,10 +746,8 @@ class FaceClusteringService {
double closestDistance = double.infinity;
for (int j = 0; j < totalFaces; j++) {
if (i == j) continue;
final double distance = cosineDistanceSIMD(
faceInfos[i].vEmbedding!,
faceInfos[j].vEmbedding!,
);
final double distance =
1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
@@ -870,10 +800,10 @@ class FaceClusteringService {
for (int i = 0; i < clusterIds.length; i++) {
for (int j = 0; j < clusterIds.length; j++) {
if (i == j) continue;
final double newDistance = cosineDistanceSIMD(
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1,
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
);
final double newDistance = 1 -
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!
.$1
.dot(clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1);
if (newDistance < distance) {
distance = newDistance;
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
@@ -944,55 +874,6 @@ class FaceClusteringService {
newClusterIdToFaceIds: clusterIdToFaceIds,
);
}
static List<List<String>> _runDbscanClustering(Map args) {
final input = args['input'] as Map<String, Uint8List>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final eps = args['eps'] as double;
final minPts = args['minPts'] as int;
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
);
final DBSCAN dbscan = DBSCAN(
epsilon: eps,
minPoints: minPts,
distanceMeasure: cosineDistForNormVectors,
);
// Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = [];
for (final entry in input.entries) {
faceInfos.add(
FaceInfo(
faceID: entry.key,
embedding: EVector.fromBuffer(entry.value).values,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
),
);
}
if (fileIDToCreationTime != null) {
_sortFaceInfosOnCreationTime(faceInfos);
}
// Get the embeddings
final List<List<double>> embeddings =
faceInfos.map((faceInfo) => faceInfo.embedding!).toList();
// Run the DBSCAN clustering
final List<List<int>> clusterOutput = dbscan.run(embeddings);
// final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput
// .map((cluster) => cluster.map((idx) => faceInfos[idx]).toList())
// .toList();
final List<List<String>> clusteredFaceIDs = clusterOutput
.map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList())
.toList();
return clusteredFaceIDs;
}
}
/// Sort the faceInfos based on fileCreationTime, in descending order, so newest faces are first

View File

@@ -651,6 +651,19 @@ class FaceMlService {
min(offset + bucketSize, allFaceInfoForClustering.length),
);
if (faceInfoForClustering.every((face) => face.clusterId != null)) {
_logger.info('Everything in bucket $bucket is already clustered');
if (offset + bucketSize >= totalFaces) {
_logger.info('All faces clustered');
break;
} else {
_logger.info('Skipping to next bucket');
offset += offsetIncrement;
bucket++;
continue;
}
}
final clusteringResult =
await FaceClusteringService.instance.predictLinear(
faceInfoForClustering.toSet(),

View File

@@ -13,7 +13,6 @@ import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@@ -434,7 +433,9 @@ class ClusterFeedbackService {
distanceThreshold: 0.22,
);
if (clusterResult == null || clusterResult.newClusterIdToFaceIds == null || clusterResult.isEmpty) {
if (clusterResult == null ||
clusterResult.newClusterIdToFaceIds == null ||
clusterResult.isEmpty) {
_logger.warning('No clusters found or something went wrong');
return ClusteringResult(newFaceIdToCluster: {});
}
@@ -537,8 +538,7 @@ class ClusterFeedbackService {
EVector.fromBuffer(clusterSummary.$1).values,
dtype: DType.float32,
);
final bigClustersMeanDistance =
cosineDistanceSIMD(biggestMean, currentMean);
final bigClustersMeanDistance = 1 - biggestMean.dot(currentMean);
_logger.info(
"Mean distance between biggest cluster and current cluster: $bigClustersMeanDistance",
);
@@ -595,8 +595,7 @@ class ClusterFeedbackService {
final List<double> trueDistances = [];
for (final biggestEmbedding in biggestSampledEmbeddings) {
for (final currentEmbedding in currentSampledEmbeddings) {
distances
.add(cosineDistanceSIMD(biggestEmbedding, currentEmbedding));
distances.add(1 - biggestEmbedding.dot(currentEmbedding));
trueDistances.add(
biggestEmbedding.distanceTo(
currentEmbedding,
@@ -789,7 +788,7 @@ class ClusterFeedbackService {
final List<double> distances = [];
for (final otherEmbedding in sampledOtherEmbeddings) {
for (final embedding in sampledEmbeddings) {
distances.add(cosineDistanceSIMD(embedding, otherEmbedding));
distances.add(1 - embedding.dot(otherEmbedding));
}
}
distances.sort();
@@ -1086,7 +1085,7 @@ class ClusterFeedbackService {
final fileIdToDistanceMap = {};
for (final entry in faceIdToVectorMap.entries) {
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
cosineDistanceSIMD(personAvg, entry.value);
1 - personAvg.dot(entry.value);
}
w?.log('calculated distances for cluster $clusterID');
suggestion.filesInCluster.sort((b, a) {
@@ -1141,7 +1140,7 @@ List<(int, double)> _calcSuggestionsMean(Map<String, dynamic> args) {
continue;
}
final Vector avg = clusterAvg[personCluster]!;
final distance = cosineDistanceSIMD(avg, otherAvg);
final distance = 1 - avg.dot(otherAvg);
comparisons++;
if (distance < maxClusterDistance) {
if (minDistance == null || distance < minDistance) {

View File

@@ -18,7 +18,7 @@ class MachineLearningController {
static const kMaximumTemperature = 42; // 42 degree celsius
static const kMinimumBatteryLevel = 20; // 20%
static const kDefaultInteractionTimeout = Duration(seconds: 10);
static const kDefaultInteractionTimeout = Duration(seconds: 15);
static const kUnhealthyStates = ["over_heat", "over_voltage", "dead"];
bool _isDeviceHealthy = true;

View File

@@ -293,7 +293,7 @@ class _FaceWidgetState extends State<FaceWidget> {
}
}
Future<Map<String, Uint8List>?> getFaceCrop() async {
Future<Map<String, Uint8List>?> getFaceCrop({int fetchAttempt = 1}) async {
try {
final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID);
if (cachedFace != null) {
@@ -326,6 +326,10 @@ class _FaceWidgetState extends State<FaceWidget> {
error: e,
stackTrace: s,
);
resetPool(fullFile: true);
if (fetchAttempt <= retryLimit) {
return getFaceCrop(fetchAttempt: fetchAttempt + 1);
}
return null;
}
}

View File

@@ -173,7 +173,9 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
}
Future<Map<String, Uint8List>?> getRelevantFaceCrops(
Iterable<Face> faces,
Iterable<Face> faces, {
int fetchAttempt = 1,
}
) async {
try {
final faceIdToCrop = <String, Uint8List>{};
@@ -223,6 +225,10 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
error: e,
stackTrace: s,
);
resetPool(fullFile: true);
if(fetchAttempt <= retryLimit) {
return getRelevantFaceCrops(faces, fetchAttempt: fetchAttempt + 1);
}
return null;
}
}

View File

@@ -188,7 +188,16 @@ class _PersonActionSheetState extends State<PersonActionSheet> {
if (snapshot.hasError) {
log("Error: ${snapshot.error} ${snapshot.stackTrace}}");
//Need to show an error on the UI here
return const SizedBox.shrink();
if (kDebugMode) {
return Column(
children: [
Text('${snapshot.error}'),
Text('${snapshot.stackTrace}'),
],
);
} else {
return const SizedBox.shrink();
}
} else if (snapshot.hasData) {
final persons = snapshot.data!;
final searchResults = _searchQuery.isNotEmpty

View File

@@ -3,7 +3,6 @@ import 'dart:async';
import "package:flutter/foundation.dart";
import 'package:flutter/material.dart';
import 'package:logging/logging.dart';
import "package:ml_linalg/linalg.dart";
import 'package:photos/core/configuration.dart';
import 'package:photos/core/event_bus.dart';
import "package:photos/db/files_db.dart";
@@ -11,12 +10,10 @@ import "package:photos/events/people_changed_event.dart";
import 'package:photos/events/subscription_purchased_event.dart';
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
import 'package:photos/services/collections_service.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
@@ -160,58 +157,6 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
return actions;
}
@Deprecated(
'Used for debugging an issue with conflicts on cluster IDs, resolved now',
)
Future<void> _validateCluster(BuildContext context) async {
_logger.info('_validateCluster called');
final faceMlDb = FaceMLDataDB.instance;
final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID);
final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key));
final embeddings = embeddingsBlobs
.map((key, value) => MapEntry(key, EVector.fromBuffer(value).values));
for (final MapEntry<String, List<double>> embedding in embeddings.entries) {
double closestDistance = double.infinity;
double closestDistance32 = double.infinity;
double closestDistance64 = double.infinity;
String? closestFaceID;
for (final MapEntry<String, List<double>> otherEmbedding
in embeddings.entries) {
if (embedding.key == otherEmbedding.key) {
continue;
}
final distance64 = cosineDistanceSIMD(
Vector.fromList(embedding.value, dtype: DType.float64),
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
);
final distance32 = cosineDistanceSIMD(
Vector.fromList(embedding.value, dtype: DType.float32),
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
);
final distance = cosineDistForNormVectors(
embedding.value,
otherEmbedding.value,
);
if (distance < closestDistance) {
closestDistance = distance;
closestDistance32 = distance32;
closestDistance64 = distance64;
closestFaceID = otherEmbedding.key;
}
}
if (closestDistance > 0.3) {
_logger.severe(
"Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64",
);
}
}
}
Future<void> _onIgnoredClusterClicked(BuildContext context) async {
await showChoiceDialog(
context,

View File

@@ -121,7 +121,7 @@ class PersonFaceWidget extends StatelessWidget {
);
}
Future<Uint8List?> getFaceCrop() async {
Future<Uint8List?> getFaceCrop({int fetchAttempt = 1}) async {
try {
final Face? face = await _getFace();
if (face == null) {
@@ -187,6 +187,10 @@ class PersonFaceWidget extends StatelessWidget {
error: e,
stackTrace: s,
);
resetPool(fullFile: useFullFile);
if(fetchAttempt <= retryLimit) {
return getFaceCrop(fetchAttempt: fetchAttempt + 1);
}
return null;
}
}

View File

@@ -11,11 +11,20 @@ import "package:photos/utils/image_ml_isolate.dart";
import "package:photos/utils/thumbnail_util.dart";
import "package:pool/pool.dart";
void resetPool({required bool fullFile}) {
if (fullFile) {
poolFullFileFaceGenerations = Pool(20, timeout: const Duration(seconds: 15));
} else {
poolThumbnailFaceGenerations = Pool(100, timeout: const Duration(seconds: 15));
}
}
const int retryLimit = 3;
final LRUMap<String, Uint8List?> faceCropCache = LRUMap(1000);
final LRUMap<String, Uint8List?> faceCropThumbnailCache = LRUMap(1000);
final poolFullFileFaceGenerations =
Pool poolFullFileFaceGenerations =
Pool(20, timeout: const Duration(seconds: 15));
final poolThumbnailFaceGenerations =
Pool poolThumbnailFaceGenerations =
Pool(100, timeout: const Duration(seconds: 15));
Future<Map<String, Uint8List>?> getFaceCrops(
EnteFile file,

View File

@@ -12,7 +12,7 @@ description: ente photos application
# Read more about iOS versioning at
# https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
version: 0.8.120+640
version: 0.8.122+642
publish_to: none
environment: