From 1b314687a9b51abe24d5211466820a6632c3da01 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:50:01 +0530 Subject: [PATCH 01/17] Add dependency on nano_id pkg --- mobile/pubspec.lock | 8 ++++++++ mobile/pubspec.yaml | 1 + 2 files changed, 9 insertions(+) diff --git a/mobile/pubspec.lock b/mobile/pubspec.lock index 8372373efb..aa1f0fb41f 100644 --- a/mobile/pubspec.lock +++ b/mobile/pubspec.lock @@ -1608,6 +1608,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.3.2+6" + nanoid: + dependency: "direct main" + description: + name: nanoid + sha256: be3f8752d9046c825df2f3914195151eb876f3ad64b9d833dd0b799b77b8759e + url: "https://pub.dev" + source: hosted + version: "1.0.0" nested: dependency: transitive description: diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index ad2dfc6912..af24aea503 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -123,6 +123,7 @@ dependencies: motionphoto: git: "https://github.com/ente-io/motionphoto.git" move_to_background: ^1.0.2 + nanoid: ^1.0.0 onnx_dart: path: plugins/onnx_dart onnxruntime: From f1d6382b399c6d0aa9fa2086f554cb6678a4437d Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:57:18 +0530 Subject: [PATCH 02/17] Add cluster nanoID --- mobile/lib/models/nanoids/cluster_id.dart | 59 +++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 mobile/lib/models/nanoids/cluster_id.dart diff --git a/mobile/lib/models/nanoids/cluster_id.dart b/mobile/lib/models/nanoids/cluster_id.dart new file mode 100644 index 0000000000..2376c7a014 --- /dev/null +++ b/mobile/lib/models/nanoids/cluster_id.dart @@ -0,0 +1,59 @@ +import "package:flutter/foundation.dart"; +import 'package:nanoid/nanoid.dart'; + +const alphaphet = + '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'; +const clusterIDLength = 22; + +class ClusterID { + final String value; + + // Private constructor + ClusterID._internal(this.value); + + // Factory constructor with validation + factory ClusterID(String value) { + if (!_isValidClusterID(value)) { + throw const FormatException('Invalid NanoID format'); + } + return ClusterID._internal(value); + } + + // Static method to generate a new NanoID + static ClusterID generate() { + return ClusterID("cluster_${customAlphabet(urlAlphabet, clusterIDLength)}"); + } + + // Validation method + static bool _isValidClusterID(String value) { + if (value.length != (clusterIDLength + 8)) { + debugPrint("ClusterID length is not ${clusterIDLength + 8}: $value"); + return false; + } + if (value.startsWith("cluster_")) { + debugPrint("ClusterID doesn't start with _cluster: $value"); + return false; + } + return true; + } + + // Override == operator + @override + bool operator ==(Object other) { + if (identical(this, other)) return true; + return other is ClusterID && other.value == value; + } + + // Override hashCode + @override + int get hashCode => value.hashCode; + + @override + String toString() => value; + + String toJson() => value; + + static ClusterID fromJson(String value) { + return ClusterID(value); + } +} From ffd9e2e06cd8399615a94d5b20b20d52c735d06f Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:25:13 +0530 Subject: [PATCH 03/17] Switch to nano_id for clusterIDs --- mobile/lib/face/db.dart | 134 ++++++++--------- mobile/lib/face/db_fields.dart | 8 +- mobile/lib/face/model/person.dart | 4 +- mobile/lib/models/nanoids/cluster_id.dart | 40 +---- .../face_clustering_service.dart | 64 ++++---- .../face_db_info_for_clustering.dart | 2 +- .../face_ml/feedback/cluster_feedback.dart | 138 +++++++++--------- .../face_ml/person/person_service.dart | 15 +- .../services/machine_learning/ml_service.dart | 2 +- mobile/lib/services/search_service.dart | 10 +- .../bottom_action_bar_widget.dart | 2 +- .../file_selection_actions_widget.dart | 2 +- .../actions/file_selection_overlay_bar.dart | 2 +- .../ui/viewer/file_details/face_widget.dart | 5 +- .../file_details/faces_item_widget.dart | 7 +- .../people/add_person_action_sheet.dart | 6 +- .../lib/ui/viewer/people/cluster_app_bar.dart | 14 +- .../viewer/people/cluster_breakup_page.dart | 6 +- mobile/lib/ui/viewer/people/cluster_page.dart | 2 +- .../people/person_cluster_suggestion.dart | 14 +- .../viewer/people/person_clusters_page.dart | 20 ++- .../search/result/person_face_widget.dart | 2 +- .../ui/viewer/search_tab/people_section.dart | 2 +- 23 files changed, 239 insertions(+), 262 deletions(-) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 81c896e164..cdc014f86e 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -28,7 +28,7 @@ import 'package:sqlite_async/sqlite_async.dart'; class FaceMLDataDB { static final Logger _logger = Logger("FaceMLDataDB"); - static const _databaseName = "ente.face_ml_db.db"; + static const _databaseName = "ente.face_ml_db_v2.db"; // static const _databaseVersion = 1; FaceMLDataDB._privateConstructor(); @@ -136,7 +136,7 @@ class FaceMLDataDB { } Future updateFaceIdToClusterId( - Map faceIDToClusterID, + Map faceIDToClusterID, ) async { final db = await instance.asyncDB; const batchSize = 500; @@ -185,43 +185,43 @@ class FaceMLDataDB { return maps.first['count'] as int; } - Future> clusterIdToFaceCount() async { + Future> clusterIdToFaceCount() async { final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $fcClusterID, COUNT(*) as count FROM $faceClustersTable where $fcClusterID IS NOT NULL GROUP BY $fcClusterID ', ); - final Map result = {}; + final Map result = {}; for (final map in maps) { - result[map[fcClusterID] as int] = map['count'] as int; + result[map[fcClusterID] as String] = map['count'] as int; } return result; } - Future> getPersonIgnoredClusters(String personID) async { + Future> getPersonIgnoredClusters(String personID) async { final db = await instance.asyncDB; // find out clusterIds that are assigned to other persons using the clusters table final List> otherPersonMaps = await db.getAll( 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', [personID], ); - final Set ignoredClusterIDs = - otherPersonMaps.map((e) => e[clusterIDColumn] as int).toSet(); + final Set ignoredClusterIDs = + otherPersonMaps.map((e) => e[clusterIDColumn] as String).toSet(); final List> rejectMaps = await db.getAll( 'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', [personID], ); - final Set rejectClusterIDs = - rejectMaps.map((e) => e[clusterIDColumn] as int).toSet(); + final Set rejectClusterIDs = + rejectMaps.map((e) => e[clusterIDColumn] as String).toSet(); return ignoredClusterIDs.union(rejectClusterIDs); } - Future> getPersonClusterIDs(String personID) async { + Future> getPersonClusterIDs(String personID) async { final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', [personID], ); - return maps.map((e) => e[clusterIDColumn] as int).toSet(); + return maps.map((e) => e[clusterIDColumn] as String).toSet(); } Future clearTable() async { @@ -235,7 +235,7 @@ class FaceMLDataDB { } Future> getFaceEmbeddingsForCluster( - int clusterID, { + String clusterID, { int? limit, }) async { final db = await instance.asyncDB; @@ -246,12 +246,12 @@ class FaceMLDataDB { return maps.map((e) => e[faceEmbeddingBlob] as Uint8List); } - Future>> getFaceEmbeddingsForClusters( - Iterable clusterIDs, { + Future>> getFaceEmbeddingsForClusters( + Iterable clusterIDs, { int? limit, }) async { final db = await instance.asyncDB; - final Map> result = {}; + final Map> result = {}; final selectQuery = ''' SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob @@ -264,7 +264,7 @@ class FaceMLDataDB { final List> maps = await db.getAll(selectQuery); for (final map in maps) { - final clusterID = map[fcClusterID] as int; + final clusterID = map[fcClusterID] as String; final faceEmbedding = map[faceEmbeddingBlob] as Uint8List; result.putIfAbsent(clusterID, () => []).add(faceEmbedding); } @@ -276,7 +276,7 @@ class FaceMLDataDB { required int recentFileID, String? personID, String? avatarFaceId, - int? clusterID, + String? clusterID, }) async { // read person from db final db = await instance.asyncDB; @@ -299,7 +299,7 @@ class FaceMLDataDB { [personID], ); final clusterIDs = - clusterRows.map((e) => e[clusterIDColumn] as int).toList(); + clusterRows.map((e) => e[clusterIDColumn] as String).toList(); final List> faceMaps = await db.getAll( 'SELECT * FROM $facesTable where ' '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' @@ -359,23 +359,23 @@ class FaceMLDataDB { return maps.map((e) => mapRowToFace(e)).toList(); } - Future>> getClusterToFaceIDs( - Set clusterIDs, + Future>> getClusterToFaceIDs( + Set clusterIDs, ) async { final db = await instance.asyncDB; - final Map> result = {}; + final Map> result = {}; final List> maps = await db.getAll( 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable WHERE $fcClusterID IN (${clusterIDs.join(",")})', ); for (final map in maps) { - final clusterID = map[fcClusterID] as int; + final clusterID = map[fcClusterID] as String; final faceID = map[fcFaceId] as String; result.putIfAbsent(clusterID, () => []).add(faceID); } return result; } - Future getClusterIDForFaceID(String faceID) async { + Future getClusterIDForFaceID(String faceID) async { final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $fcClusterID FROM $faceClustersTable WHERE $fcFaceId = ?', @@ -384,24 +384,24 @@ class FaceMLDataDB { if (maps.isEmpty) { return null; } - return maps.first[fcClusterID] as int; + return maps.first[fcClusterID] as String; } - Future>> getAllClusterIdToFaceIDs() async { + Future>> getAllClusterIdToFaceIDs() async { final db = await instance.asyncDB; - final Map> result = {}; + final Map> result = {}; final List> maps = await db.getAll( 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable', ); for (final map in maps) { - final clusterID = map[fcClusterID] as int; + final clusterID = map[fcClusterID] as String; final faceID = map[fcFaceId] as String; result.putIfAbsent(clusterID, () => []).add(faceID); } return result; } - Future> getFaceIDsForCluster(int clusterID) async { + Future> getFaceIDsForCluster(String clusterID) async { final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $fcFaceId FROM $faceClustersTable ' @@ -412,17 +412,17 @@ class FaceMLDataDB { } // Get Map of personID to Map of clusterID to faceIDs - Future>>> + Future>>> getPersonToClusterIdToFaceIds() async { final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $personIdColumn, $faceClustersTable.$fcClusterID, $fcFaceId FROM $clusterPersonTable ' 'LEFT JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$fcClusterID', ); - final Map>> result = {}; + final Map>> result = {}; for (final map in maps) { final personID = map[personIdColumn] as String; - final clusterID = map[fcClusterID] as int; + final clusterID = map[fcClusterID] as String; final faceID = map[fcFaceId] as String; result .putIfAbsent(personID, () => {}) @@ -443,7 +443,7 @@ class FaceMLDataDB { return faceIdsResult.map((e) => e[fcFaceId] as String).toSet(); } - Future> getBlurValuesForCluster(int clusterID) async { + Future> getBlurValuesForCluster(String clusterID) async { final db = await instance.asyncDB; const String query = ''' SELECT $facesTable.$faceBlur @@ -463,29 +463,29 @@ class FaceMLDataDB { return maps.map((e) => e[faceBlur] as double).toSet(); } - Future> getFaceIdsToClusterIds( + Future> getFaceIdsToClusterIds( Iterable faceIds, ) async { final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $fcFaceId, $fcClusterID FROM $faceClustersTable where $fcFaceId IN (${faceIds.map((id) => "'$id'").join(",")})', ); - final Map result = {}; + final Map result = {}; for (final map in maps) { - result[map[fcFaceId] as String] = map[fcClusterID] as int?; + result[map[fcFaceId] as String] = map[fcClusterID] as String?; } return result; } - Future>> getFileIdToClusterIds() async { - final Map> result = {}; + Future>> getFileIdToClusterIds() async { + final Map> result = {}; final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable', ); for (final map in maps) { - final clusterID = map[fcClusterID] as int; + final clusterID = map[fcClusterID] as String; final faceID = map[fcFaceId] as String; final fileID = getFileIdFromFaceId(faceID); result[fileID] = (result[fileID] ?? {})..add(clusterID); @@ -494,7 +494,7 @@ class FaceMLDataDB { } Future forceUpdateClusterIds( - Map faceIDToClusterID, + Map faceIDToClusterID, ) async { final db = await instance.asyncDB; @@ -681,7 +681,7 @@ class FaceMLDataDB { Future assignClusterToPerson({ required String personID, - required int clusterID, + required String clusterID, }) async { final db = await instance.asyncDB; @@ -692,7 +692,7 @@ class FaceMLDataDB { } Future bulkAssignClusterToPersonID( - Map clusterToPersonID, + Map clusterToPersonID, ) async { final db = await instance.asyncDB; @@ -706,7 +706,7 @@ class FaceMLDataDB { Future captureNotPersonFeedback({ required String personID, - required int clusterID, + required String clusterID, }) async { final db = await instance.asyncDB; @@ -717,7 +717,7 @@ class FaceMLDataDB { } Future bulkCaptureNotPersonFeedback( - Map clusterToPersonID, + Map clusterToPersonID, ) async { final db = await instance.asyncDB; @@ -732,7 +732,7 @@ class FaceMLDataDB { Future removeNotPersonFeedback({ required String personID, - required int clusterID, + required String clusterID, }) async { final db = await instance.asyncDB; @@ -744,7 +744,7 @@ class FaceMLDataDB { Future removeClusterToPerson({ required String personID, - required int clusterID, + required String clusterID, }) async { final db = await instance.asyncDB; @@ -755,7 +755,7 @@ class FaceMLDataDB { } // for a given personID, return a map of clusterID to fileIDs using join query - Future>> getFileIdToClusterIDSet(String personID) { + Future>> getFileIdToClusterIDSet(String personID) { final db = instance.asyncDB; return db.then((db) async { final List> maps = await db.getAll( @@ -765,9 +765,9 @@ class FaceMLDataDB { 'WHERE $clusterPersonTable.$personIdColumn = ?', [personID], ); - final Map> result = {}; + final Map> result = {}; for (final map in maps) { - final clusterID = map[clusterIDColumn] as int; + final clusterID = map[clusterIDColumn] as String; final String faceID = map[fcFaceId] as String; final fileID = getFileIdFromFaceId(faceID); result[fileID] = (result[fileID] ?? {})..add(clusterID); @@ -776,8 +776,8 @@ class FaceMLDataDB { }); } - Future>> getFileIdToClusterIDSetForCluster( - Set clusterIDs, + Future>> getFileIdToClusterIDSetForCluster( + Set clusterIDs, ) { final db = instance.asyncDB; return db.then((db) async { @@ -785,9 +785,9 @@ class FaceMLDataDB { 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable ' 'WHERE $fcClusterID IN (${clusterIDs.join(",")})', ); - final Map> result = {}; + final Map> result = {}; for (final map in maps) { - final clusterID = map[fcClusterID] as int; + final clusterID = map[fcClusterID] as String; final faceID = map[fcFaceId] as String; final fileID = getFileIdFromFaceId(faceID); result[fileID] = (result[fileID] ?? {})..add(clusterID); @@ -796,7 +796,9 @@ class FaceMLDataDB { }); } - Future clusterSummaryUpdate(Map summary) async { + Future clusterSummaryUpdate( + Map summary, + ) async { final db = await instance.asyncDB; const String sql = ''' @@ -810,7 +812,7 @@ class FaceMLDataDB { batchCounter = 0; parameterSets.clear(); } - final int clusterID = entry.key; + final String clusterID = entry.key; final int count = entry.value.$2; final Uint8List avg = entry.value.$1; parameterSets.add([clusterID, avg, count]); @@ -819,7 +821,7 @@ class FaceMLDataDB { await db.executeBatch(sql, parameterSets); } - Future deleteClusterSummary(int clusterID) async { + Future deleteClusterSummary(String clusterID) async { final db = await instance.asyncDB; const String sqlDelete = 'DELETE FROM $clusterSummaryTable WHERE $clusterIDColumn = ?'; @@ -827,16 +829,16 @@ class FaceMLDataDB { } /// Returns a map of clusterID to (avg embedding, count) - Future> getAllClusterSummary([ + Future> getAllClusterSummary([ int? minClusterSize, ]) async { final db = await instance.asyncDB; - final Map result = {}; + final Map result = {}; final rows = await db.getAll( 'SELECT * FROM $clusterSummaryTable${minClusterSize != null ? ' WHERE $countColumn >= $minClusterSize' : ''}', ); for (final r in rows) { - final id = r[clusterIDColumn] as int; + final id = r[clusterIDColumn] as String; final avg = r[avgColumn] as Uint8List; final count = r[countColumn] as int; result[id] = (avg, count); @@ -844,16 +846,16 @@ class FaceMLDataDB { return result; } - Future> getClusterToClusterSummary( - Iterable clusterIDs, + Future> getClusterToClusterSummary( + Iterable clusterIDs, ) async { final db = await instance.asyncDB; - final Map result = {}; + final Map result = {}; final rows = await db.getAll( 'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${clusterIDs.join(",")})', ); for (final r in rows) { - final id = r[clusterIDColumn] as int; + final id = r[clusterIDColumn] as String; final avg = r[avgColumn] as Uint8List; final count = r[countColumn] as int; result[id] = (avg, count); @@ -861,14 +863,14 @@ class FaceMLDataDB { return result; } - Future> getClusterIDToPersonID() async { + Future> getClusterIDToPersonID() async { final db = await instance.asyncDB; final List> maps = await db.getAll( 'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable', ); - final Map result = {}; + final Map result = {}; for (final map in maps) { - result[map[clusterIDColumn] as int] = map[personIdColumn] as String; + result[map[clusterIDColumn] as String] = map[personIdColumn] as String; } return result; } diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart index 8ad14ae282..c1de4780bc 100644 --- a/mobile/lib/face/db_fields.dart +++ b/mobile/lib/face/db_fields.dart @@ -41,7 +41,7 @@ const fcFaceId = 'face_id'; const createFaceClustersTable = ''' CREATE TABLE IF NOT EXISTS $faceClustersTable ( $fcFaceId TEXT NOT NULL, - $fcClusterID INTEGER NOT NULL, + $fcClusterID TEXT NOT NULL, PRIMARY KEY($fcFaceId) ); '''; @@ -59,7 +59,7 @@ const clusterIDColumn = 'cluster_id'; const createClusterPersonTable = ''' CREATE TABLE IF NOT EXISTS $clusterPersonTable ( $personIdColumn TEXT NOT NULL, - $clusterIDColumn INTEGER NOT NULL, + $clusterIDColumn TEXT NOT NULL, PRIMARY KEY($personIdColumn, $clusterIDColumn) ); '''; @@ -72,7 +72,7 @@ const avgColumn = 'avg'; const countColumn = 'count'; const createClusterSummaryTable = ''' CREATE TABLE IF NOT EXISTS $clusterSummaryTable ( - $clusterIDColumn INTEGER NOT NULL, + $clusterIDColumn TEXT NOT NULL, $avgColumn BLOB NOT NULL, $countColumn INTEGER NOT NULL, PRIMARY KEY($clusterIDColumn) @@ -89,7 +89,7 @@ const notPersonFeedback = 'not_person_feedback'; const createNotPersonFeedbackTable = ''' CREATE TABLE IF NOT EXISTS $notPersonFeedback ( $personIdColumn TEXT NOT NULL, - $clusterIDColumn INTEGER NOT NULL, + $clusterIDColumn TEXT NOT NULL, PRIMARY KEY($personIdColumn, $clusterIDColumn) ); '''; diff --git a/mobile/lib/face/model/person.dart b/mobile/lib/face/model/person.dart index cedec7a0dc..7536d3b0a6 100644 --- a/mobile/lib/face/model/person.dart +++ b/mobile/lib/face/model/person.dart @@ -24,7 +24,7 @@ class PersonEntity { } class ClusterInfo { - final int id; + final String id; final Set faces; ClusterInfo({ required this.id, @@ -40,7 +40,7 @@ class ClusterInfo { // from Json factory ClusterInfo.fromJson(Map json) { return ClusterInfo( - id: json['id'] as int, + id: json['id'] as String, faces: (json['faces'] as List).map((e) => e as String).toSet(), ); } diff --git a/mobile/lib/models/nanoids/cluster_id.dart b/mobile/lib/models/nanoids/cluster_id.dart index 2376c7a014..278e0bb1a2 100644 --- a/mobile/lib/models/nanoids/cluster_id.dart +++ b/mobile/lib/models/nanoids/cluster_id.dart @@ -6,26 +6,12 @@ const alphaphet = const clusterIDLength = 22; class ClusterID { - final String value; - - // Private constructor - ClusterID._internal(this.value); - - // Factory constructor with validation - factory ClusterID(String value) { - if (!_isValidClusterID(value)) { - throw const FormatException('Invalid NanoID format'); - } - return ClusterID._internal(value); - } - - // Static method to generate a new NanoID - static ClusterID generate() { - return ClusterID("cluster_${customAlphabet(urlAlphabet, clusterIDLength)}"); + static String generate() { + return "cluster_${customAlphabet(urlAlphabet, clusterIDLength)}"; } // Validation method - static bool _isValidClusterID(String value) { + static bool isValidClusterID(String value) { if (value.length != (clusterIDLength + 8)) { debugPrint("ClusterID length is not ${clusterIDLength + 8}: $value"); return false; @@ -36,24 +22,4 @@ class ClusterID { } return true; } - - // Override == operator - @override - bool operator ==(Object other) { - if (identical(this, other)) return true; - return other is ClusterID && other.value == value; - } - - // Override hashCode - @override - int get hashCode => value.hashCode; - - @override - String toString() => value; - - String toJson() => value; - - static ClusterID fromJson(String value) { - return ClusterID(value); - } } diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 69b8fc06b8..d33b69c97a 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -1,7 +1,6 @@ import "dart:async"; import "dart:developer"; import "dart:isolate"; -import "dart:math" show max; import "dart:typed_data" show Uint8List; import "package:computer/computer.dart"; @@ -10,6 +9,7 @@ import "package:logging/logging.dart"; import "package:ml_linalg/dtype.dart"; import "package:ml_linalg/vector.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import "package:photos/models/nanoids/cluster_id.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/ml_result.dart"; @@ -21,7 +21,7 @@ class FaceInfo { final double? blurValue; final bool? badFace; final Vector? vEmbedding; - int? clusterId; + String? clusterId; String? closestFaceId; int? closestDist; int? fileCreationTime; @@ -39,9 +39,9 @@ class FaceInfo { enum ClusterOperation { linearIncrementalClustering } class ClusteringResult { - final Map newFaceIdToCluster; - final Map> newClusterIdToFaceIds; - final Map newClusterSummaries; + final Map newFaceIdToCluster; + final Map> newClusterIdToFaceIds; + final Map newClusterSummaries; bool get isEmpty => newFaceIdToCluster.isEmpty; @@ -210,7 +210,7 @@ class FaceClusteringService { double conservativeDistanceThreshold = kConservativeDistanceThreshold, bool useDynamicThreshold = true, int? offset, - required Map oldClusterSummaries, + required Map oldClusterSummaries, }) async { if (input.isEmpty) { _logger.warning( @@ -417,7 +417,7 @@ ClusteringResult _runLinearClustering(Map args) { final useDynamicThreshold = args['useDynamicThreshold'] as bool; final offset = args['offset'] as int?; final oldClusterSummaries = - args['oldClusterSummaries'] as Map?; + args['oldClusterSummaries'] as Map?; log( "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", @@ -491,17 +491,17 @@ ClusteringResult _runLinearClustering(Map args) { "[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces in total in this round ${offset != null ? "on top of ${offset + facesWithClusterID.length} earlier processed faces" : ""}", ); // set current epoch time as clusterID - int clusterID = DateTime.now().microsecondsSinceEpoch; + String clusterID = ClusterID.generate(); if (facesWithClusterID.isEmpty) { // assign a clusterID to the first face sortedFaceInfos[0].clusterId = clusterID; - clusterID++; + clusterID = ClusterID.generate(); } final stopwatchClustering = Stopwatch()..start(); for (int i = 1; i < totalFaces; i++) { // Incremental clustering, so we can skip faces that already have a clusterId if (sortedFaceInfos[i].clusterId != null) { - clusterID = max(clusterID, sortedFaceInfos[i].clusterId!); + // clusterID = max(clusterID, sortedFaceInfos[i].clusterId!); continue; } @@ -539,25 +539,25 @@ ClusteringResult _runLinearClustering(Map args) { log( " [ClusterIsolate] [WARNING] ${DateTime.now()} Found new cluster $clusterID", ); - clusterID++; + clusterID = ClusterID.generate(); sortedFaceInfos[closestIdx].clusterId = clusterID; } sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; } else { - clusterID++; + clusterID = ClusterID.generate(); sortedFaceInfos[i].clusterId = clusterID; } } // Finally, assign the new clusterId to the faces - final Map newFaceIdToCluster = {}; + final Map newFaceIdToCluster = {}; final newClusteredFaceInfos = sortedFaceInfos.sublist(alreadyClusteredCount); for (final faceInfo in newClusteredFaceInfos) { newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; } // Create a map of clusterId to faceIds - final Map> clusterIdToFaceIds = {}; + final Map> clusterIdToFaceIds = {}; for (final entry in newFaceIdToCluster.entries) { final clusterID = entry.value; if (clusterIdToFaceIds.containsKey(clusterID)) { @@ -599,7 +599,7 @@ ClusteringResult _runCompleteClustering(Map args) { final distanceThreshold = args['distanceThreshold'] as double; final mergeThreshold = args['mergeThreshold'] as double; final oldClusterSummaries = - args['oldClusterSummaries'] as Map?; + args['oldClusterSummaries'] as Map?; log( "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering", @@ -634,11 +634,10 @@ ClusteringResult _runCompleteClustering(Map args) { "[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering", ); - // set current epoch time as clusterID - int clusterID = DateTime.now().microsecondsSinceEpoch; + String clusterID = ClusterID.generate(); // Start actual clustering - final Map newFaceIdToCluster = {}; + final Map newFaceIdToCluster = {}; final stopwatchClustering = Stopwatch()..start(); for (int i = 0; i < totalFaces; i++) { if ((i + 1) % 250 == 0) { @@ -659,18 +658,18 @@ ClusteringResult _runCompleteClustering(Map args) { if (closestDistance < distanceThreshold) { if (faceInfos[closestIdx].clusterId == null) { - clusterID++; + clusterID = ClusterID.generate(); faceInfos[closestIdx].clusterId = clusterID; } faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!; } else { - clusterID++; + clusterID = ClusterID.generate(); faceInfos[i].clusterId = clusterID; } } // Now calculate the mean of the embeddings for each cluster - final Map> clusterIdToFaceInfos = {}; + final Map> clusterIdToFaceInfos = {}; for (final faceInfo in faceInfos) { if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) { clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo); @@ -678,7 +677,7 @@ ClusteringResult _runCompleteClustering(Map args) { clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; } } - final Map clusterIdToMeanEmbeddingAndWeight = {}; + final Map clusterIdToMeanEmbeddingAndWeight = {}; for (final clusterId in clusterIdToFaceInfos.keys) { final List embeddings = clusterIdToFaceInfos[clusterId]! .map((faceInfo) => faceInfo.vEmbedding!) @@ -691,13 +690,14 @@ ClusteringResult _runCompleteClustering(Map args) { } // Now merge the clusters that are close to each other, based on mean embedding - final List<(int, int)> mergedClustersList = []; - final List clusterIds = clusterIdToMeanEmbeddingAndWeight.keys.toList(); + final List<(String, String)> mergedClustersList = []; + final List clusterIds = + clusterIdToMeanEmbeddingAndWeight.keys.toList(); log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges'); while (true) { if (clusterIds.length < 2) break; double distance = double.infinity; - (int, int) clusterIDsToMerge = (-1, -1); + (String, String) clusterIDsToMerge = ('', ''); for (int i = 0; i < clusterIds.length; i++) { for (int j = 0; j < clusterIds.length; j++) { if (i == j) continue; @@ -749,7 +749,7 @@ ClusteringResult _runCompleteClustering(Map args) { newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; } - final Map> clusterIdToFaceIds = {}; + final Map> clusterIdToFaceIds = {}; for (final entry in newFaceIdToCluster.entries) { final clusterID = entry.value; if (clusterIdToFaceIds.containsKey(clusterID)) { @@ -794,12 +794,12 @@ void _sortFaceInfosOnCreationTime( }); } -Map _updateClusterSummaries({ +Map _updateClusterSummaries({ required List newFaceInfos, - Map? oldSummary, + Map? oldSummary, }) { final calcSummariesStart = DateTime.now(); - final Map> newClusterIdToFaceInfos = {}; + final Map> newClusterIdToFaceInfos = {}; for (final faceInfo in newFaceInfos) { if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) { newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo); @@ -808,7 +808,7 @@ Map _updateClusterSummaries({ } } - final Map newClusterSummaries = {}; + final Map newClusterSummaries = {}; for (final clusterId in newClusterIdToFaceInfos.keys) { final List newEmbeddings = newClusterIdToFaceInfos[clusterId]! .map((faceInfo) => faceInfo.vEmbedding!) @@ -849,13 +849,13 @@ void _analyzeClusterResults(List sortedFaceInfos) { if (!kDebugMode) return; final stopwatch = Stopwatch()..start(); - final Map faceIdToCluster = {}; + final Map faceIdToCluster = {}; for (final faceInfo in sortedFaceInfos) { faceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; } // Find faceIDs that are part of a cluster which is larger than 5 and are new faceIDs - final Map clusterIdToSize = {}; + final Map clusterIdToSize = {}; faceIdToCluster.forEach((key, value) { if (clusterIdToSize.containsKey(value)) { clusterIdToSize[value] = clusterIdToSize[value]! + 1; diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart index 12754301bb..1822d3fb55 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart @@ -2,7 +2,7 @@ import "dart:typed_data" show Uint8List; class FaceDbInfoForClustering { final String faceID; - int? clusterId; + String? clusterId; final Uint8List embeddingBytes; final double faceScore; final double blurValue; diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 9da4f1ce2f..85fb95af75 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -20,7 +20,7 @@ import "package:photos/services/machine_learning/ml_result.dart"; import "package:photos/services/search_service.dart"; class ClusterSuggestion { - final int clusterIDToMerge; + final String clusterIDToMerge; final double distancePersonToCluster; final bool usedOnlyMeanForSuggestion; final List filesInCluster; @@ -43,13 +43,13 @@ class ClusterFeedbackService { static final ClusterFeedbackService instance = ClusterFeedbackService._privateConstructor(); - static int lastViewedClusterID = -1; - static setLastViewedClusterID(int clusterID) { + static String lastViewedClusterID = ''; + static setLastViewedClusterID(String clusterID) { lastViewedClusterID = clusterID; } static resetLastViewedClusterID() { - lastViewedClusterID = -1; + lastViewedClusterID = ''; } /// Returns a list of cluster suggestions for a person. Each suggestion is a tuple of the following elements: @@ -68,7 +68,7 @@ class ClusterFeedbackService { try { // Get the suggestions for the person using centroids and median final startTime = DateTime.now(); - final List<(int, double, bool)> foundSuggestions = + final List<(String, double, bool)> foundSuggestions = await _getSuggestions(person); final findSuggestionsTime = DateTime.now(); _logger.info( @@ -77,13 +77,13 @@ class ClusterFeedbackService { // Get the files for the suggestions final suggestionClusterIDs = foundSuggestions.map((e) => e.$1).toSet(); - final Map> fileIdToClusterID = + final Map> fileIdToClusterID = await FaceMLDataDB.instance.getFileIdToClusterIDSetForCluster( suggestionClusterIDs, ); final clusterIdToFaceIDs = await FaceMLDataDB.instance.getClusterToFaceIDs(suggestionClusterIDs); - final Map> clusterIDToFiles = {}; + final Map> clusterIDToFiles = {}; final allFiles = await SearchService.instance.getAllFiles(); for (final f in allFiles) { if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { @@ -180,7 +180,7 @@ class ClusterFeedbackService { .clusterSummaryUpdate(clusterResult.newClusterSummaries); // Make sure the deleted faces don't get suggested in the future - final notClusterIdToPersonId = {}; + final notClusterIdToPersonId = {}; for (final clusterId in newFaceIdToClusterID.values.toSet()) { notClusterIdToPersonId[clusterId] = p.remoteID; } @@ -202,7 +202,7 @@ class ClusterFeedbackService { Future removeFilesFromCluster( List files, - int clusterID, + String clusterID, ) async { _logger.info('removeFilesFromCluster called'); try { @@ -260,8 +260,8 @@ class ClusterFeedbackService { } } - Future addFacesToCluster(List faceIDs, int clusterID) async { - final faceIDToClusterID = {}; + Future addFacesToCluster(List faceIDs, String clusterID) async { + final faceIDToClusterID = {}; for (final faceID in faceIDs) { faceIDToClusterID[faceID] = clusterID; } @@ -272,7 +272,7 @@ class ClusterFeedbackService { Future checkAndDoAutomaticMerges( PersonEntity p, { - required int personClusterID, + required String personClusterID, }) async { final faceMlDb = FaceMLDataDB.instance; final faceIDs = await faceMlDb.getFaceIDsForCluster(personClusterID); @@ -293,7 +293,7 @@ class ClusterFeedbackService { // Get and update the cluster summary to get the avg (centroid) and count final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); - final Map clusterAvg = await _getUpdateClusterAvg( + final Map clusterAvg = await _getUpdateClusterAvg( allClusterIdsToCountMap, ignoredClusters, minClusterSize: kMinimumClusterSizeSearchResult, @@ -301,7 +301,8 @@ class ClusterFeedbackService { watch.log('computed avg for ${clusterAvg.length} clusters'); // Find the actual closest clusters for the person - final List<(int, double)> suggestions = await calcSuggestionsMeanInComputer( + final List<(String, double)> suggestions = + await calcSuggestionsMeanInComputer( clusterAvg, {personClusterID}, ignoredClusters, @@ -333,16 +334,16 @@ class ClusterFeedbackService { return true; } - Future ignoreCluster(int clusterID) async { + Future ignoreCluster(String clusterID) async { await PersonService.instance.addPerson('', clusterID); Bus.instance.fire(PeopleChangedEvent()); return; } - Future> checkForMixedClusters() async { + Future> checkForMixedClusters() async { final faceMlDb = FaceMLDataDB.instance; final allClusterToFaceCount = await faceMlDb.clusterIdToFaceCount(); - final clustersToInspect = []; + final clustersToInspect = []; for (final clusterID in allClusterToFaceCount.keys) { if (allClusterToFaceCount[clusterID]! > 20 && allClusterToFaceCount[clusterID]! < 500) { @@ -353,7 +354,7 @@ class ClusterFeedbackService { final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime(); - final susClusters = <(int, int)>[]; + final susClusters = <(String, int)>[]; final inspectionStart = DateTime.now(); for (final clusterID in clustersToInspect) { @@ -387,15 +388,15 @@ class ClusterFeedbackService { ); // Now find the sizes of the biggest and second biggest cluster - final int biggestClusterID = newClusterIdToCount.keys.reduce((a, b) { + final String biggestClusterID = newClusterIdToCount.keys.reduce((a, b) { return newClusterIdToCount[a]! > newClusterIdToCount[b]! ? a : b; }); final int biggestSize = newClusterIdToCount[biggestClusterID]!; final biggestRatio = biggestSize / originalClusterSize; if (newClusterIdToCount.length > 1) { - final List clusterIDs = newClusterIdToCount.keys.toList(); + final List clusterIDs = newClusterIdToCount.keys.toList(); clusterIDs.remove(biggestClusterID); - final int secondBiggestClusterID = clusterIDs.reduce((a, b) { + final String secondBiggestClusterID = clusterIDs.reduce((a, b) { return newClusterIdToCount[a]! > newClusterIdToCount[b]! ? a : b; }); final int secondBiggestSize = @@ -432,7 +433,7 @@ class ClusterFeedbackService { } Future breakUpCluster( - int clusterID, { + String clusterID, { bool useDbscan = false, }) async { _logger.info( @@ -491,7 +492,7 @@ class ClusterFeedbackService { /// 1. clusterID: the ID of the cluster /// 2. distance: the distance between the person's cluster and the suggestion /// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false) - Future> _getSuggestions( + Future> _getSuggestions( PersonEntity p, { int sampleSize = 50, double maxMedianDistance = 0.62, @@ -520,8 +521,8 @@ class ClusterFeedbackService { .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0) .reduce((value, element) => min(value, element)); final checkSizes = [100, 20, kMinimumClusterSizeSearchResult, 10, 5, 1]; - Map clusterAvgBigClusters = {}; - final List<(int, double)> suggestionsMean = []; + Map clusterAvgBigClusters = {}; + final List<(String, double)> suggestionsMean = []; for (final minimumSize in checkSizes.toSet()) { if (smallestPersonClusterSize >= min(minimumSize, kMinimumClusterSizeSearchResult)) { @@ -533,7 +534,7 @@ class ClusterFeedbackService { w?.log( 'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', ); - final List<(int, double)> suggestionsMeanBigClusters = + final List<(String, double)> suggestionsMeanBigClusters = await calcSuggestionsMeanInComputer( clusterAvgBigClusters, personClusters, @@ -570,7 +571,7 @@ class ClusterFeedbackService { // Find the other cluster candidates based on the median final clusterAvg = clusterAvgBigClusters; - final List<(int, double)> moreSuggestionsMean = + final List<(String, double)> moreSuggestionsMean = await calcSuggestionsMeanInComputer( clusterAvg, personClusters, @@ -616,8 +617,8 @@ class ClusterFeedbackService { .toList(growable: false); // Find the actual closest clusters for the person using median - final List<(int, double)> suggestionsMedian = []; - final List<(int, double)> greatSuggestionsMedian = []; + final List<(String, double)> suggestionsMedian = []; + final List<(String, double)> greatSuggestionsMedian = []; double minMedianDistance = maxMedianDistance; for (final otherClusterId in otherClusterIdsCandidates) { final Iterable otherEmbeddingsProto = @@ -663,11 +664,12 @@ class ClusterFeedbackService { _logger.info("Found suggestions using median: $suggestionsMedian"); } - final List<(int, double, bool)> finalSuggestionsMedian = suggestionsMedian - .map(((e) => (e.$1, e.$2, false))) - .toList(growable: false) - .reversed - .toList(growable: false); + final List<(String, double, bool)> finalSuggestionsMedian = + suggestionsMedian + .map(((e) => (e.$1, e.$2, false))) + .toList(growable: false) + .reversed + .toList(growable: false); if (greatSuggestionsMedian.isNotEmpty) { _logger.info( @@ -687,9 +689,9 @@ class ClusterFeedbackService { return finalSuggestionsMedian; } - Future> _getUpdateClusterAvg( - Map allClusterIdsToCountMap, - Set ignoredClusters, { + Future> _getUpdateClusterAvg( + Map allClusterIdsToCountMap, + Set ignoredClusters, { int minClusterSize = 1, int maxClusterInCurrentRun = 500, int maxEmbeddingToRead = 10000, @@ -701,9 +703,9 @@ class ClusterFeedbackService { 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun', ); - final Map clusterToSummary = + final Map clusterToSummary = await faceMlDb.getAllClusterSummary(minClusterSize); - final Map updatesForClusterSummary = {}; + final Map updatesForClusterSummary = {}; w?.log( 'getUpdateClusterAvg database call for getAllClusterSummary', @@ -717,7 +719,7 @@ class ClusterFeedbackService { 'ignoredClusters': ignoredClusters, 'clusterToSummary': clusterToSummary, }, - ) as (Map, Set, int, int, int); + ) as (Map, Set, int, int, int); final clusterAvg = serializationEmbeddings.$1; final allClusterIds = serializationEmbeddings.$2; final ignoredClustersCnt = serializationEmbeddings.$3; @@ -753,7 +755,7 @@ class ClusterFeedbackService { w?.reset(); int currentPendingRead = 0; - final List clusterIdsToRead = []; + final List clusterIdsToRead = []; for (final clusterID in sortedClusterIDs) { if (maxClusterInCurrentRun-- <= 0) { break; @@ -772,9 +774,9 @@ class ClusterFeedbackService { } } - final Map> clusterEmbeddings = await FaceMLDataDB - .instance - .getFaceEmbeddingsForClusters(clusterIdsToRead); + final Map> clusterEmbeddings = + await FaceMLDataDB.instance + .getFaceEmbeddingsForClusters(clusterIdsToRead); w?.logAndReset( 'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters', @@ -817,10 +819,10 @@ class ClusterFeedbackService { return clusterAvg; } - Future> calcSuggestionsMeanInComputer( - Map clusterAvg, - Set personClusters, - Set ignoredClusters, + Future> calcSuggestionsMeanInComputer( + Map clusterAvg, + Set personClusters, + Set ignoredClusters, double maxClusterDistance, ) async { return await _computer.compute( @@ -889,7 +891,7 @@ class ClusterFeedbackService { // Get the cluster averages for the person's clusters and the suggestions' clusters final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); - final Map personClusterToSummary = + final Map personClusterToSummary = await faceMlDb.getClusterToClusterSummary(personClusters); final clusterSummaryCallTime = DateTime.now(); @@ -975,7 +977,7 @@ class ClusterFeedbackService { } Future debugLogClusterBlurValues( - int clusterID, { + String clusterID, { int? clusterSize, bool logClusterSummary = false, bool logBlurValues = false, @@ -986,7 +988,8 @@ class ClusterFeedbackService { _logger.info( "Debug logging for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}", ); - const int biggestClusterID = 1715061228725148; + // todo:(laurens) remove to review + const String biggestClusterID = 'some random id'; // Logging the cluster summary for the cluster if (logClusterSummary) { @@ -1117,21 +1120,22 @@ class ClusterFeedbackService { } /// Returns a map of person's clusterID to map of closest clusterID to with disstance -List<(int, double)> _calcSuggestionsMean(Map args) { +List<(String, double)> _calcSuggestionsMean(Map args) { // Fill in args - final Map clusterAvg = args['clusterAvg']; - final Set personClusters = args['personClusters']; - final Set ignoredClusters = args['ignoredClusters']; + final Map clusterAvg = args['clusterAvg']; + final Set personClusters = args['personClusters']; + final Set ignoredClusters = args['ignoredClusters']; final double maxClusterDistance = args['maxClusterDistance']; - final Map> suggestions = {}; + final Map> suggestions = {}; const suggestionMax = 2000; int suggestionCount = 0; int comparisons = 0; final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); // ignore the clusters that belong to the person or is ignored - Set otherClusters = clusterAvg.keys.toSet().difference(personClusters); + Set otherClusters = + clusterAvg.keys.toSet().difference(personClusters); otherClusters = otherClusters.difference(ignoredClusters); for (final otherClusterID in otherClusters) { @@ -1140,7 +1144,7 @@ List<(int, double)> _calcSuggestionsMean(Map args) { dev.log('[WARNING] no avg for othercluster $otherClusterID'); continue; } - int? nearestPersonCluster; + String? nearestPersonCluster; double? minDistance; for (final personCluster in personClusters) { if (clusterAvg[personCluster] == null) { @@ -1172,8 +1176,8 @@ List<(int, double)> _calcSuggestionsMean(Map args) { ); if (suggestions.isNotEmpty) { - final List<(int, double)> suggestClusterIds = []; - for (final List<(int, double)> suggestion in suggestions.values) { + final List<(String, double)> suggestClusterIds = []; + for (final List<(String, double)> suggestion in suggestions.values) { suggestClusterIds.addAll(suggestion); } suggestClusterIds.sort( @@ -1186,20 +1190,22 @@ List<(int, double)> _calcSuggestionsMean(Map args) { return suggestClusterIds.sublist(0, min(suggestClusterIds.length, 20)); } else { dev.log("No suggestions found using mean"); - return <(int, double)>[]; + return <(String, double)>[]; } } -Future<(Map, Set, int, int, int)> +Future<(Map, Set, int, int, int)> checkAndSerializeCurrentClusterMeans( Map args, ) async { - final Map allClusterIdsToCountMap = args['allClusterIdsToCountMap']; + final Map allClusterIdsToCountMap = + args['allClusterIdsToCountMap']; final int minClusterSize = args['minClusterSize'] ?? 1; - final Set ignoredClusters = args['ignoredClusters'] ?? {}; - final Map clusterToSummary = args['clusterToSummary']; + final Set ignoredClusters = args['ignoredClusters'] ?? {}; + final Map clusterToSummary = + args['clusterToSummary']; - final Map clusterAvg = {}; + final Map clusterAvg = {}; final allClusterIds = allClusterIdsToCountMap.keys.toSet(); int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0; diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index 4931bb0eb7..187e51faf6 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -82,7 +82,7 @@ class PersonService { continue; } final personData = person.data; - final Map> dbPersonCluster = + final Map> dbPersonCluster = dbPersonClusterInfo[personID]!; if (_shouldUpdateRemotePerson(personData, dbPersonCluster)) { final personData = person.data; @@ -109,7 +109,7 @@ class PersonService { bool _shouldUpdateRemotePerson( PersonData personData, - Map> dbPersonCluster, + Map> dbPersonCluster, ) { bool result = false; if ((personData.assigned?.length ?? 0) != dbPersonCluster.length) { @@ -152,7 +152,7 @@ class PersonService { Future addPerson( String name, - int clusterID, { + String clusterID, { bool isHidden = false, }) async { final faceIds = await faceMLDataDB.getFaceIDsForCluster(clusterID); @@ -179,7 +179,7 @@ class PersonService { Future removeClusterToPerson({ required String personID, - required int clusterID, + required String clusterID, }) async { final person = (await getPerson(personID))!; final personData = person.data; @@ -201,7 +201,7 @@ class PersonService { required Set faceIDs, }) async { final personData = person.data; - final List emptiedClusters = []; + final List emptiedClusters = []; for (final cluster in personData.assigned!) { cluster.faces.removeWhere((faceID) => faceIDs.contains(faceID)); if (cluster.faces.isEmpty) { @@ -219,7 +219,6 @@ class PersonService { ); } - await entityService.addOrUpdate( EntityType.person, json.encode(personData.toJson()), @@ -256,8 +255,8 @@ class PersonService { await entityService.syncEntities(); final entities = await entityService.getEntities(EntityType.person); entities.sort((a, b) => a.updatedAt.compareTo(b.updatedAt)); - final Map faceIdToClusterID = {}; - final Map clusterToPersonID = {}; + final Map faceIdToClusterID = {}; + final Map clusterToPersonID = {}; for (var e in entities) { final personData = PersonData.fromJson(json.decode(e.data)); int faceCount = 0; diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 7bf0e2dfd5..1ed24f0d46 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -250,7 +250,7 @@ class MLService { ); // Get the current cluster statistics - final Map oldClusterSummaries = + final Map oldClusterSummaries = await FaceMLDataDB.instance.getAllClusterSummary(); if (clusterInBuckets) { diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index b9b6736e3c..9329a79081 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -736,14 +736,14 @@ class SearchService { return searchResults; } - Future>> getClusterFilesForPersonID( + Future>> getClusterFilesForPersonID( String personID, ) async { _logger.info('getClusterFilesForPersonID $personID'); - final Map> fileIdToClusterID = + final Map> fileIdToClusterID = await FaceMLDataDB.instance.getFileIdToClusterIDSet(personID); _logger.info('faceDbDone getClusterFilesForPersonID $personID'); - final Map> clusterIDToFiles = {}; + final Map> clusterIDToFiles = {}; final allFiles = await getAllFiles(); for (final f in allFiles) { if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { @@ -765,7 +765,7 @@ class SearchService { Future> getAllFace(int? limit) async { try { debugPrint("getting faces"); - final Map> fileIdToClusterID = + final Map> fileIdToClusterID = await FaceMLDataDB.instance.getFileIdToClusterIds(); final Map personIdToPerson = await PersonService.instance.getPersonsMap(); @@ -773,7 +773,7 @@ class SearchService { await FaceMLDataDB.instance.getClusterIDToPersonID(); final List facesResult = []; - final Map> clusterIdToFiles = {}; + final Map> clusterIdToFiles = {}; final Map> personIdToFiles = {}; final allFiles = await getAllFiles(); for (final f in allFiles) { diff --git a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart index a0c50be21d..d295d85e1c 100644 --- a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart +++ b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart @@ -13,7 +13,7 @@ class BottomActionBarWidget extends StatelessWidget { final GalleryType galleryType; final Collection? collection; final PersonEntity? person; - final int? clusterID; + final String? clusterID; final SelectedFiles selectedFiles; final VoidCallback? onCancel; final Color? backgroundColor; diff --git a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart index 9e75474d40..cf90489c6d 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart @@ -53,7 +53,7 @@ class FileSelectionActionsWidget extends StatefulWidget { final DeviceCollection? deviceCollection; final SelectedFiles selectedFiles; final PersonEntity? person; - final int? clusterID; + final String? clusterID; const FileSelectionActionsWidget( this.type, diff --git a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart index 9ba5adf3fb..bb9825f649 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart @@ -14,7 +14,7 @@ class FileSelectionOverlayBar extends StatefulWidget { final Collection? collection; final Color? backgroundColor; final PersonEntity? person; - final int? clusterID; + final String? clusterID; const FileSelectionOverlayBar( this.galleryType, diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index 9c0d7decdb..7465f7c8c0 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -9,6 +9,7 @@ import "package:photos/face/db.dart"; import "package:photos/face/model/face.dart"; import "package:photos/face/model/person.dart"; import 'package:photos/models/file/file.dart'; +import "package:photos/models/nanoids/cluster_id.dart"; import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/search_service.dart"; @@ -24,7 +25,7 @@ class FaceWidget extends StatefulWidget { final Face face; final Future?>? faceCrops; final PersonEntity? person; - final int? clusterID; + final String? clusterID; final bool highlight; final bool editMode; @@ -98,7 +99,7 @@ class _FaceWidgetState extends State { } // Create new clusterID for the faceID and update DB to assign the faceID to the new clusterID - final int newClusterID = DateTime.now().microsecondsSinceEpoch; + final String newClusterID = ClusterID.generate(); await FaceMLDataDB.instance.updateFaceIdToClusterId( {widget.face.faceID: newClusterID}, ); diff --git a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart index cb22e53b82..304beba5a4 100644 --- a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart +++ b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart @@ -146,7 +146,7 @@ class _FacesItemWidgetState extends State { final faceCrops = getRelevantFaceCrops(faces); for (final Face face in faces) { - final int? clusterID = faceIdsToClusterIds[face.faceID]; + final String? clusterID = faceIdsToClusterIds[face.faceID]; final PersonEntity? person = clusterIDToPerson[clusterID] != null ? persons[clusterIDToPerson[clusterID]!] : null; @@ -175,8 +175,7 @@ class _FacesItemWidgetState extends State { Future?> getRelevantFaceCrops( Iterable faces, { int fetchAttempt = 1, - } - ) async { + }) async { try { final faceIdToCrop = {}; final facesWithoutCrops = {}; @@ -226,7 +225,7 @@ class _FacesItemWidgetState extends State { stackTrace: s, ); resetPool(fullFile: true); - if(fetchAttempt <= retryLimit) { + if (fetchAttempt <= retryLimit) { return getRelevantFaceCrops(faces, fetchAttempt: fetchAttempt + 1); } return null; diff --git a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart index eb5b3e4b32..c868e60e71 100644 --- a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart +++ b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart @@ -47,7 +47,7 @@ String _actionName( Future showAssignPersonAction( BuildContext context, { - required int clusterID, + required String clusterID, PersonActionType actionType = PersonActionType.assignPerson, bool showOptionToAddNewPerson = true, }) { @@ -75,7 +75,7 @@ Future showAssignPersonAction( class PersonActionSheet extends StatefulWidget { final PersonActionType actionType; - final int cluserID; + final String cluserID; final bool showOptionToCreateNewPerson; const PersonActionSheet({ required this.actionType, @@ -276,7 +276,7 @@ class _PersonActionSheetState extends State { Future addNewPerson( BuildContext context, { String initValue = '', - required int clusterID, + required String clusterID, }) async { final result = await showTextInputDialog( context, diff --git a/mobile/lib/ui/viewer/people/cluster_app_bar.dart b/mobile/lib/ui/viewer/people/cluster_app_bar.dart index fbfade9013..25b16aab34 100644 --- a/mobile/lib/ui/viewer/people/cluster_app_bar.dart +++ b/mobile/lib/ui/viewer/people/cluster_app_bar.dart @@ -26,7 +26,7 @@ class ClusterAppBar extends StatefulWidget { final GalleryType type; final String? title; final SelectedFiles selectedFiles; - final int clusterID; + final String clusterID; final PersonEntity? person; const ClusterAppBar( @@ -179,7 +179,7 @@ class _AppBarWidgetState extends State { Future _breakUpCluster(BuildContext context) async { bool userConfirmed = false; List biggestClusterFiles = []; - int biggestClusterID = -1; + String biggestClusterID = ''; await showChoiceDialog( context, title: "Does this grouping contain multiple people?", @@ -190,9 +190,9 @@ class _AppBarWidgetState extends State { try { final breakupResult = await ClusterFeedbackService.instance .breakUpCluster(widget.clusterID); - final Map> newClusterIDToFaceIDs = + final Map> newClusterIDToFaceIDs = breakupResult.newClusterIdToFaceIds; - final Map newFaceIdToClusterID = + final Map newFaceIdToClusterID = breakupResult.newFaceIdToCluster; // Update to delete the old clusters and save the new clusters @@ -203,9 +203,9 @@ class _AppBarWidgetState extends State { .updateFaceIdToClusterId(newFaceIdToClusterID); // Find the biggest cluster - biggestClusterID = -1; + biggestClusterID = ''; int biggestClusterSize = 0; - for (final MapEntry> clusterToFaces + for (final MapEntry> clusterToFaces in newClusterIDToFaceIDs.entries) { if (clusterToFaces.value.length > biggestClusterSize) { biggestClusterSize = clusterToFaces.value.length; @@ -253,7 +253,7 @@ class _AppBarWidgetState extends State { final breakupResult = await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID); - final Map> newClusterIDToFaceIDs = + final Map> newClusterIDToFaceIDs = breakupResult.newClusterIdToFaceIds; final allFileIDs = newClusterIDToFaceIDs.values diff --git a/mobile/lib/ui/viewer/people/cluster_breakup_page.dart b/mobile/lib/ui/viewer/people/cluster_breakup_page.dart index 5644258325..c1bec40466 100644 --- a/mobile/lib/ui/viewer/people/cluster_breakup_page.dart +++ b/mobile/lib/ui/viewer/people/cluster_breakup_page.dart @@ -6,7 +6,7 @@ import "package:photos/ui/viewer/people/cluster_page.dart"; import "package:photos/ui/viewer/search/result/person_face_widget.dart"; class ClusterBreakupPage extends StatefulWidget { - final Map> newClusterIDsToFiles; + final Map> newClusterIDsToFiles; final String title; const ClusterBreakupPage( @@ -32,7 +32,7 @@ class _ClusterBreakupPageState extends State { body: ListView.builder( itemCount: widget.newClusterIDsToFiles.keys.length, itemBuilder: (context, index) { - final int clusterID = keys[index]; + final String clusterID = keys[index]; final List files = clusterIDsToFiles[keys[index]]!; return InkWell( onTap: () { @@ -40,7 +40,7 @@ class _ClusterBreakupPageState extends State { MaterialPageRoute( builder: (context) => ClusterPage( files, - clusterID: index, + clusterID: clusterID, appendTitle: "(Analysis)", ), ), diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart index efca41ecab..05b89d8247 100644 --- a/mobile/lib/ui/viewer/people/cluster_page.dart +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -29,7 +29,7 @@ class ClusterPage extends StatefulWidget { final List searchResult; final bool enableGrouping; final String tagPrefix; - final int clusterID; + final String clusterID; final PersonEntity? personID; final String appendTitle; final bool showNamingBanner; diff --git a/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart index c3e788ffff..db1141fa9e 100644 --- a/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart +++ b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart @@ -109,7 +109,7 @@ class _PersonClustersState extends State { allSuggestions = snapshot.data!; final numberOfDifferentSuggestions = allSuggestions.length; final currentSuggestion = allSuggestions[currentSuggestionIndex]; - final int clusterID = currentSuggestion.clusterIDToMerge; + final String clusterID = currentSuggestion.clusterIDToMerge; final double distance = currentSuggestion.distancePersonToCluster; final bool usingMean = currentSuggestion.usedOnlyMeanForSuggestion; final List files = currentSuggestion.filesInCluster; @@ -182,7 +182,7 @@ class _PersonClustersState extends State { } Future _handleUserClusterChoice( - int clusterID, + String clusterID, bool yesOrNo, int numberOfSuggestions, ) async { @@ -229,7 +229,7 @@ class _PersonClustersState extends State { } Future _rejectSuggestion( - int clusterID, + String clusterID, int numberOfSuggestions, ) async { canGiveFeedback = false; @@ -254,7 +254,7 @@ class _PersonClustersState extends State { } Widget _buildSuggestionView( - int clusterID, + String clusterID, double distance, bool usingMean, List files, @@ -379,7 +379,7 @@ class _PersonClustersState extends State { Widget _buildThumbnailWidget( List files, - int clusterID, + String clusterID, Future> generateFaceThumbnails, ) { return SizedBox( @@ -433,7 +433,7 @@ class _PersonClustersState extends State { List _buildThumbnailWidgetsRow( List files, - int cluserId, + String cluserId, Map faceThumbnails, { int start = 0, }) { @@ -460,7 +460,7 @@ class _PersonClustersState extends State { Future> _generateFaceThumbnails( List files, - int clusterID, + String clusterID, ) async { final futures = >[]; for (final file in files) { diff --git a/mobile/lib/ui/viewer/people/person_clusters_page.dart b/mobile/lib/ui/viewer/people/person_clusters_page.dart index 16d1131682..810fcbc779 100644 --- a/mobile/lib/ui/viewer/people/person_clusters_page.dart +++ b/mobile/lib/ui/viewer/people/person_clusters_page.dart @@ -32,12 +32,13 @@ class _PersonClustersPageState extends State { appBar: AppBar( title: Text(widget.person.data.name), ), - body: FutureBuilder>>( - future: SearchService.instance.getClusterFilesForPersonID(widget.person.remoteID), + body: FutureBuilder>>( + future: SearchService.instance + .getClusterFilesForPersonID(widget.person.remoteID), builder: (context, snapshot) { if (snapshot.hasData) { final clusters = snapshot.data!; - final List keys = clusters.keys.toList(); + final List keys = clusters.keys.toList(); // Sort the clusters by the number of files in each cluster, largest first keys.sort( (b, a) => clusters[a]!.length.compareTo(clusters[b]!.length), @@ -45,7 +46,7 @@ class _PersonClustersPageState extends State { return ListView.builder( itemCount: keys.length, itemBuilder: (context, index) { - final int clusterID = keys[index]; + final String clusterID = keys[index]; final List files = clusters[clusterID]!; return InkWell( onTap: () { @@ -54,7 +55,7 @@ class _PersonClustersPageState extends State { builder: (context) => ClusterPage( files, personID: widget.person, - clusterID: index, + clusterID: clusterID, showNamingBanner: false, ), ), @@ -91,7 +92,8 @@ class _PersonClustersPageState extends State { ), // Add some spacing between the thumbnail and the text Expanded( child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 8.0), + padding: + const EdgeInsets.symmetric(horizontal: 8.0), child: Row( mainAxisAlignment: MainAxisAlignment.spaceBetween, children: [ @@ -103,14 +105,16 @@ class _PersonClustersPageState extends State { ? GestureDetector( onTap: () async { try { - await PersonService.instance.removeClusterToPerson( + await PersonService.instance + .removeClusterToPerson( personID: widget.person.remoteID, clusterID: clusterID, ); _logger.info( "Removed cluster $clusterID from person ${widget.person.remoteID}", ); - Bus.instance.fire(PeopleChangedEvent()); + Bus.instance + .fire(PeopleChangedEvent()); setState(() {}); } catch (e) { _logger.severe( diff --git a/mobile/lib/ui/viewer/search/result/person_face_widget.dart b/mobile/lib/ui/viewer/search/result/person_face_widget.dart index c63a8cd890..dda445b862 100644 --- a/mobile/lib/ui/viewer/search/result/person_face_widget.dart +++ b/mobile/lib/ui/viewer/search/result/person_face_widget.dart @@ -17,7 +17,7 @@ import "package:pool/pool.dart"; class PersonFaceWidget extends StatelessWidget { final EnteFile file; final String? personId; - final int? clusterID; + final String? clusterID; final bool useFullFile; final bool thumbnailFallback; final Uint8List? faceCrop; diff --git a/mobile/lib/ui/viewer/search_tab/people_section.dart b/mobile/lib/ui/viewer/search_tab/people_section.dart index 13e2f8a813..b3ce01045f 100644 --- a/mobile/lib/ui/viewer/search_tab/people_section.dart +++ b/mobile/lib/ui/viewer/search_tab/people_section.dart @@ -273,7 +273,7 @@ class SearchExample extends StatelessWidget { onTap: () async { final result = await showAssignPersonAction( context, - clusterID: int.parse(searchResult.name()), + clusterID: searchResult.name(), ); if (result != null && result is (PersonEntity, EnteFile)) { From df46eeab10d53128c3d9d0d82ca94c9d7bf7e0e7 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 13 Aug 2024 23:12:29 +0530 Subject: [PATCH 04/17] refactor --- mobile/lib/services/entity_service.dart | 9 ++++----- mobile/lib/services/location_service.dart | 4 ++-- .../face_ml/person/person_service.dart | 16 ++++++---------- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/mobile/lib/services/entity_service.dart b/mobile/lib/services/entity_service.dart index 6ffe87358b..3f00d264fd 100644 --- a/mobile/lib/services/entity_service.dart +++ b/mobile/lib/services/entity_service.dart @@ -56,14 +56,13 @@ class EntityService { Future addOrUpdate( EntityType type, - String plainText, { + Map jsonMap, { String? id, }) async { + final String plainText = jsonEncode(jsonMap); final key = await getOrCreateEntityKey(type); - final encryptedKeyData = await CryptoUtil.encryptChaCha( - utf8.encode(plainText), - key, - ); + final encryptedKeyData = + await CryptoUtil.encryptChaCha(utf8.encode(plainText), key); final String encryptedData = CryptoUtil.bin2base64(encryptedKeyData.encryptedData!); final String header = CryptoUtil.bin2base64(encryptedKeyData.header!); diff --git a/mobile/lib/services/location_service.dart b/mobile/lib/services/location_service.dart index f31a2803f6..f929dca525 100644 --- a/mobile/lib/services/location_service.dart +++ b/mobile/lib/services/location_service.dart @@ -90,7 +90,7 @@ class LocationService { centerPoint: centerPoint, ); await EntityService.instance - .addOrUpdate(EntityType.location, json.encode(locationTag.toJson())); + .addOrUpdate(EntityType.location, locationTag.toJson()); Bus.instance.fire(LocationTagUpdatedEvent(LocTagEventType.add)); } catch (e, s) { _logger.severe("Failed to add location tag", e, s); @@ -179,7 +179,7 @@ class LocationService { await EntityService.instance.addOrUpdate( EntityType.location, - json.encode(updatedLoationTag.toJson()), + updatedLoationTag.toJson(), id: locationTagEntity.id, ); Bus.instance.fire( diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index 187e51faf6..9a94b28160 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -95,11 +95,7 @@ class PersonService { ) .toList(); entityService - .addOrUpdate( - EntityType.person, - json.encode(personData.toJson()), - id: personID, - ) + .addOrUpdate(EntityType.person, personData.toJson(), id: personID) .ignore(); personData.logStats(); } @@ -168,7 +164,7 @@ class PersonService { ); final result = await entityService.addOrUpdate( EntityType.person, - json.encode(data.toJson()), + data.toJson(), ); await faceMLDataDB.assignClusterToPerson( personID: result.id, @@ -186,7 +182,7 @@ class PersonService { personData.assigned!.removeWhere((element) => element.id != clusterID); await entityService.addOrUpdate( EntityType.person, - json.encode(personData.toJson()), + personData.toJson(), id: personID, ); await faceMLDataDB.removeClusterToPerson( @@ -221,7 +217,7 @@ class PersonService { await entityService.addOrUpdate( EntityType.person, - json.encode(personData.toJson()), + personData.toJson(), id: person.remoteID, ); personData.logStats(); @@ -237,7 +233,7 @@ class PersonService { PersonEntity(personID, PersonData(name: entity.data.name)); await entityService.addOrUpdate( EntityType.person, - json.encode(justName.data.toJson()), + justName.data.toJson(), id: personID, ); await faceMLDataDB.removePerson(personID); @@ -312,7 +308,7 @@ class PersonService { Future _updatePerson(PersonEntity updatePerson) async { await entityService.addOrUpdate( EntityType.person, - json.encode(updatePerson.data.toJson()), + updatePerson.data.toJson(), id: updatePerson.remoteID, ); updatePerson.data.logStats(); From bc43d3b461da1ff661f5abe383891f79f15de4e8 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 10:52:51 +0530 Subject: [PATCH 05/17] Update avatar faceID key --- mobile/lib/face/model/person.dart | 12 ++++++------ .../ui/viewer/search/result/person_face_widget.dart | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mobile/lib/face/model/person.dart b/mobile/lib/face/model/person.dart index 7536d3b0a6..49754be925 100644 --- a/mobile/lib/face/model/person.dart +++ b/mobile/lib/face/model/person.dart @@ -49,12 +49,12 @@ class ClusterInfo { class PersonData { final String name; final bool isHidden; - String? avatarFaceId; + String? avatarFaceID; List? assigned = List.empty(); List? rejected = List.empty(); final String? birthDate; - bool hasAvatar() => avatarFaceId != null; + bool hasAvatar() => avatarFaceID != null; bool get isIgnored => (name.isEmpty || name == '(hidden)' || name == '(ignored)'); @@ -63,7 +63,7 @@ class PersonData { required this.name, this.assigned, this.rejected, - this.avatarFaceId, + this.avatarFaceID, this.isHidden = false, this.birthDate, }); @@ -79,7 +79,7 @@ class PersonData { return PersonData( name: name ?? this.name, assigned: assigned ?? this.assigned, - avatarFaceId: avatarFaceId ?? this.avatarFaceId, + avatarFaceID: avatarFaceId ?? this.avatarFaceID, isHidden: isHidden ?? this.isHidden, birthDate: birthDate ?? this.birthDate, ); @@ -109,7 +109,7 @@ class PersonData { 'name': name, 'assigned': assigned?.map((e) => e.toJson()).toList(), 'rejected': rejected?.map((e) => e.toJson()).toList(), - 'avatarFaceId': avatarFaceId, + 'avatarFaceID': avatarFaceID, 'isHidden': isHidden, 'birthDate': birthDate, }; @@ -131,7 +131,7 @@ class PersonData { name: json['name'] as String, assigned: assigned, rejected: rejected, - avatarFaceId: json['avatarFaceId'] as String?, + avatarFaceID: json['avatarFaceId'] as String?, isHidden: json['isHidden'] as bool? ?? false, birthDate: json['birthDate'] as String?, ); diff --git a/mobile/lib/ui/viewer/search/result/person_face_widget.dart b/mobile/lib/ui/viewer/search/result/person_face_widget.dart index dda445b862..eea17134b2 100644 --- a/mobile/lib/ui/viewer/search/result/person_face_widget.dart +++ b/mobile/lib/ui/viewer/search/result/person_face_widget.dart @@ -83,7 +83,7 @@ class PersonFaceWidget extends StatelessWidget { final PersonEntity? personEntity = await PersonService.instance.getPerson(personId!); if (personEntity != null) { - personAvatarFaceID = personEntity.data.avatarFaceId; + personAvatarFaceID = personEntity.data.avatarFaceID; } } return await FaceMLDataDB.instance.getCoverFaceForPerson( From 9e22cfa9724421d2dae3f3676f2d14a1ce09a152 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 12:25:21 +0530 Subject: [PATCH 06/17] Add method to async decrypt json data --- mobile/lib/utils/gzip.dart | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mobile/lib/utils/gzip.dart b/mobile/lib/utils/gzip.dart index d1680d43e7..2cc86848ee 100644 --- a/mobile/lib/utils/gzip.dart +++ b/mobile/lib/utils/gzip.dart @@ -28,6 +28,25 @@ Uint8List _gzipUInt8List(Uint8List data) { return Uint8List.fromList(compressedData); } +Future> decryptAndUnzipJson( + Uint8List key, { + required String encryptedData, + required String header, +}) async { + final Computer computer = Computer.shared(); + final response = + await computer.compute, Map>( + _decryptAndUnzipJsonSync, + param: { + "key": key, + "encryptedData": encryptedData, + "header": header, + }, + taskName: "decryptAndUnzipJson", + ); + return response; +} + Map decryptAndUnzipJsonSync( Uint8List key, { required String encryptedData, @@ -82,3 +101,10 @@ ChaChaEncryptionResult _gzipAndEncryptJsonSync( ) { return gzipAndEncryptJsonSync(args["jsonData"], args["key"]); } + +Map _decryptAndUnzipJsonSync( + Map args, +) { + return decryptAndUnzipJsonSync(args["key"], + encryptedData: args["encryptedData"], header: args["header"]); +} From 9d9b4f00ed7f42fd4edf08160d184ad62b87e348 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 12:25:36 +0530 Subject: [PATCH 07/17] Fix nano id alphabets --- mobile/lib/models/nanoids/cluster_id.dart | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mobile/lib/models/nanoids/cluster_id.dart b/mobile/lib/models/nanoids/cluster_id.dart index 278e0bb1a2..ee2d14013e 100644 --- a/mobile/lib/models/nanoids/cluster_id.dart +++ b/mobile/lib/models/nanoids/cluster_id.dart @@ -1,13 +1,13 @@ import "package:flutter/foundation.dart"; import 'package:nanoid/nanoid.dart'; -const alphaphet = +const enteWhiteListedAlphabet = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'; const clusterIDLength = 22; class ClusterID { static String generate() { - return "cluster_${customAlphabet(urlAlphabet, clusterIDLength)}"; + return "cluster_${customAlphabet(enteWhiteListedAlphabet, clusterIDLength)}"; } // Validation method From d7ffb3c7e0a8bf64f7fb0eee212f26ccb95879f4 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 12:25:54 +0530 Subject: [PATCH 08/17] Update endpoint --- mobile/lib/services/filedata/filedata_service.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mobile/lib/services/filedata/filedata_service.dart b/mobile/lib/services/filedata/filedata_service.dart index 1d7d184050..e0e2d6c2db 100644 --- a/mobile/lib/services/filedata/filedata_service.dart +++ b/mobile/lib/services/filedata/filedata_service.dart @@ -32,7 +32,7 @@ class FileDataService { try { final _ = await _dio.put( - "/files/data/", + "/files/data", data: { "fileID": file.uploadedFileID!, "type": data.type.toJson(), From 9f96ef8d83df66dad64c53d8a04e65a5708291b0 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 12:27:36 +0530 Subject: [PATCH 09/17] [mob] Switch to person_v2 where data is gzipped --- mobile/lib/models/api/entity/type.dart | 12 ++++++ mobile/lib/services/entity_service.dart | 42 +++++++++++++------ .../face_ml/person/person_service.dart | 20 ++++----- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/mobile/lib/models/api/entity/type.dart b/mobile/lib/models/api/entity/type.dart index 88e60d62f3..431ea4e57b 100644 --- a/mobile/lib/models/api/entity/type.dart +++ b/mobile/lib/models/api/entity/type.dart @@ -3,6 +3,7 @@ import "package:flutter/foundation.dart"; enum EntityType { location, person, + personV2, unknown, } @@ -12,18 +13,29 @@ EntityType typeFromString(String type) { return EntityType.location; case "person": return EntityType.location; + case "person_v2": + return EntityType.personV2; } debugPrint("unexpected collection type $type"); return EntityType.unknown; } extension EntityTypeExtn on EntityType { + bool isZipped() { + if (this == EntityType.location || this == EntityType.person) { + return false; + } + return true; + } + String typeToString() { switch (this) { case EntityType.location: return "location"; case EntityType.person: return "person"; + case EntityType.personV2: + return "person_v2"; case EntityType.unknown: return "unknown"; } diff --git a/mobile/lib/services/entity_service.dart b/mobile/lib/services/entity_service.dart index 3f00d264fd..e979d4fd9e 100644 --- a/mobile/lib/services/entity_service.dart +++ b/mobile/lib/services/entity_service.dart @@ -14,6 +14,7 @@ import "package:photos/models/api/entity/key.dart"; import "package:photos/models/api/entity/type.dart"; import "package:photos/models/local_entity_data.dart"; import "package:photos/utils/crypto_util.dart"; +import "package:photos/utils/gzip.dart"; import 'package:shared_preferences/shared_preferences.dart'; class EntityService { @@ -61,11 +62,18 @@ class EntityService { }) async { final String plainText = jsonEncode(jsonMap); final key = await getOrCreateEntityKey(type); - final encryptedKeyData = - await CryptoUtil.encryptChaCha(utf8.encode(plainText), key); - final String encryptedData = - CryptoUtil.bin2base64(encryptedKeyData.encryptedData!); - final String header = CryptoUtil.bin2base64(encryptedKeyData.header!); + late String encryptedData, header; + if (type.isZipped()) { + final ChaChaEncryptionResult result = + await gzipAndEncryptJson(jsonMap, key); + encryptedData = result.encData; + header = result.header; + } else { + final encryptedKeyData = + await CryptoUtil.encryptChaCha(utf8.encode(plainText), key); + encryptedData = CryptoUtil.bin2base64(encryptedKeyData.encryptedData!); + header = CryptoUtil.bin2base64(encryptedKeyData.header!); + } debugPrint( " ${id == null ? 'Adding' : 'Updating'} entity of type: " + type.typeToString(), @@ -93,7 +101,7 @@ class EntityService { Future syncEntities() async { try { await _remoteToLocalSync(EntityType.location); - await _remoteToLocalSync(EntityType.person); + await _remoteToLocalSync(EntityType.personV2); } catch (e) { _logger.severe("Failed to sync entities", e); } @@ -126,12 +134,22 @@ class EntityService { final List entities = []; for (EntityData e in result) { try { - final decryptedValue = await CryptoUtil.decryptChaCha( - CryptoUtil.base642bin(e.encryptedData!), - entityKey, - CryptoUtil.base642bin(e.header!), - ); - final String plainText = utf8.decode(decryptedValue); + late String plainText; + if (type.isZipped()) { + final jsonMap = await decryptAndUnzipJson( + entityKey, + encryptedData: e.encryptedData!, + header: e.header!, + ); + plainText = jsonEncode(jsonMap); + } else { + final Uint8List decryptedValue = await CryptoUtil.decryptChaCha( + CryptoUtil.base642bin(e.encryptedData!), + entityKey, + CryptoUtil.base642bin(e.header!), + ); + plainText = utf8.decode(decryptedValue); + } entities.add( LocalEntityData( id: e.id, diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index 9a94b28160..5a5f4b23a6 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -37,7 +37,7 @@ class PersonService { } Future> getPersons() async { - final entities = await entityService.getEntities(EntityType.person); + final entities = await entityService.getEntities(EntityType.personV2); return entities .map( (e) => PersonEntity(e.id, PersonData.fromJson(json.decode(e.data))), @@ -46,7 +46,7 @@ class PersonService { } Future getPerson(String id) { - return entityService.getEntity(EntityType.person, id).then((e) { + return entityService.getEntity(EntityType.personV2, id).then((e) { if (e == null) { return null; } @@ -55,7 +55,7 @@ class PersonService { } Future> getPersonsMap() async { - final entities = await entityService.getEntities(EntityType.person); + final entities = await entityService.getEntities(EntityType.personV2); final Map map = {}; for (var e in entities) { final person = @@ -95,7 +95,7 @@ class PersonService { ) .toList(); entityService - .addOrUpdate(EntityType.person, personData.toJson(), id: personID) + .addOrUpdate(EntityType.personV2, personData.toJson(), id: personID) .ignore(); personData.logStats(); } @@ -163,7 +163,7 @@ class PersonService { isHidden: isHidden, ); final result = await entityService.addOrUpdate( - EntityType.person, + EntityType.personV2, data.toJson(), ); await faceMLDataDB.assignClusterToPerson( @@ -181,7 +181,7 @@ class PersonService { final personData = person.data; personData.assigned!.removeWhere((element) => element.id != clusterID); await entityService.addOrUpdate( - EntityType.person, + EntityType.personV2, personData.toJson(), id: personID, ); @@ -216,7 +216,7 @@ class PersonService { } await entityService.addOrUpdate( - EntityType.person, + EntityType.personV2, personData.toJson(), id: person.remoteID, ); @@ -232,7 +232,7 @@ class PersonService { final PersonEntity justName = PersonEntity(personID, PersonData(name: entity.data.name)); await entityService.addOrUpdate( - EntityType.person, + EntityType.personV2, justName.data.toJson(), id: personID, ); @@ -249,7 +249,7 @@ class PersonService { Future fetchRemoteClusterFeedback() async { await entityService.syncEntities(); - final entities = await entityService.getEntities(EntityType.person); + final entities = await entityService.getEntities(EntityType.personV2); entities.sort((a, b) => a.updatedAt.compareTo(b.updatedAt)); final Map faceIdToClusterID = {}; final Map clusterToPersonID = {}; @@ -307,7 +307,7 @@ class PersonService { Future _updatePerson(PersonEntity updatePerson) async { await entityService.addOrUpdate( - EntityType.person, + EntityType.personV2, updatePerson.data.toJson(), id: updatePerson.remoteID, ); From 3a1025a2b9d59c4dbcba814226bbbb3c012c6f9f Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:07:01 +0530 Subject: [PATCH 10/17] Lint fix --- mobile/lib/utils/gzip.dart | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mobile/lib/utils/gzip.dart b/mobile/lib/utils/gzip.dart index 2cc86848ee..8de71f49df 100644 --- a/mobile/lib/utils/gzip.dart +++ b/mobile/lib/utils/gzip.dart @@ -105,6 +105,9 @@ ChaChaEncryptionResult _gzipAndEncryptJsonSync( Map _decryptAndUnzipJsonSync( Map args, ) { - return decryptAndUnzipJsonSync(args["key"], - encryptedData: args["encryptedData"], header: args["header"]); + return decryptAndUnzipJsonSync( + args["key"], + encryptedData: args["encryptedData"], + header: args["header"], + ); } From 5625071fe86f3a595c019de13d446ac7a104b5f3 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:24:39 +0530 Subject: [PATCH 11/17] Rearrange --- mobile/lib/{face => db/ml}/db.dart | 0 mobile/lib/{face => db/ml}/db_fields.dart | 0 mobile/lib/{face => db/ml}/db_model_mappers.dart | 0 mobile/lib/db/{ => ml}/embeddings_db.dart | 0 mobile/lib/models/{embedding.dart => ml/clip.dart} | 0 mobile/lib/{face/model => models/ml/face}/box.dart | 0 mobile/lib/{face/model => models/ml/face}/detection.dart | 0 mobile/lib/{face/model => models/ml/face}/dimension.dart | 0 mobile/lib/{face/model => models/ml/face}/face.dart | 0 mobile/lib/{face/model => models/ml/face}/landmark.dart | 0 mobile/lib/{face/model => models/ml/face}/person.dart | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename mobile/lib/{face => db/ml}/db.dart (100%) rename mobile/lib/{face => db/ml}/db_fields.dart (100%) rename mobile/lib/{face => db/ml}/db_model_mappers.dart (100%) rename mobile/lib/db/{ => ml}/embeddings_db.dart (100%) rename mobile/lib/models/{embedding.dart => ml/clip.dart} (100%) rename mobile/lib/{face/model => models/ml/face}/box.dart (100%) rename mobile/lib/{face/model => models/ml/face}/detection.dart (100%) rename mobile/lib/{face/model => models/ml/face}/dimension.dart (100%) rename mobile/lib/{face/model => models/ml/face}/face.dart (100%) rename mobile/lib/{face/model => models/ml/face}/landmark.dart (100%) rename mobile/lib/{face/model => models/ml/face}/person.dart (100%) diff --git a/mobile/lib/face/db.dart b/mobile/lib/db/ml/db.dart similarity index 100% rename from mobile/lib/face/db.dart rename to mobile/lib/db/ml/db.dart diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/db/ml/db_fields.dart similarity index 100% rename from mobile/lib/face/db_fields.dart rename to mobile/lib/db/ml/db_fields.dart diff --git a/mobile/lib/face/db_model_mappers.dart b/mobile/lib/db/ml/db_model_mappers.dart similarity index 100% rename from mobile/lib/face/db_model_mappers.dart rename to mobile/lib/db/ml/db_model_mappers.dart diff --git a/mobile/lib/db/embeddings_db.dart b/mobile/lib/db/ml/embeddings_db.dart similarity index 100% rename from mobile/lib/db/embeddings_db.dart rename to mobile/lib/db/ml/embeddings_db.dart diff --git a/mobile/lib/models/embedding.dart b/mobile/lib/models/ml/clip.dart similarity index 100% rename from mobile/lib/models/embedding.dart rename to mobile/lib/models/ml/clip.dart diff --git a/mobile/lib/face/model/box.dart b/mobile/lib/models/ml/face/box.dart similarity index 100% rename from mobile/lib/face/model/box.dart rename to mobile/lib/models/ml/face/box.dart diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/models/ml/face/detection.dart similarity index 100% rename from mobile/lib/face/model/detection.dart rename to mobile/lib/models/ml/face/detection.dart diff --git a/mobile/lib/face/model/dimension.dart b/mobile/lib/models/ml/face/dimension.dart similarity index 100% rename from mobile/lib/face/model/dimension.dart rename to mobile/lib/models/ml/face/dimension.dart diff --git a/mobile/lib/face/model/face.dart b/mobile/lib/models/ml/face/face.dart similarity index 100% rename from mobile/lib/face/model/face.dart rename to mobile/lib/models/ml/face/face.dart diff --git a/mobile/lib/face/model/landmark.dart b/mobile/lib/models/ml/face/landmark.dart similarity index 100% rename from mobile/lib/face/model/landmark.dart rename to mobile/lib/models/ml/face/landmark.dart diff --git a/mobile/lib/face/model/person.dart b/mobile/lib/models/ml/face/person.dart similarity index 100% rename from mobile/lib/face/model/person.dart rename to mobile/lib/models/ml/face/person.dart From 810cf6f8853d63656008775ecb06dd9eb66fb89d Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:21:52 +0530 Subject: [PATCH 12/17] [mob] Use single db for ml data --- mobile/lib/core/configuration.dart | 7 +- mobile/lib/db/ml/db.dart | 61 +++++++--- mobile/lib/db/ml/db_fields.dart | 18 ++- mobile/lib/db/ml/db_model_mappers.dart | 10 +- mobile/lib/db/ml/embeddings_db.dart | 114 ++++-------------- mobile/lib/main.dart | 2 +- mobile/lib/models/ml/face/detection.dart | 4 +- mobile/lib/models/ml/face/face.dart | 2 +- .../services/filedata/model/file_data.dart | 2 +- .../face_ml/face_detection/detection.dart | 2 +- .../face_detection_service.dart | 2 +- .../face_ml/face_recognition_service.dart | 10 +- .../face_ml/feedback/cluster_feedback.dart | 6 +- .../face_ml/person/person_service.dart | 4 +- .../machine_learning/ml_computer.dart | 2 +- .../services/machine_learning/ml_result.dart | 2 +- .../services/machine_learning/ml_service.dart | 10 +- mobile/lib/services/search_service.dart | 4 +- .../bottom_action_bar_widget.dart | 2 +- .../debug/ml_debug_section_widget.dart | 4 +- .../file_selection_actions_widget.dart | 2 +- .../actions/file_selection_overlay_bar.dart | 2 +- .../ui/viewer/file_details/face_widget.dart | 6 +- .../file_details/faces_item_widget.dart | 8 +- .../people/add_person_action_sheet.dart | 4 +- .../lib/ui/viewer/people/cluster_app_bar.dart | 4 +- mobile/lib/ui/viewer/people/cluster_page.dart | 2 +- .../lib/ui/viewer/people/people_app_bar.dart | 23 ++-- mobile/lib/ui/viewer/people/people_page.dart | 2 +- .../people/person_cluster_suggestion.dart | 4 +- .../viewer/people/person_clusters_page.dart | 2 +- .../lib/ui/viewer/people/person_row_item.dart | 2 +- .../search/result/person_face_widget.dart | 6 +- .../ui/viewer/search_tab/people_section.dart | 2 +- mobile/lib/utils/face/face_box_crop.dart | 2 +- mobile/lib/utils/face/face_util.dart | 2 +- mobile/lib/utils/image_ml_util.dart | 4 +- mobile/lib/utils/ml_util.dart | 10 +- 38 files changed, 166 insertions(+), 189 deletions(-) diff --git a/mobile/lib/core/configuration.dart b/mobile/lib/core/configuration.dart index a1f7d51daf..9fe284c583 100644 --- a/mobile/lib/core/configuration.dart +++ b/mobile/lib/core/configuration.dart @@ -11,15 +11,14 @@ import 'package:photos/core/constants.dart'; import 'package:photos/core/error-reporting/super_logging.dart'; import 'package:photos/core/event_bus.dart'; import 'package:photos/db/collections_db.dart'; -import "package:photos/db/embeddings_db.dart"; import 'package:photos/db/files_db.dart'; import 'package:photos/db/memories_db.dart'; +import "package:photos/db/ml/db.dart"; import 'package:photos/db/trash_db.dart'; import 'package:photos/db/upload_locks_db.dart'; import "package:photos/events/endpoint_updated_event.dart"; import 'package:photos/events/signed_in_event.dart'; import 'package:photos/events/user_logged_out_event.dart'; -import "package:photos/face/db.dart"; import 'package:photos/models/key_attributes.dart'; import 'package:photos/models/key_gen_result.dart'; import 'package:photos/models/private_key_attributes.dart'; @@ -28,7 +27,6 @@ import 'package:photos/services/collections_service.dart'; import 'package:photos/services/favorites_service.dart'; import "package:photos/services/home_widget_service.dart"; import 'package:photos/services/ignored_files_service.dart'; -import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import 'package:photos/services/memories_service.dart'; import 'package:photos/services/search_service.dart'; import 'package:photos/services/sync_service.dart'; @@ -205,9 +203,6 @@ class Configuration { _cachedToken = null; _secretKey = null; await FilesDB.instance.clearTable(); - SemanticSearchService.instance.hasInitialized - ? await EmbeddingsDB.instance.clearTable() - : null; await CollectionsDB.instance.clearTable(); await MemoriesDB.instance.clearTable(); await FaceMLDataDB.instance.clearTable(); diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index cdc014f86e..11c9cf26d4 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -6,10 +6,10 @@ import "package:flutter/foundation.dart"; import 'package:logging/logging.dart'; import 'package:path/path.dart' show join; import 'package:path_provider/path_provider.dart'; +import 'package:photos/db/ml/db_fields.dart'; +import "package:photos/db/ml/db_model_mappers.dart"; import "package:photos/extensions/stop_watch.dart"; -import 'package:photos/face/db_fields.dart'; -import "package:photos/face/db_model_mappers.dart"; -import "package:photos/face/model/face.dart"; +import "package:photos/models/ml/face/face.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; @@ -28,7 +28,7 @@ import 'package:sqlite_async/sqlite_async.dart'; class FaceMLDataDB { static final Logger _logger = Logger("FaceMLDataDB"); - static const _databaseName = "ente.face_ml_db_v2.db"; + static const _databaseName = "ente.face_ml_db_v3.db"; // static const _databaseVersion = 1; FaceMLDataDB._privateConstructor(); @@ -42,6 +42,7 @@ class FaceMLDataDB { createClusterSummaryTable, createNotPersonFeedbackTable, fcClusterIDIndex, + createClipEmbeddingsTable, ]; // only have a single app-wide reference to the database @@ -111,9 +112,9 @@ class FaceMLDataDB { const String sql = ''' INSERT INTO $facesTable ( - $fileIDColumn, $faceIDColumn, $faceDetectionColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways, $imageHeight, $imageWidth, $mlVersionColumn + $fileIDColumn, $faceIDColumn, $faceDetectionColumn, $embeddingColumn, $faceScore, $faceBlur, $isSideways, $imageHeight, $imageWidth, $mlVersionColumn ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT($fileIDColumn, $faceIDColumn) DO UPDATE SET $faceIDColumn = excluded.$faceIDColumn, $faceDetectionColumn = excluded.$faceDetectionColumn, $faceEmbeddingBlob = excluded.$faceEmbeddingBlob, $faceScore = excluded.$faceScore, $faceBlur = excluded.$faceBlur, $isSideways = excluded.$isSideways, $imageHeight = excluded.$imageHeight, $imageWidth = excluded.$imageWidth, $mlVersionColumn = excluded.$mlVersionColumn + ON CONFLICT($fileIDColumn, $faceIDColumn) DO UPDATE SET $faceIDColumn = excluded.$faceIDColumn, $faceDetectionColumn = excluded.$faceDetectionColumn, $embeddingColumn = excluded.$embeddingColumn, $faceScore = excluded.$faceScore, $faceBlur = excluded.$faceBlur, $isSideways = excluded.$isSideways, $imageHeight = excluded.$imageHeight, $imageWidth = excluded.$imageWidth, $mlVersionColumn = excluded.$mlVersionColumn '''; final parameterSets = batch.map((face) { final map = mapRemoteToFaceDB(face); @@ -121,7 +122,7 @@ class FaceMLDataDB { map[fileIDColumn], map[faceIDColumn], map[faceDetectionColumn], - map[faceEmbeddingBlob], + map[embeddingColumn], map[faceScore], map[faceBlur], map[isSideways], @@ -232,6 +233,7 @@ class FaceMLDataDB { await db.execute(deleteClusterPersonTable); await db.execute(deleteClusterSummaryTable); await db.execute(deleteNotPersonFeedbackTable); + await db.execute(deleteClipEmbeddingsTable); } Future> getFaceEmbeddingsForCluster( @@ -240,10 +242,10 @@ class FaceMLDataDB { }) async { final db = await instance.asyncDB; final List> maps = await db.getAll( - 'SELECT $faceEmbeddingBlob FROM $facesTable WHERE $faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID = ?) ${limit != null ? 'LIMIT $limit' : ''}', + 'SELECT $embeddingColumn FROM $facesTable WHERE $faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID = ?) ${limit != null ? 'LIMIT $limit' : ''}', [clusterID], ); - return maps.map((e) => e[faceEmbeddingBlob] as Uint8List); + return maps.map((e) => e[embeddingColumn] as Uint8List); } Future>> getFaceEmbeddingsForClusters( @@ -254,7 +256,7 @@ class FaceMLDataDB { final Map> result = {}; final selectQuery = ''' - SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob + SELECT fc.$fcClusterID, fe.$embeddingColumn FROM $faceClustersTable fc INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn WHERE fc.$fcClusterID IN (${clusterIDs.join(',')}) @@ -265,7 +267,7 @@ class FaceMLDataDB { for (final map in maps) { final clusterID = map[fcClusterID] as String; - final faceEmbedding = map[faceEmbeddingBlob] as Uint8List; + final faceEmbedding = map[embeddingColumn] as Uint8List; result.putIfAbsent(clusterID, () => []).add(faceEmbedding); } @@ -300,10 +302,25 @@ class FaceMLDataDB { ); final clusterIDs = clusterRows.map((e) => e[clusterIDColumn] as String).toList(); + // final List> faceMaps = await db.getAll( + // 'SELECT * FROM $facesTable where ' + // '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' + // 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinimumQualityFaceScore ORDER BY $faceScore DESC', + // ); + final List> faceMaps = await db.getAll( - 'SELECT * FROM $facesTable where ' - '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' - 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinimumQualityFaceScore ORDER BY $faceScore DESC', + ''' + SELECT * FROM $facesTable + WHERE $faceIDColumn IN ( + SELECT $fcFaceId + FROM $faceClustersTable + WHERE $fcClusterID IN (${List.filled(clusterIDs.length, '?').join(',')}) + ) + AND $fileIDColumn IN (${List.filled(fileId.length, '?').join(',')}) + AND $faceScore > ? + ORDER BY $faceScore DESC + ''', + [...clusterIDs, ...fileId, kMinimumQualityFaceScore], ); if (faceMaps.isNotEmpty) { if (avatarFileId != null) { @@ -367,6 +384,14 @@ class FaceMLDataDB { final List> maps = await db.getAll( 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable WHERE $fcClusterID IN (${clusterIDs.join(",")})', ); + final List> maps = await db.query( + faceClustersTable, + columns: [fcClusterID, fcFaceId], + where: + '$fcClusterID IN (${List.filled(clusterIDs.length, '?').join(',')})', + whereArgs: clusterIDs, + ); + for (final map in maps) { final clusterID = map[fcClusterID] as String; final faceID = map[fcFaceId] as String; @@ -541,7 +566,7 @@ class FaceMLDataDB { while (true) { // Query a batch of rows final List> maps = await db.getAll( - 'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways FROM $facesTable' + 'SELECT $faceIDColumn, $embeddingColumn, $faceScore, $faceBlur, $isSideways FROM $facesTable' ' WHERE $faceScore > $minScore AND $faceBlur > $minClarity' ' ORDER BY $faceIDColumn' ' DESC LIMIT $batchSize OFFSET $offset', @@ -560,7 +585,7 @@ class FaceMLDataDB { final faceInfo = FaceDbInfoForClustering( faceID: faceID, clusterId: faceIdToClusterId[faceID], - embeddingBytes: map[faceEmbeddingBlob] as Uint8List, + embeddingBytes: map[embeddingColumn] as Uint8List, faceScore: map[faceScore] as double, blurValue: map[faceBlur] as double, isSideways: (map[isSideways] as int) == 1, @@ -594,7 +619,7 @@ class FaceMLDataDB { while (true) { // Query a batch of rows final String query = ''' - SELECT $faceIDColumn, $faceEmbeddingBlob + SELECT $faceIDColumn, $embeddingColumn FROM $facesTable WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")}) ORDER BY $faceIDColumn DESC @@ -607,7 +632,7 @@ class FaceMLDataDB { } for (final map in maps) { final faceID = map[faceIDColumn] as String; - result[faceID] = map[faceEmbeddingBlob] as Uint8List; + result[faceID] = map[embeddingColumn] as Uint8List; } if (result.length > 10000) { break; diff --git a/mobile/lib/db/ml/db_fields.dart b/mobile/lib/db/ml/db_fields.dart index c1de4780bc..f05953855e 100644 --- a/mobile/lib/db/ml/db_fields.dart +++ b/mobile/lib/db/ml/db_fields.dart @@ -5,7 +5,7 @@ const facesTable = 'faces'; const fileIDColumn = 'file_id'; const faceIDColumn = 'face_id'; const faceDetectionColumn = 'detection'; -const faceEmbeddingBlob = 'eBlob'; +const embeddingColumn = 'embedding'; const faceScore = 'score'; const faceBlur = 'blur'; const isSideways = 'is_sideways'; @@ -18,7 +18,7 @@ const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( $fileIDColumn INTEGER NOT NULL, $faceIDColumn TEXT NOT NULL UNIQUE, $faceDetectionColumn TEXT NOT NULL, - $faceEmbeddingBlob BLOB NOT NULL, + $embeddingColumn BLOB NOT NULL, $faceScore REAL NOT NULL, $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault, $isSideways INTEGER NOT NULL DEFAULT 0, @@ -95,3 +95,17 @@ CREATE TABLE IF NOT EXISTS $notPersonFeedback ( '''; const deleteNotPersonFeedbackTable = 'DELETE FROM $notPersonFeedback'; // End Clusters Table Fields & Schema Queries + +// ## CLIP EMBEDDINGS TABLE +const clipTable = 'clip'; + +const createClipEmbeddingsTable = ''' +CREATE TABLE IF NOT EXISTS $clipTable ( + $fileIDColumn INTEGER NOT NULL, + $embeddingColumn BLOB NOT NULL, + $mlVersionColumn INTEGER NOT NULL, + PRIMARY KEY ($fileIDColumn) + ); +'''; + +const deleteClipEmbeddingsTable = 'DELETE FROM $clipTable'; diff --git a/mobile/lib/db/ml/db_model_mappers.dart b/mobile/lib/db/ml/db_model_mappers.dart index 85f7d3ce17..7e84f1ced8 100644 --- a/mobile/lib/db/ml/db_model_mappers.dart +++ b/mobile/lib/db/ml/db_model_mappers.dart @@ -1,9 +1,9 @@ import "dart:convert"; -import 'package:photos/face/db_fields.dart'; -import "package:photos/face/model/detection.dart"; -import "package:photos/face/model/face.dart"; +import 'package:photos/db/ml/db_fields.dart'; import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import "package:photos/models/ml/face/detection.dart"; +import "package:photos/models/ml/face/face.dart"; import "package:photos/models/ml/ml_versions.dart"; Map mapRemoteToFaceDB(Face face) { @@ -11,7 +11,7 @@ Map mapRemoteToFaceDB(Face face) { faceIDColumn: face.faceID, fileIDColumn: face.fileID, faceDetectionColumn: json.encode(face.detection.toJson()), - faceEmbeddingBlob: EVector( + embeddingColumn: EVector( values: face.embedding, ).writeToBuffer(), faceScore: face.score, @@ -27,7 +27,7 @@ Face mapRowToFace(Map row) { return Face( row[faceIDColumn] as String, row[fileIDColumn] as int, - EVector.fromBuffer(row[faceEmbeddingBlob] as List).values, + EVector.fromBuffer(row[embeddingColumn] as List).values, row[faceScore] as double, Detection.fromJson(json.decode(row[faceDetectionColumn] as String)), row[faceBlur] as double, diff --git a/mobile/lib/db/ml/embeddings_db.dart b/mobile/lib/db/ml/embeddings_db.dart index 1cdd331c30..be537e3c67 100644 --- a/mobile/lib/db/ml/embeddings_db.dart +++ b/mobile/lib/db/ml/embeddings_db.dart @@ -1,135 +1,71 @@ import "dart:io"; import "dart:typed_data"; -import "package:path/path.dart"; -import 'package:path_provider/path_provider.dart'; import "package:photos/core/event_bus.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/db/ml/db_fields.dart"; import "package:photos/events/embedding_updated_event.dart"; -import "package:photos/models/embedding.dart"; -import "package:sqlite_async/sqlite_async.dart"; - -class EmbeddingsDB { - EmbeddingsDB._privateConstructor(); - - static final EmbeddingsDB instance = EmbeddingsDB._privateConstructor(); +import "package:photos/models/ml/clip.dart"; +extension EmbeddingsDB on FaceMLDataDB { static const databaseName = "ente.embeddings.db"; - static const tableName = "clip_embedding"; - static const oldTableName = "embeddings"; - static const columnFileID = "file_id"; - static const columnEmbedding = "embedding"; - static const columnVersion = "version"; - - @Deprecated("") - static const columnUpdationTime = "updation_time"; - - static Future? _dbFuture; - - Future get _database async { - _dbFuture ??= _initDatabase(); - return _dbFuture!; - } - - Future init() async { - final dir = await getApplicationDocumentsDirectory(); - await _clearDeprecatedStores(dir); - } - - Future _initDatabase() async { - final Directory documentsDirectory = - await getApplicationDocumentsDirectory(); - final String path = join(documentsDirectory.path, databaseName); - final migrations = SqliteMigrations() - ..add( - SqliteMigration( - 1, - (tx) async { - // Avoid creating the old table - // await tx.execute( - // 'CREATE TABLE $oldTableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))', - // ); - }, - ), - ) - ..add( - SqliteMigration( - 2, - (tx) async { - // delete old table - await tx.execute('DROP TABLE IF EXISTS $oldTableName'); - await tx.execute( - 'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnVersion INTEGER, UNIQUE ($columnFileID))', - ); - }, - ), - ); - final database = SqliteDatabase(path: path); - await migrations.migrate(database); - return database; - } - - Future clearTable() async { - final db = await _database; - await db.execute('DELETE FROM $tableName'); - } Future> getAll() async { - final db = await _database; - final results = await db.getAll('SELECT * FROM $tableName'); + final db = await FaceMLDataDB.instance.asyncDB; + final results = await db.getAll('SELECT * FROM $clipTable'); return _convertToEmbeddings(results); } // Get indexed FileIDs - Future> getIndexedFileIds() async { - final db = await _database; + Future> clipIndexedFileWithVersion() async { + final db = await FaceMLDataDB.instance.asyncDB; final maps = await db - .getAll('SELECT $columnFileID , $columnVersion FROM $tableName'); + .getAll('SELECT $fileIDColumn , $mlVersionColumn FROM $clipTable'); final Map result = {}; for (final map in maps) { - result[map[columnFileID] as int] = map[columnVersion] as int; + result[map[mlVersionColumn] as int] = map[mlVersionColumn] as int; } return result; } - // TODO: Add actual colomn for version and use here, similar to faces - Future getIndexedFileCount() async { - final db = await _database; + Future getClipIndexedFileCount() async { + final db = await FaceMLDataDB.instance.asyncDB; const String query = - 'SELECT COUNT(DISTINCT $columnFileID) as count FROM $tableName'; + 'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $clipTable'; final List> maps = await db.getAll(query); return maps.first['count'] as int; } Future put(ClipEmbedding embedding) async { - final db = await _database; + final db = await FaceMLDataDB.instance.asyncDB; await db.execute( - 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) VALUES (?, ?, ?)', + 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) VALUES (?, ?, ?)', _getRowFromEmbedding(embedding), ); Bus.instance.fire(EmbeddingUpdatedEvent()); } Future putMany(List embeddings) async { - final db = await _database; + final db = await FaceMLDataDB.instance.asyncDB; final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList(); await db.executeBatch( - 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnVersion) values(?, ?, ?)', + 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) values(?, ?, ?)', inputs, ); Bus.instance.fire(EmbeddingUpdatedEvent()); } Future deleteEmbeddings(List fileIDs) async { - final db = await _database; + final db = await FaceMLDataDB.instance.asyncDB; await db.execute( - 'DELETE FROM $tableName WHERE $columnFileID IN (${fileIDs.join(", ")})', + 'DELETE FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})', ); Bus.instance.fire(EmbeddingUpdatedEvent()); } - Future deleteAll() async { - final db = await _database; - await db.execute('DELETE FROM $tableName'); + Future deleteClipIndexes() async { + final db = await FaceMLDataDB.instance.asyncDB; + await db.execute('DELETE FROM $clipTable'); Bus.instance.fire(EmbeddingUpdatedEvent()); } @@ -144,9 +80,9 @@ class EmbeddingsDB { } ClipEmbedding _getEmbeddingFromRow(Map row) { - final fileID = row[columnFileID]; - final bytes = row[columnEmbedding] as Uint8List; - final version = row[columnVersion] as int; + final fileID = row[fileIDColumn] as int; + final bytes = row[embeddingColumn] as Uint8List; + final version = row[mlVersionColumn] as int; final list = Float32List.view(bytes.buffer); return ClipEmbedding(fileID: fileID, embedding: list, version: version); } diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index cfbb4a8367..1864fbbbac 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -18,9 +18,9 @@ import 'package:photos/core/constants.dart'; import 'package:photos/core/error-reporting/super_logging.dart'; import 'package:photos/core/errors.dart'; import 'package:photos/core/network/network.dart'; +import "package:photos/db/ml/db.dart"; import 'package:photos/db/upload_locks_db.dart'; import 'package:photos/ente_theme_data.dart'; -import "package:photos/face/db.dart"; import "package:photos/l10n/l10n.dart"; import "package:photos/service_locator.dart"; import 'package:photos/services/app_lifecycle_service.dart'; diff --git a/mobile/lib/models/ml/face/detection.dart b/mobile/lib/models/ml/face/detection.dart index f77fd61b8a..766e326fae 100644 --- a/mobile/lib/models/ml/face/detection.dart +++ b/mobile/lib/models/ml/face/detection.dart @@ -1,7 +1,7 @@ import "dart:math" show min, max; -import "package:photos/face/model/box.dart"; -import "package:photos/face/model/landmark.dart"; +import "package:photos/models/ml/face/box.dart"; +import "package:photos/models/ml/face/landmark.dart"; import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; /// Stores the face detection data, notably the bounding box and landmarks. diff --git a/mobile/lib/models/ml/face/face.dart b/mobile/lib/models/ml/face/face.dart index 99f07cb2dc..0ff5b86eb1 100644 --- a/mobile/lib/models/ml/face/face.dart +++ b/mobile/lib/models/ml/face/face.dart @@ -1,4 +1,4 @@ -import "package:photos/face/model/detection.dart"; +import "package:photos/models/ml/face/detection.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import "package:photos/services/machine_learning/ml_result.dart"; diff --git a/mobile/lib/services/filedata/model/file_data.dart b/mobile/lib/services/filedata/model/file_data.dart index 2c06cf601c..d73ee1d2f8 100644 --- a/mobile/lib/services/filedata/model/file_data.dart +++ b/mobile/lib/services/filedata/model/file_data.dart @@ -1,4 +1,4 @@ -import "package:photos/face/model/face.dart"; +import "package:photos/models/ml/face/face.dart"; const _faceKey = 'face'; const _clipKey = 'clip'; diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index 6dbb209dcd..de544314ef 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -1,6 +1,6 @@ import 'dart:math' show max, min; -import "package:photos/face/model/dimension.dart"; +import "package:photos/models/ml/face/dimension.dart"; enum FaceDirection { left, right, straight } diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart index 2aeb68d604..afe1a3cfae 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart @@ -6,7 +6,7 @@ import 'dart:ui' as ui show Image; import 'package:logging/logging.dart'; import "package:onnx_dart/onnx_dart.dart"; import 'package:onnxruntime/onnxruntime.dart'; -import "package:photos/face/model/dimension.dart"; +import "package:photos/models/ml/face/dimension.dart"; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart"; import "package:photos/services/machine_learning/ml_model.dart"; diff --git a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart index 8ac8091f31..e51931e254 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart @@ -5,13 +5,13 @@ import "dart:ui" show Image; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; -import "package:photos/db/embeddings_db.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/db/ml/embeddings_db.dart"; import "package:photos/events/diff_sync_complete_event.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/list.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/face.dart"; -import "package:photos/models/embedding.dart"; +import "package:photos/models/ml/clip.dart"; +import "package:photos/models/ml/face/face.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/service_locator.dart"; import "package:photos/services/filedata/filedata_service.dart"; @@ -144,7 +144,7 @@ class FaceRecognitionService { } } await FaceMLDataDB.instance.bulkInsertFaces(faces); - await EmbeddingsDB.instance.putMany(clipEmbeddings); + await FaceMLDataDB.instance.putMany(clipEmbeddings); } // Yield any remaining instructions if (batchToYield.isNotEmpty) { diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 85fb95af75..f02110baca 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -7,12 +7,12 @@ import "package:logging/logging.dart"; import "package:ml_linalg/linalg.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/stop_watch.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; @@ -249,7 +249,7 @@ class ClusterFeedbackService { PeopleChangedEvent( relevantFiles: files, type: PeopleEventType.removedFilesFromCluster, - source: "$clusterID", + source: clusterID, ), ); _logger.info('removeFilesFromCluster done'); diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index 5a5f4b23a6..fd676b8980 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -4,11 +4,11 @@ import "dart:developer"; import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/stop_watch.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/models/api/entity/type.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/services/entity_service.dart"; import "package:shared_preferences/shared_preferences.dart"; diff --git a/mobile/lib/services/machine_learning/ml_computer.dart b/mobile/lib/services/machine_learning/ml_computer.dart index 753725abc0..ed9cc719d4 100644 --- a/mobile/lib/services/machine_learning/ml_computer.dart +++ b/mobile/lib/services/machine_learning/ml_computer.dart @@ -5,7 +5,7 @@ import 'dart:typed_data' show Uint8List; import "package:dart_ui_isolate/dart_ui_isolate.dart"; import "package:logging/logging.dart"; -import "package:photos/face/model/box.dart"; +import "package:photos/models/ml/face/box.dart"; import "package:photos/services/machine_learning/ml_model.dart"; import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart"; import "package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart"; diff --git a/mobile/lib/services/machine_learning/ml_result.dart b/mobile/lib/services/machine_learning/ml_result.dart index 254564cfc6..9385639ee8 100644 --- a/mobile/lib/services/machine_learning/ml_result.dart +++ b/mobile/lib/services/machine_learning/ml_result.dart @@ -1,6 +1,6 @@ import "dart:convert" show jsonEncode, jsonDecode; -import "package:photos/face/model/dimension.dart"; +import "package:photos/models/ml/face/dimension.dart"; import 'package:photos/models/ml/ml_typedefs.dart'; import "package:photos/models/ml/ml_versions.dart"; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 1ed24f0d46..a2ef72b3f3 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -8,13 +8,13 @@ import "package:logging/logging.dart"; import "package:package_info_plus/package_info_plus.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/events/machine_learning_control_event.dart"; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/box.dart"; -import "package:photos/face/model/detection.dart" as face_detection; -import "package:photos/face/model/face.dart"; -import "package:photos/face/model/landmark.dart"; +import "package:photos/models/ml/face/box.dart"; +import "package:photos/models/ml/face/detection.dart" as face_detection; +import "package:photos/models/ml/face/face.dart"; +import "package:photos/models/ml/face/landmark.dart"; import "package:photos/service_locator.dart"; import "package:photos/services/filedata/filedata_service.dart"; import "package:photos/services/filedata/model/file_data.dart"; diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index 9329a79081..0949f4f0fd 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -9,10 +9,9 @@ import 'package:photos/data/holidays.dart'; import 'package:photos/data/months.dart'; import 'package:photos/data/years.dart'; import 'package:photos/db/files_db.dart'; +import "package:photos/db/ml/db.dart"; import 'package:photos/events/local_photos_updated_event.dart'; import "package:photos/extensions/string_ext.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/models/api/collection/user.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/collection/collection_items.dart'; @@ -22,6 +21,7 @@ import 'package:photos/models/file/file_type.dart'; import "package:photos/models/local_entity_data.dart"; import "package:photos/models/location/location.dart"; import "package:photos/models/location_tag/location_tag.dart"; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/search/album_search_result.dart'; import 'package:photos/models/search/generic_search_result.dart'; import "package:photos/models/search/search_constants.dart"; diff --git a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart index d295d85e1c..0f31986598 100644 --- a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart +++ b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart @@ -1,8 +1,8 @@ import 'package:flutter/material.dart'; import 'package:photos/core/constants.dart'; -import "package:photos/face/model/person.dart"; import 'package:photos/models/collection/collection.dart'; import "package:photos/models/gallery_type.dart"; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/selected_files.dart'; import 'package:photos/theme/ente_theme.dart'; import 'package:photos/ui/components/bottom_action_bar/action_bar_widget.dart'; diff --git a/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart b/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart index 5eed505b86..08bef662ed 100644 --- a/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/ml_debug_section_widget.dart @@ -3,9 +3,9 @@ import "dart:async"; import 'package:flutter/material.dart'; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/person.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/service_locator.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; diff --git a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart index cf90489c6d..4508300aab 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart @@ -11,7 +11,6 @@ import 'package:photos/core/configuration.dart'; import "package:photos/core/event_bus.dart"; import "package:photos/events/guest_view_event.dart"; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/device_collection.dart'; @@ -20,6 +19,7 @@ import 'package:photos/models/file/file_type.dart'; import 'package:photos/models/files_split.dart'; import 'package:photos/models/gallery_type.dart'; import "package:photos/models/metadata/common_keys.dart"; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; import 'package:photos/services/hidden_service.dart'; diff --git a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart index bb9825f649..fbc998110a 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart @@ -1,7 +1,7 @@ import 'package:flutter/material.dart'; -import "package:photos/face/model/person.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/gallery_type.dart'; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/selected_files.dart'; import "package:photos/theme/effects.dart"; import "package:photos/theme/ente_theme.dart"; diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index 7465f7c8c0..3ab4dac6ca 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -4,11 +4,11 @@ import "dart:typed_data"; import "package:flutter/cupertino.dart"; import "package:flutter/foundation.dart" show kDebugMode; import "package:flutter/material.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/extensions/stop_watch.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/face.dart"; -import "package:photos/face/model/person.dart"; import 'package:photos/models/file/file.dart'; +import "package:photos/models/ml/face/face.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/models/nanoids/cluster_id.dart"; import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; diff --git a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart index 304beba5a4..8b12547b26 100644 --- a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart +++ b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart @@ -3,11 +3,11 @@ import "dart:developer" as dev show log; import "package:flutter/foundation.dart" show Uint8List, kDebugMode; import "package:flutter/material.dart"; import "package:logging/logging.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/box.dart"; -import "package:photos/face/model/face.dart"; -import "package:photos/face/model/person.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/box.dart"; +import "package:photos/models/ml/face/face.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/ui/components/buttons/chip_button_widget.dart"; diff --git a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart index c868e60e71..cb8dfe797a 100644 --- a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart +++ b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart @@ -7,11 +7,11 @@ import 'package:flutter/material.dart'; import "package:logging/logging.dart"; import 'package:modal_bottom_sheet/modal_bottom_sheet.dart'; import "package:photos/core/event_bus.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart'; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; diff --git a/mobile/lib/ui/viewer/people/cluster_app_bar.dart b/mobile/lib/ui/viewer/people/cluster_app_bar.dart index 25b16aab34..05df78c390 100644 --- a/mobile/lib/ui/viewer/people/cluster_app_bar.dart +++ b/mobile/lib/ui/viewer/people/cluster_app_bar.dart @@ -6,12 +6,12 @@ import 'package:logging/logging.dart'; import 'package:photos/core/configuration.dart'; import 'package:photos/core/event_bus.dart'; import "package:photos/db/files_db.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; import 'package:photos/events/subscription_purchased_event.dart'; -import "package:photos/face/db.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/models/gallery_type.dart'; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart index 05b89d8247..af41a200be 100644 --- a/mobile/lib/ui/viewer/people/cluster_page.dart +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -6,11 +6,11 @@ import 'package:photos/core/event_bus.dart'; import 'package:photos/events/files_updated_event.dart'; import 'package:photos/events/local_photos_updated_event.dart'; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/models/file/file.dart'; import 'package:photos/models/file_load_result.dart'; import 'package:photos/models/gallery_type.dart'; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/selected_files.dart'; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import 'package:photos/ui/viewer/actions/file_selection_overlay_bar.dart'; diff --git a/mobile/lib/ui/viewer/people/people_app_bar.dart b/mobile/lib/ui/viewer/people/people_app_bar.dart index 828dff6bfd..12f80fbfdf 100644 --- a/mobile/lib/ui/viewer/people/people_app_bar.dart +++ b/mobile/lib/ui/viewer/people/people_app_bar.dart @@ -7,10 +7,10 @@ import 'package:photos/core/configuration.dart'; import 'package:photos/core/event_bus.dart'; import "package:photos/events/people_changed_event.dart"; import 'package:photos/events/subscription_purchased_event.dart'; -import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/models/gallery_type.dart'; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; @@ -67,7 +67,8 @@ class _AppBarWidgetState extends State { }; collectionActions = CollectionActions(CollectionsService.instance); widget.selectedFiles.addListener(_selectedFilesListener); - _userAuthEventSubscription = Bus.instance.on().listen((event) { + _userAuthEventSubscription = + Bus.instance.on().listen((event) { setState(() {}); }); _appBarTitle = widget.title; @@ -88,7 +89,8 @@ class _AppBarWidgetState extends State { centerTitle: false, title: Text( _appBarTitle!, - style: Theme.of(context).textTheme.headlineSmall!.copyWith(fontSize: 16), + style: + Theme.of(context).textTheme.headlineSmall!.copyWith(fontSize: 16), maxLines: 2, overflow: TextOverflow.ellipsis, ), @@ -112,7 +114,8 @@ class _AppBarWidgetState extends State { } try { - await PersonService.instance.updateAttributes(widget.person.remoteID, name: text); + await PersonService.instance + .updateAttributes(widget.person.remoteID, name: text); if (mounted) { _appBarTitle = text; setState(() {}); @@ -132,7 +135,8 @@ class _AppBarWidgetState extends State { List _getDefaultActions(BuildContext context) { final List actions = []; // If the user has selected files, don't show any actions - if (widget.selectedFiles.files.isNotEmpty || !Configuration.instance.hasConfiguredAccount()) { + if (widget.selectedFiles.files.isNotEmpty || + !Configuration.instance.hasConfiguredAccount()) { return actions; } @@ -223,7 +227,8 @@ class _AppBarWidgetState extends State { unawaited( Navigator.of(context).push( MaterialPageRoute( - builder: (context) => PersonReviewClusterSuggestion(widget.person), + builder: (context) => + PersonReviewClusterSuggestion(widget.person), ), ), ); @@ -266,11 +271,13 @@ class _AppBarWidgetState extends State { bool assignName = false; await showChoiceDialog( context, - title: "Are you sure you want to show this person in people section again?", + title: + "Are you sure you want to show this person in people section again?", firstButtonLabel: "Yes, show person", firstButtonOnTap: () async { try { - await PersonService.instance.deletePerson(widget.person.remoteID, onlyMapping: false); + await PersonService.instance + .deletePerson(widget.person.remoteID, onlyMapping: false); Bus.instance.fire(PeopleChangedEvent()); assignName = true; } catch (e, s) { diff --git a/mobile/lib/ui/viewer/people/people_page.dart b/mobile/lib/ui/viewer/people/people_page.dart index f857943ea3..74a7e7f6bb 100644 --- a/mobile/lib/ui/viewer/people/people_page.dart +++ b/mobile/lib/ui/viewer/people/people_page.dart @@ -7,10 +7,10 @@ import 'package:photos/core/event_bus.dart'; import 'package:photos/events/files_updated_event.dart'; import 'package:photos/events/local_photos_updated_event.dart'; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/model/person.dart"; import 'package:photos/models/file/file.dart'; import 'package:photos/models/file_load_result.dart'; import 'package:photos/models/gallery_type.dart'; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/models/selected_files.dart'; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; diff --git a/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart index db1141fa9e..0c5a4ccbf1 100644 --- a/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart +++ b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart @@ -6,11 +6,11 @@ import "package:flutter/foundation.dart" show kDebugMode; import "package:flutter/material.dart"; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; +import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/l10n/l10n.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/person.dart"; import 'package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart'; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/theme/ente_theme.dart"; diff --git a/mobile/lib/ui/viewer/people/person_clusters_page.dart b/mobile/lib/ui/viewer/people/person_clusters_page.dart index 810fcbc779..935ddbff87 100644 --- a/mobile/lib/ui/viewer/people/person_clusters_page.dart +++ b/mobile/lib/ui/viewer/people/person_clusters_page.dart @@ -3,8 +3,8 @@ import "package:flutter/material.dart"; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/events/people_changed_event.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; import "package:photos/theme/ente_theme.dart"; diff --git a/mobile/lib/ui/viewer/people/person_row_item.dart b/mobile/lib/ui/viewer/people/person_row_item.dart index ed1fc9fa27..95e2453eaa 100644 --- a/mobile/lib/ui/viewer/people/person_row_item.dart +++ b/mobile/lib/ui/viewer/people/person_row_item.dart @@ -1,6 +1,6 @@ import "package:flutter/material.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/ui/viewer/search/result/person_face_widget.dart"; class PersonRowItem extends StatelessWidget { diff --git a/mobile/lib/ui/viewer/search/result/person_face_widget.dart b/mobile/lib/ui/viewer/search/result/person_face_widget.dart index eea17134b2..1b812dd36a 100644 --- a/mobile/lib/ui/viewer/search/result/person_face_widget.dart +++ b/mobile/lib/ui/viewer/search/result/person_face_widget.dart @@ -3,10 +3,10 @@ import "dart:typed_data"; import 'package:flutter/widgets.dart'; import "package:photos/db/files_db.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/face.dart"; -import "package:photos/face/model/person.dart"; +import "package:photos/db/ml/db.dart"; import 'package:photos/models/file/file.dart'; +import "package:photos/models/ml/face/face.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/ui/common/loading_widget.dart"; import "package:photos/ui/viewer/file/thumbnail_widget.dart"; diff --git a/mobile/lib/ui/viewer/search_tab/people_section.dart b/mobile/lib/ui/viewer/search_tab/people_section.dart index b3ce01045f..3e1d7599e5 100644 --- a/mobile/lib/ui/viewer/search_tab/people_section.dart +++ b/mobile/lib/ui/viewer/search_tab/people_section.dart @@ -4,8 +4,8 @@ import "package:collection/collection.dart"; import "package:flutter/material.dart"; import "package:photos/core/constants.dart"; import "package:photos/events/event.dart"; -import "package:photos/face/model/person.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/person.dart"; import "package:photos/models/search/album_search_result.dart"; import "package:photos/models/search/generic_search_result.dart"; import "package:photos/models/search/recent_searches.dart"; diff --git a/mobile/lib/utils/face/face_box_crop.dart b/mobile/lib/utils/face/face_box_crop.dart index 197aa31f38..a9b5567a9f 100644 --- a/mobile/lib/utils/face/face_box_crop.dart +++ b/mobile/lib/utils/face/face_box_crop.dart @@ -2,9 +2,9 @@ import "dart:io" show File; import "package:flutter/foundation.dart"; import "package:photos/core/cache/lru_map.dart"; -import "package:photos/face/model/box.dart"; import "package:photos/models/file/file.dart"; import "package:photos/models/file/file_type.dart"; +import "package:photos/models/ml/face/box.dart"; import "package:photos/services/machine_learning/ml_computer.dart"; import "package:photos/utils/file_util.dart"; import "package:photos/utils/thumbnail_util.dart"; diff --git a/mobile/lib/utils/face/face_util.dart b/mobile/lib/utils/face/face_util.dart index 56dc8f3bf0..cb5b186b04 100644 --- a/mobile/lib/utils/face/face_util.dart +++ b/mobile/lib/utils/face/face_util.dart @@ -5,7 +5,7 @@ import "package:computer/computer.dart"; import "package:flutter_image_compress/flutter_image_compress.dart"; import "package:image/image.dart" as img; import "package:logging/logging.dart"; -import "package:photos/face/model/box.dart"; +import "package:photos/models/ml/face/box.dart"; /// Bounding box of a face. /// diff --git a/mobile/lib/utils/image_ml_util.dart b/mobile/lib/utils/image_ml_util.dart index e566a9d536..b0badaede3 100644 --- a/mobile/lib/utils/image_ml_util.dart +++ b/mobile/lib/utils/image_ml_util.dart @@ -6,8 +6,8 @@ import "dart:ui"; import 'package:flutter/painting.dart' as paint show decodeImageFromList; import 'package:ml_linalg/linalg.dart'; -import "package:photos/face/model/box.dart"; -import "package:photos/face/model/dimension.dart"; +import "package:photos/models/ml/face/box.dart"; +import "package:photos/models/ml/face/dimension.dart"; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; import 'package:photos/services/machine_learning/face_ml/face_alignment/similarity_transform.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; diff --git a/mobile/lib/utils/ml_util.dart b/mobile/lib/utils/ml_util.dart index bc0cb6c737..5b68b9c819 100644 --- a/mobile/lib/utils/ml_util.dart +++ b/mobile/lib/utils/ml_util.dart @@ -5,13 +5,13 @@ import "dart:typed_data" show ByteData; import "package:flutter/services.dart" show PlatformException; import "package:logging/logging.dart"; import "package:photos/core/configuration.dart"; -import "package:photos/db/embeddings_db.dart"; import "package:photos/db/files_db.dart"; -import "package:photos/face/db.dart"; -import "package:photos/face/model/dimension.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/db/ml/embeddings_db.dart"; import "package:photos/models/file/extensions/file_props.dart"; import "package:photos/models/file/file.dart"; import "package:photos/models/file/file_type.dart"; +import "package:photos/models/ml/face/dimension.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/services/filedata/model/file_data.dart"; import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart"; @@ -54,7 +54,7 @@ Future getIndexStatus() async { final int facesIndexedFiles = await FaceMLDataDB.instance.getIndexedFileCount(); final int clipIndexedFiles = - await EmbeddingsDB.instance.getIndexedFileCount(); + await FaceMLDataDB.instance.getClipIndexedFileCount(); final int indexedFiles = math.min(facesIndexedFiles, clipIndexedFiles); final showIndexedFiles = math.min(indexedFiles, indexableFiles); @@ -73,7 +73,7 @@ Future> getFilesForMlIndexing() async { final Map faceIndexedFileIDs = await FaceMLDataDB.instance.getIndexedFileIds(); final Map clipIndexedFileIDs = - await EmbeddingsDB.instance.getIndexedFileIds(); + await FaceMLDataDB.instance.clipIndexedFileWithVersion(); // Get all regular files and all hidden files final enteFiles = await SearchService.instance.getAllFiles(); From bfec2ff2be52917f2ac9dfc2ac498db7c49f003b Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:22:02 +0530 Subject: [PATCH 13/17] [mob] Use single db for ml data --- mobile/lib/models/api/entity/type.dart | 2 ++ .../semantic_search_service.dart | 30 ++++++++++++------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/mobile/lib/models/api/entity/type.dart b/mobile/lib/models/api/entity/type.dart index 431ea4e57b..05de5c7826 100644 --- a/mobile/lib/models/api/entity/type.dart +++ b/mobile/lib/models/api/entity/type.dart @@ -15,6 +15,8 @@ EntityType typeFromString(String type) { return EntityType.location; case "person_v2": return EntityType.personV2; + case "personV2": + return EntityType.personV2; } debugPrint("unexpected collection type $type"); return EntityType.unknown; diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 1c196314b0..ec68264142 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -1,18 +1,19 @@ import "dart:async" show unawaited; import "dart:developer" as dev show log; import "dart:math" show min; -import "dart:typed_data" show ByteData; import "dart:ui" show Image; import "package:computer/computer.dart"; +import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:photos/core/cache/lru_map.dart"; import "package:photos/core/event_bus.dart"; -import "package:photos/db/embeddings_db.dart"; import "package:photos/db/files_db.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/db/ml/embeddings_db.dart"; import 'package:photos/events/embedding_updated_event.dart'; -import "package:photos/models/embedding.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/clip.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/service_locator.dart"; import "package:photos/services/collections_service.dart"; @@ -57,7 +58,7 @@ class SemanticSearchService { return; } _hasInitialized = true; - await EmbeddingsDB.instance.init(); + await _loadImageEmbeddings(); Bus.instance.on().listen((event) { if (!_hasInitialized) return; @@ -112,7 +113,7 @@ class SemanticSearchService { } Future clearIndexes() async { - await EmbeddingsDB.instance.deleteAll(); + await FaceMLDataDB.instance.deleteClipIndexes(); final preferences = await SharedPreferences.getInstance(); await preferences.remove("sync_time_embeddings_v3"); _logger.info("Indexes cleared"); @@ -121,7 +122,7 @@ class SemanticSearchService { Future _loadImageEmbeddings() async { _logger.info("Pulling cached embeddings"); final startTime = DateTime.now(); - _cachedImageEmbeddings = await EmbeddingsDB.instance.getAll(); + _cachedImageEmbeddings = await FaceMLDataDB.instance.getAll(); final endTime = DateTime.now(); _logger.info( "Loading ${_cachedImageEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", @@ -133,7 +134,7 @@ class SemanticSearchService { Future> _getFileIDsToBeIndexed() async { final uploadedFileIDs = await getIndexableFileIDs(); - final embeddedFileIDs = await EmbeddingsDB.instance.getIndexedFileIds(); + final embeddedFileIDs = await FaceMLDataDB.instance.getIndexedFileIds(); embeddedFileIDs.removeWhere((key, value) => value < clipMlVersion); return uploadedFileIDs.difference(embeddedFileIDs.keys.toSet()).toList(); @@ -178,7 +179,7 @@ class SemanticSearchService { _logger.info(results.length.toString() + " results"); if (deletedEntries.isNotEmpty) { - unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries)); + unawaited(FaceMLDataDB.instance.deleteEmbeddings(deletedEntries)); } return results; @@ -221,7 +222,7 @@ class SemanticSearchService { _logger.info(results.length.toString() + " results"); if (deletedEntries.isNotEmpty) { - unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries)); + unawaited(FaceMLDataDB.instance.deleteEmbeddings(deletedEntries)); } final matchingFileIDs = []; @@ -253,12 +254,12 @@ class SemanticSearchService { embedding: clipResult.embedding, version: clipMlVersion, ); - await EmbeddingsDB.instance.put(embedding); + await FaceMLDataDB.instance.put(embedding); } static Future storeEmptyClipImageResult(EnteFile entefile) async { final embedding = ClipEmbedding.empty(entefile.uploadedFileID!); - await EmbeddingsDB.instance.put(embedding); + await FaceMLDataDB.instance.put(embedding); } Future> _getTextEmbedding(String query) async { @@ -320,6 +321,7 @@ List computeBulkSimilarities(Map args) { final textEmbedding = args["textEmbedding"] as List; final minimumSimilarity = args["minimumSimilarity"] ?? SemanticSearchService.kMinimumSimilarityThreshold; + double bestScore = 0.0; for (final imageEmbedding in imageEmbeddings) { final score = computeCosineSimilarity( imageEmbedding.embedding, @@ -328,6 +330,12 @@ List computeBulkSimilarities(Map args) { if (score >= minimumSimilarity) { queryResults.add(QueryResult(imageEmbedding.fileID, score)); } + if (score > bestScore) { + bestScore = score; + } + } + if (kDebugMode && queryResults.isEmpty) { + dev.log("No results found for query with best score: $bestScore"); } queryResults.sort((first, second) => second.score.compareTo(first.score)); From a90cb4e45fe07475bbc40e10b4d73df6975a63e7 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:41:11 +0530 Subject: [PATCH 14/17] [mob] Fix queries --- mobile/lib/db/ml/db.dart | 49 +++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index 11c9cf26d4..7d1a8b1700 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -29,6 +29,7 @@ class FaceMLDataDB { static final Logger _logger = Logger("FaceMLDataDB"); static const _databaseName = "ente.face_ml_db_v3.db"; + // static const _databaseVersion = 1; FaceMLDataDB._privateConstructor(); @@ -256,14 +257,20 @@ class FaceMLDataDB { final Map> result = {}; final selectQuery = ''' - SELECT fc.$fcClusterID, fe.$embeddingColumn - FROM $faceClustersTable fc - INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn - WHERE fc.$fcClusterID IN (${clusterIDs.join(',')}) - ${limit != null ? 'LIMIT $limit' : ''} - '''; + SELECT fc.$fcClusterID, fe.$embeddingColumn + FROM $faceClustersTable fc + INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn + WHERE fc.$fcClusterID IN (${List.filled(clusterIDs.length, '?').join(',')}) + ${limit != null ? 'LIMIT ?' : ''} +'''; - final List> maps = await db.getAll(selectQuery); + final List selectQueryParams = [...clusterIDs]; + if (limit != null) { + selectQueryParams.add(limit); + } + + final List> maps = + await db.getAll(selectQuery, selectQueryParams); for (final map in maps) { final clusterID = map[fcClusterID] as String; @@ -381,15 +388,14 @@ class FaceMLDataDB { ) async { final db = await instance.asyncDB; final Map> result = {}; + final List> maps = await db.getAll( - 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable WHERE $fcClusterID IN (${clusterIDs.join(",")})', - ); - final List> maps = await db.query( - faceClustersTable, - columns: [fcClusterID, fcFaceId], - where: - '$fcClusterID IN (${List.filled(clusterIDs.length, '?').join(',')})', - whereArgs: clusterIDs, + ''' + SELECT $fcClusterID, $fcFaceId + FROM $faceClustersTable + WHERE $fcClusterID IN (${List.filled(clusterIDs.length, '?').join(',')}) + ''', + [...clusterIDs], ); for (final map in maps) { @@ -807,8 +813,12 @@ class FaceMLDataDB { final db = instance.asyncDB; return db.then((db) async { final List> maps = await db.getAll( - 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable ' - 'WHERE $fcClusterID IN (${clusterIDs.join(",")})', + ''' + SELECT $fcClusterID, $fcFaceId + FROM $faceClustersTable + WHERE $fcClusterID IN (${List.filled(clusterIDs.length, '?').join(',')}) + ''', + [...clusterIDs], ); final Map> result = {}; for (final map in maps) { @@ -876,9 +886,12 @@ class FaceMLDataDB { ) async { final db = await instance.asyncDB; final Map result = {}; + final rows = await db.getAll( - 'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${clusterIDs.join(",")})', + 'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${List.filled(clusterIDs.length, '?').join(',')})', + [...clusterIDs], ); + for (final r in rows) { final id = r[clusterIDColumn] as String; final avg = r[avgColumn] as Uint8List; From b7acd3033d52aac5b725559db9ae253ed53b18d5 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:32:31 +0530 Subject: [PATCH 15/17] [mob] Fix typo --- mobile/lib/db/ml/embeddings_db.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mobile/lib/db/ml/embeddings_db.dart b/mobile/lib/db/ml/embeddings_db.dart index be537e3c67..8d59e8504b 100644 --- a/mobile/lib/db/ml/embeddings_db.dart +++ b/mobile/lib/db/ml/embeddings_db.dart @@ -23,7 +23,7 @@ extension EmbeddingsDB on FaceMLDataDB { .getAll('SELECT $fileIDColumn , $mlVersionColumn FROM $clipTable'); final Map result = {}; for (final map in maps) { - result[map[mlVersionColumn] as int] = map[mlVersionColumn] as int; + result[map[fileIDColumn] as int] = map[mlVersionColumn] as int; } return result; } From 0c20be98d73c40207976fd7872097b3000a4df14 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:37:29 +0530 Subject: [PATCH 16/17] [mob] Allow injecting score threshold in query --- .../semantic_search/semantic_search_service.dart | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index ec68264142..b01d9d82c2 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -144,6 +144,14 @@ class SemanticSearchService { String query, { double? scoreThreshold, }) async { + // if the query starts with 0.xxx, the split the query to get score threshold and actual query + if (query.startsWith(RegExp(r"0\.\d+"))) { + final parts = query.split(" "); + if (parts.length > 1) { + scoreThreshold = double.parse(parts[0]); + query = parts.sublist(1).join(" "); + } + } final textEmbedding = await _getTextEmbedding(query); final queryResults = await _getSimilarities( From 1ca1967ec297cdf10989eb8c7b4ac959e18097c2 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Fri, 16 Aug 2024 12:13:56 +0530 Subject: [PATCH 17/17] Rename person_v2 as cgroup --- mobile/lib/models/api/entity/type.dart | 14 ++++++------- mobile/lib/services/entity_service.dart | 2 +- .../face_ml/person/person_service.dart | 20 +++++++++---------- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/mobile/lib/models/api/entity/type.dart b/mobile/lib/models/api/entity/type.dart index 05de5c7826..2114dc2c5b 100644 --- a/mobile/lib/models/api/entity/type.dart +++ b/mobile/lib/models/api/entity/type.dart @@ -3,7 +3,7 @@ import "package:flutter/foundation.dart"; enum EntityType { location, person, - personV2, + cgroup, unknown, } @@ -13,12 +13,10 @@ EntityType typeFromString(String type) { return EntityType.location; case "person": return EntityType.location; - case "person_v2": - return EntityType.personV2; - case "personV2": - return EntityType.personV2; + case "cgroup": + return EntityType.cgroup; } - debugPrint("unexpected collection type $type"); + debugPrint("unexpected entity type $type"); return EntityType.unknown; } @@ -36,8 +34,8 @@ extension EntityTypeExtn on EntityType { return "location"; case EntityType.person: return "person"; - case EntityType.personV2: - return "person_v2"; + case EntityType.cgroup: + return "cgroup"; case EntityType.unknown: return "unknown"; } diff --git a/mobile/lib/services/entity_service.dart b/mobile/lib/services/entity_service.dart index e979d4fd9e..22089e7737 100644 --- a/mobile/lib/services/entity_service.dart +++ b/mobile/lib/services/entity_service.dart @@ -101,7 +101,7 @@ class EntityService { Future syncEntities() async { try { await _remoteToLocalSync(EntityType.location); - await _remoteToLocalSync(EntityType.personV2); + await _remoteToLocalSync(EntityType.cgroup); } catch (e) { _logger.severe("Failed to sync entities", e); } diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index fd676b8980..e8a0b78fc6 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -37,7 +37,7 @@ class PersonService { } Future> getPersons() async { - final entities = await entityService.getEntities(EntityType.personV2); + final entities = await entityService.getEntities(EntityType.cgroup); return entities .map( (e) => PersonEntity(e.id, PersonData.fromJson(json.decode(e.data))), @@ -46,7 +46,7 @@ class PersonService { } Future getPerson(String id) { - return entityService.getEntity(EntityType.personV2, id).then((e) { + return entityService.getEntity(EntityType.cgroup, id).then((e) { if (e == null) { return null; } @@ -55,7 +55,7 @@ class PersonService { } Future> getPersonsMap() async { - final entities = await entityService.getEntities(EntityType.personV2); + final entities = await entityService.getEntities(EntityType.cgroup); final Map map = {}; for (var e in entities) { final person = @@ -95,7 +95,7 @@ class PersonService { ) .toList(); entityService - .addOrUpdate(EntityType.personV2, personData.toJson(), id: personID) + .addOrUpdate(EntityType.cgroup, personData.toJson(), id: personID) .ignore(); personData.logStats(); } @@ -163,7 +163,7 @@ class PersonService { isHidden: isHidden, ); final result = await entityService.addOrUpdate( - EntityType.personV2, + EntityType.cgroup, data.toJson(), ); await faceMLDataDB.assignClusterToPerson( @@ -181,7 +181,7 @@ class PersonService { final personData = person.data; personData.assigned!.removeWhere((element) => element.id != clusterID); await entityService.addOrUpdate( - EntityType.personV2, + EntityType.cgroup, personData.toJson(), id: personID, ); @@ -216,7 +216,7 @@ class PersonService { } await entityService.addOrUpdate( - EntityType.personV2, + EntityType.cgroup, personData.toJson(), id: person.remoteID, ); @@ -232,7 +232,7 @@ class PersonService { final PersonEntity justName = PersonEntity(personID, PersonData(name: entity.data.name)); await entityService.addOrUpdate( - EntityType.personV2, + EntityType.cgroup, justName.data.toJson(), id: personID, ); @@ -249,7 +249,7 @@ class PersonService { Future fetchRemoteClusterFeedback() async { await entityService.syncEntities(); - final entities = await entityService.getEntities(EntityType.personV2); + final entities = await entityService.getEntities(EntityType.cgroup); entities.sort((a, b) => a.updatedAt.compareTo(b.updatedAt)); final Map faceIdToClusterID = {}; final Map clusterToPersonID = {}; @@ -307,7 +307,7 @@ class PersonService { Future _updatePerson(PersonEntity updatePerson) async { await entityService.addOrUpdate( - EntityType.personV2, + EntityType.cgroup, updatePerson.data.toJson(), id: updatePerson.remoteID, );