From e9a24efecb1e54962adb54eb9b1926a05848754e Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Mon, 14 Apr 2025 16:23:06 +0530 Subject: [PATCH] Make ml related classes generic --- mobile/lib/db/ml/base.dart | 2 +- mobile/lib/models/ml/clip.dart | 6 ++-- mobile/lib/models/ml/face/face.dart | 12 ++++---- .../face_ml/face_detection/detection.dart | 3 +- .../face_ml/face_recognition_service.dart | 8 +++--- .../machine_learning/ml_indexing_isolate.dart | 6 ++-- .../services/machine_learning/ml_result.dart | 28 ++++++++++--------- .../services/machine_learning/ml_service.dart | 23 +++++++++------ .../semantic_search_service.dart | 19 ++----------- mobile/lib/utils/ml_util.dart | 10 +++---- 10 files changed, 55 insertions(+), 62 deletions(-) diff --git a/mobile/lib/db/ml/base.dart b/mobile/lib/db/ml/base.dart index 5e2879d463..35e745fd1f 100644 --- a/mobile/lib/db/ml/base.dart +++ b/mobile/lib/db/ml/base.dart @@ -112,7 +112,7 @@ abstract class IMLDataDB { Future> getAllClipVectors(); Future> clipIndexedFileWithVersion(); Future getClipIndexedFileCount({int minimumMlVersion}); - Future putClip(List embeddings); + Future putClip(List> embeddings); Future deleteClipEmbeddings(List fileIDs); Future deleteClipIndexes(); } diff --git a/mobile/lib/models/ml/clip.dart b/mobile/lib/models/ml/clip.dart index 570dae388a..ce6eb2707f 100644 --- a/mobile/lib/models/ml/clip.dart +++ b/mobile/lib/models/ml/clip.dart @@ -2,8 +2,8 @@ import "dart:convert"; import "package:photos/models/ml/ml_versions.dart"; -class ClipEmbedding { - final int fileID; +class ClipEmbedding { + final T fileID; final List embedding; int version; @@ -15,7 +15,7 @@ class ClipEmbedding { required this.version, }); - factory ClipEmbedding.empty(int fileID) { + factory ClipEmbedding.empty(T fileID) { return ClipEmbedding( fileID: fileID, embedding: [], diff --git a/mobile/lib/models/ml/face/face.dart b/mobile/lib/models/ml/face/face.dart index 5e4c554e3b..e1c04821b1 100644 --- a/mobile/lib/models/ml/face/face.dart +++ b/mobile/lib/models/ml/face/face.dart @@ -16,7 +16,7 @@ class FileInfo { }); } -class Face { +class Face { final String faceID; final List embedding; Detection detection; @@ -26,7 +26,7 @@ class Face { ///#region Local DB fields // This is not stored on the server, using it for local DB row FileInfo? fileInfo; - final int fileID; + final T fileID; ///#endregion @@ -48,7 +48,7 @@ class Face { factory Face.fromFaceResult( FaceResult faceResult, - int fileID, + T fileID, Dimensions decodedDimensions, ) { final detection = Detection( @@ -81,7 +81,7 @@ class Face { ); } - factory Face.empty(int fileID, {bool error = false}) { + static Face empty(T fileID, {bool error = false}) { return Face( "${fileID}_0_0_0_0", fileID, @@ -92,9 +92,9 @@ class Face { ); } - factory Face.fromJson(Map json) { + static Face fromJson(Map json) { final String faceID = json['faceID'] as String; - final int fileID = getFileIdFromFaceId(faceID); + final T fileID = getFileIdFromFaceId(faceID); return Face( faceID, fileID, diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index f7a7e630e2..bb7a97568a 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -162,7 +162,8 @@ class FaceDetectionRelative extends Detection { ); } - String toFaceID({required int fileID}) { + // todo: neeraj review later #localgallery + String toFaceID({required dynamic fileID}) { // Assert that the values are within the expected range assert( (xMinBox >= 0 && xMinBox <= 1) && diff --git a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart index e653751f4f..ab1457b8f8 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart @@ -71,8 +71,8 @@ class FaceRecognitionService { } } - static Future> runFacesPipeline( - int enteFileID, + static Future> runFacesPipeline( + T enteFileID, Image image, Uint8List rawRgbaBytes, int faceDetectionAddress, @@ -131,8 +131,8 @@ class FaceRecognitionService { } /// Runs face recognition on the given image data. - static Future> _detectFacesSync( - int fileID, + static Future> _detectFacesSync( + T fileID, Image image, Uint8List rawRgbaBytes, int interpreterAddress, diff --git a/mobile/lib/services/machine_learning/ml_indexing_isolate.dart b/mobile/lib/services/machine_learning/ml_indexing_isolate.dart index 52da789eb7..2b115a4d45 100644 --- a/mobile/lib/services/machine_learning/ml_indexing_isolate.dart +++ b/mobile/lib/services/machine_learning/ml_indexing_isolate.dart @@ -55,11 +55,11 @@ class MLIndexingIsolate extends SuperIsolate { factory MLIndexingIsolate() => instance; /// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageStatic] in the isolate. - Future analyzeImage( + Future?> analyzeImage( FileMLInstruction instruction, String filePath, ) async { - late MLResult result; + late MLResult result; try { final resultJsonString = @@ -78,7 +78,7 @@ class MLIndexingIsolate extends SuperIsolate { } return null; } - result = MLResult.fromJsonString(resultJsonString); + result = MLResult.fromJsonString(resultJsonString); } catch (e, s) { _logger.severe( "Could not analyze image with ID ${instruction.file.uploadedFileID} \n", diff --git a/mobile/lib/services/machine_learning/ml_result.dart b/mobile/lib/services/machine_learning/ml_result.dart index e38b20097a..d541c8c189 100644 --- a/mobile/lib/services/machine_learning/ml_result.dart +++ b/mobile/lib/services/machine_learning/ml_result.dart @@ -7,8 +7,8 @@ import 'package:photos/services/machine_learning/face_ml/face_alignment/alignmen import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; -class MLResult { - int fileId; +class MLResult { + T fileId; List? faces = []; ClipResult? clip; @@ -20,14 +20,14 @@ class MLResult { bool get clipRan => clip != null; MLResult({ - this.fileId = -1, + required this.fileId, this.faces, this.clip, this.decodedImageSize = const Dimensions(width: -1, height: -1), }); MLResult.fromEnteFileID( - fileID, { + T fileID, { this.decodedImageSize = const Dimensions(width: -1, height: -1), }) : fileId = fileID; @@ -43,7 +43,9 @@ class MLResult { String toJsonString() => jsonEncode(_toJson()); - static MLResult _fromJson(Map json) { + static MLResult _fromJson( + Map json, + ) { return MLResult( fileId: json['fileId'], faces: json['faces'] != null @@ -70,13 +72,13 @@ class MLResult { ); } - static MLResult fromJsonString(String jsonString) { - return _fromJson(jsonDecode(jsonString)); + static MLResult fromJsonString(String jsonString) { + return _fromJson(jsonDecode(jsonString)); } } -class ClipResult { - final int fileID; +class ClipResult { + final T fileID; final Embedding embedding; ClipResult({ @@ -97,12 +99,12 @@ class ClipResult { } } -class FaceResult { +class FaceResult { late FaceDetectionRelative detection; late double blurValue; late AlignmentResult alignment; late Embedding embedding; - late int fileId; + late T fileId; late String faceId; bool get isBlurry => blurValue < kLaplacianHardThreshold; @@ -118,7 +120,7 @@ class FaceResult { FaceResult.fromFaceDetection( FaceDetectionRelative faceDetection, - int fileID, + T fileID, ) { fileId = fileID; faceId = faceDetection.toFaceID(fileID: fileID); @@ -146,7 +148,7 @@ class FaceResult { } } -T getFileIdFromFaceId(String faceId) { +T getFileIdFromFaceId(String faceId) { final String faceIdSplit = faceId.substring(0, faceId.indexOf('_')); if (T == int) { return int.parse(faceIdSplit) as T; diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 0961b50e4c..49d50610a8 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -10,6 +10,7 @@ import "package:photos/db/files_db.dart"; import "package:photos/db/ml/db.dart"; import "package:photos/events/machine_learning_control_event.dart"; import "package:photos/events/people_changed_event.dart"; +import "package:photos/models/ml/clip.dart"; import "package:photos/models/ml/face/face.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/service_locator.dart"; @@ -20,7 +21,6 @@ import "package:photos/services/machine_learning/face_ml/face_clustering/face_db import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/machine_learning/ml_indexing_isolate.dart"; import 'package:photos/services/machine_learning/ml_result.dart'; -import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; import "package:photos/utils/ml_util.dart"; import "package:photos/utils/network_util.dart"; import "package:photos/utils/ram_check_util.dart"; @@ -428,13 +428,14 @@ class MLService { } } - Future processImage(FileMLInstruction instruction) async { + Future processImage(FileMLInstruction instruction) async { bool actuallyRanML = false; try { final String filePath = await getImagePathForML(instruction.file); - final MLResult? result = await MLIndexingIsolate.instance.analyzeImage( + final MLResult? result = + await MLIndexingIsolate.instance.analyzeImage( instruction, filePath, ); @@ -502,9 +503,13 @@ class MLService { // Storing results locally if (result.facesRan) await mlDataDB.bulkInsertFaces(faces); if (result.clipRan) { - await SemanticSearchService.instance.storeClipImageResult( - result.clip!, - ); + await mlDataDB.putClip([ + ClipEmbedding( + fileID: result.clip!.fileID, + embedding: result.clip!.embedding, + version: clipMlVersion, + ), + ]); } _logger.info("ML results for fileID ${result.fileId} stored locally"); return actuallyRanML; @@ -526,9 +531,9 @@ class MLService { await mlDataDB.bulkInsertFaces( [Face.empty(instruction.file.uploadedFileID!, error: true)], ); - await SemanticSearchService.instance.storeEmptyClipImageResult( - instruction.file, - ); + + await mlDataDB + .putClip([ClipEmbedding.empty(instruction.file.uploadedFileID!)]); return true; } _logger.severe( diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index cf3930a782..11570bec62 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -12,7 +12,6 @@ import "package:photos/db/ml/db.dart"; import 'package:photos/events/embedding_updated_event.dart'; import "package:photos/models/file/file.dart"; import "package:photos/models/ml/clip.dart"; -import "package:photos/models/ml/ml_versions.dart"; import "package:photos/service_locator.dart"; import "package:photos/services/collections_service.dart"; import "package:photos/services/machine_learning/ml_computer.dart"; @@ -235,20 +234,6 @@ class SemanticSearchService { _logger.info("Clip text model loaded"); } - Future storeClipImageResult(ClipResult clipResult) async { - final embedding = ClipEmbedding( - fileID: clipResult.fileID, - embedding: clipResult.embedding, - version: clipMlVersion, - ); - await mlDataDB.putClip([embedding]); - } - - Future storeEmptyClipImageResult(EnteFile entefile) async { - final embedding = ClipEmbedding.empty(entefile.uploadedFileID!); - await mlDataDB.putClip([embedding]); - } - Future> _getTextEmbedding(String query) async { _logger.info("Searching for ${kDebugMode ? query : ''}"); final cachedResult = _queryEmbeddingCache.get(query); @@ -292,8 +277,8 @@ class SemanticSearchService { }); } - static Future runClipImage( - int enteFileID, + static Future> runClipImage( + T enteFileID, Image image, Uint8List rawRgbaBytes, int clipImageAddress, diff --git a/mobile/lib/utils/ml_util.dart b/mobile/lib/utils/ml_util.dart index d17f2f7ecc..6cd7556aaa 100644 --- a/mobile/lib/utils/ml_util.dart +++ b/mobile/lib/utils/ml_util.dart @@ -221,7 +221,7 @@ Stream> fetchEmbeddingsAndInstructions( fileMl.getClipEmbeddingIfCompatible(clipMlVersion); if (remoteClipEmbedding != null) { clipEmbeddings.add( - ClipEmbedding( + ClipEmbedding( fileID: fileMl.fileID, embedding: remoteClipEmbedding.embedding, version: remoteClipEmbedding.version, @@ -395,9 +395,9 @@ void normalizeEmbedding(List embedding) { } } -Future analyzeImageStatic(Map args) async { +Future> analyzeImageStatic(Map args) async { try { - final int enteFileID = args["enteFileID"] as int; + final T enteFileID = args["enteFileID"] as T; final String imagePath = args["filePath"] as String; final bool runFaces = args["runFaces"] as bool; final bool runClip = args["runClip"] as bool; @@ -422,7 +422,7 @@ Future analyzeImageStatic(Map args) async { String faceMsString = "", clipMsString = ""; final pipelines = await Future.wait([ runFaces - ? FaceRecognitionService.runFacesPipeline( + ? FaceRecognitionService.runFacesPipeline( enteFileID, image, rawRgbaBytes, @@ -435,7 +435,7 @@ Future analyzeImageStatic(Map args) async { }) : Future.value(null), runClip - ? SemanticSearchService.runClipImage( + ? SemanticSearchService.runClipImage( enteFileID, image, rawRgbaBytes,