diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 4343a4b7b5..e9bdd74241 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -9,7 +9,6 @@ import "package:computer/computer.dart"; import "package:dart_ui_isolate/dart_ui_isolate.dart"; import "package:flutter/foundation.dart" show debugPrint; import "package:logging/logging.dart"; -import "package:onnxruntime/onnxruntime.dart"; import "package:package_info_plus/package_info_plus.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; @@ -22,7 +21,6 @@ import "package:photos/face/model/detection.dart" as face_detection; import "package:photos/face/model/dimension.dart"; import "package:photos/face/model/face.dart"; import "package:photos/face/model/landmark.dart"; -import "package:photos/models/file/file.dart"; import "package:photos/service_locator.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart"; @@ -35,6 +33,8 @@ import 'package:photos/services/machine_learning/file_ml/file_ml.dart'; import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; import 'package:photos/services/machine_learning/ml_exceptions.dart'; import 'package:photos/services/machine_learning/ml_result.dart'; +import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; +import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/local_settings.dart"; import "package:photos/utils/ml_util.dart"; @@ -192,7 +192,7 @@ class FaceMlService { break outerLoop; } await _ensureReadyForInference(); - futures.add(processImage(instruction.enteFile)); + futures.add(processImage(instruction)); } final awaitedFutures = await Future.wait(futures); final sumFutures = awaitedFutures.fold( @@ -385,22 +385,20 @@ class FaceMlService { } } - Future processImage(EnteFile enteFile) async { + Future processImage(FileMLInstruction instruction) async { // TODO: clean this function up _logger.info( - "`processImage` start processing image with uploadedFileID: ${enteFile.uploadedFileID}", + "`processImage` start processing image with uploadedFileID: ${instruction.enteFile.uploadedFileID}", ); bool actuallyRanML = false; try { final MLResult? result = await _analyzeImageInSingleIsolate( - enteFile, - // preferUsingThumbnailForEverything: false, - // disposeImageIsolateAfterUse: false, + instruction, ); if (result == null) { if (!_shouldPauseIndexingAndClustering) { _logger.severe( - "Failed to analyze image with uploadedFileID: ${enteFile.uploadedFileID}", + "Failed to analyze image with uploadedFileID: ${instruction.enteFile.uploadedFileID}", ); } return actuallyRanML; @@ -410,7 +408,7 @@ class FaceMlService { final List faces = []; if (result.foundNoFaces) { debugPrint( - 'No faces detected for file with name:${enteFile.displayName}', + 'No faces detected for file with name:${instruction.enteFile.displayName}', ); faces.add( Face.empty(result.fileId, error: result.errorOccured), @@ -421,9 +419,9 @@ class FaceMlService { result.decodedImageSize.height == -1) { _logger.severe( "decodedImageSize is not stored correctly for image with " - "ID: ${enteFile.uploadedFileID}"); + "ID: ${instruction.enteFile.uploadedFileID}"); _logger.info( - "Using aligned image size for image with ID: ${enteFile.uploadedFileID}. This size is ${result.decodedImageSize.width}x${result.decodedImageSize.height} compared to size of ${enteFile.width}x${enteFile.height} in the metadata", + "Using aligned image size for image with ID: ${instruction.enteFile.uploadedFileID}. This size is ${result.decodedImageSize.width}x${result.decodedImageSize.height} compared to size of ${instruction.enteFile.width}x${instruction.enteFile.height} in the metadata", ); } for (int i = 0; i < result.faces!.length; ++i) { @@ -463,9 +461,9 @@ class FaceMlService { _logger.info("inserting ${faces.length} faces for ${result.fileId}"); if (!result.errorOccured) { await RemoteFileMLService.instance.putFileEmbedding( - enteFile, + instruction.enteFile, FileMl( - enteFile.uploadedFileID!, + instruction.enteFile.uploadedFileID!, FaceEmbeddings( faces, result.mlVersion, @@ -483,18 +481,30 @@ class FaceMlService { await FaceMLDataDB.instance.bulkInsertFaces(faces); return actuallyRanML; } + + if (result.clipRan) { + actuallyRanML = true; + await SemanticSearchService.storeClipImageResult( + result.clip!, + instruction.enteFile, + ); + } } on ThumbnailRetrievalException catch (e, s) { _logger.severe( - 'ThumbnailRetrievalException while processing image with ID ${enteFile.uploadedFileID}, storing empty face so indexing does not get stuck', + 'ThumbnailRetrievalException while processing image with ID ${instruction.enteFile.uploadedFileID}, storing empty face so indexing does not get stuck', e, s, ); - await FaceMLDataDB.instance - .bulkInsertFaces([Face.empty(enteFile.uploadedFileID!, error: true)]); + await FaceMLDataDB.instance.bulkInsertFaces( + [Face.empty(instruction.enteFile.uploadedFileID!, error: true)], + ); + await SemanticSearchService.storeEmptyClipImageResult( + instruction.enteFile, + ); return true; } catch (e, s) { _logger.severe( - "Failed to analyze using FaceML for image with ID: ${enteFile.uploadedFileID}. Not storing any faces, which means it will be automatically retried later.", + "Failed to analyze using FaceML for image with ID: ${instruction.enteFile.uploadedFileID}. Not storing any faces, which means it will be automatically retried later.", e, s, ); @@ -524,6 +534,11 @@ class FaceMlService { } catch (e, s) { _logger.severe("Could not initialize mobilefacenet", e, s); } + try { + await ClipImageEncoder.instance.init(); + } catch (e, s) { + _logger.severe("Could not initialize clip image", e, s); + } _isModelsInitialized = true; _logger.info('initModels done'); _logStatus(); @@ -546,7 +561,11 @@ class FaceMlService { } catch (e, s) { _logger.severe("Could not dispose mobilefacenet", e, s); } - OrtEnv.instance.release(); + try { + await ClipImageEncoder.instance.release(); + } catch (e, s) { + _logger.severe("Could not dispose clip image", e, s); + } _isModelsInitialized = false; }); } @@ -681,9 +700,13 @@ class FaceMlService { } /// Analyzes the given image data by running the full pipeline for faces, using [_analyzeImageSync] in the isolate. - Future _analyzeImageInSingleIsolate(EnteFile enteFile) async { - final String filePath = - await getImagePathForML(enteFile, typeOfData: FileDataForML.fileData); + Future _analyzeImageInSingleIsolate( + FileMLInstruction instruction, + ) async { + final String filePath = await getImagePathForML( + instruction.enteFile, + typeOfData: FileDataForML.fileData, + ); final Stopwatch stopwatch = Stopwatch()..start(); late MLResult result; @@ -693,12 +716,15 @@ class FaceMlService { ( FaceMlOperation.analyzeImage, { - "enteFileID": enteFile.uploadedFileID ?? -1, + "enteFileID": instruction.enteFile.uploadedFileID ?? -1, "filePath": filePath, + "runFaces": instruction.shouldRunFaces, + "runClip": instruction.shouldRunClip, "faceDetectionAddress": FaceDetectionService.instance.sessionAddress, "faceEmbeddingAddress": FaceEmbeddingService.instance.sessionAddress, + "clipImageAddress": ClipImageEncoder.instance.sessionAddress, } ), ) as String?; @@ -711,20 +737,21 @@ class FaceMlService { result = MLResult.fromJsonString(resultJsonString); } catch (e, s) { _logger.severe( - "Could not analyze image with ID ${enteFile.uploadedFileID} \n", + "Could not analyze image with ID ${instruction.enteFile.uploadedFileID} \n", e, s, ); debugPrint( - "This image with ID ${enteFile.uploadedFileID} has name ${enteFile.displayName}.", + "This image with ID ${instruction.enteFile.uploadedFileID} has name ${instruction.enteFile.displayName}.", ); - final resultBuilder = MLResult.fromEnteFileID(enteFile.uploadedFileID!) - ..errorOccurred(); + final resultBuilder = + MLResult.fromEnteFileID(instruction.enteFile.uploadedFileID!) + ..errorOccurred(); return resultBuilder; } stopwatch.stop(); _logger.info( - "Finished Analyze image with uploadedFileID ${enteFile.uploadedFileID}, in " + "Finished Analyze image with uploadedFileID ${instruction.enteFile.uploadedFileID}, in " "${stopwatch.elapsedMilliseconds} ms (including time waiting for inference engine availability)", ); @@ -735,8 +762,11 @@ class FaceMlService { try { final int enteFileID = args["enteFileID"] as int; final String imagePath = args["filePath"] as String; + final bool runFaces = args["runFaces"] as bool; + final bool runClip = args["runClip"] as bool; final int faceDetectionAddress = args["faceDetectionAddress"] as int; final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int; + final int clipImageAddress = args["clipImageAddress"] as int; dev.log( "Start analyzing image with uploadedFileID: $enteFileID inside the isolate", @@ -754,18 +784,29 @@ class FaceMlService { final result = MLResult.fromEnteFileID(enteFileID); result.decodedImageSize = decodedImageSize; - final resultFaces = await FaceRecognitionService.runFacesPipeline( - enteFileID, - image, - imageByteData, - faceDetectionAddress, - faceEmbeddingAddress, - ); - if (resultFaces.isEmpty) { - return result..noFaceDetected(); + if (runFaces) { + final resultFaces = await FaceRecognitionService.runFacesPipeline( + enteFileID, + image, + imageByteData, + faceDetectionAddress, + faceEmbeddingAddress, + ); + if (resultFaces.isEmpty) { + return result..noFaceDetected(); + } + result.faces = resultFaces; } - result.faces = resultFaces; + if (runClip) { + final clipResult = await SemanticSearchService.runClipImage( + enteFileID, + image, + imageByteData, + clipImageAddress, + ); + result.clip = clipResult; + } return result; } catch (e, s) { diff --git a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart index b563186b00..3bacee5fd5 100644 --- a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart @@ -1,5 +1,5 @@ -import "dart:io"; -import "dart:typed_data"; +import "dart:typed_data" show ByteData; +import "dart:ui" show Image; import "package:logging/logging.dart"; import "package:onnxruntime/onnxruntime.dart"; @@ -25,17 +25,17 @@ class ClipImageEncoder extends MlModel { static final instance = ClipImageEncoder._privateConstructor(); factory ClipImageEncoder() => instance; - static Future> predict(Map args) async { - final imageData = await File(args["imagePath"]).readAsBytes(); - final image = await decodeImageFromData(imageData); - final ByteData imgByteData = await getByteDataFromImage(image); - - final inputList = await preprocessImageClip(image, imgByteData); + static Future> predict( + Image image, + ByteData imageByteData, + int sessionAddress, + ) async { + final inputList = await preprocessImageClip(image, imageByteData); final inputOrt = OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 224, 224]); final inputs = {'input': inputOrt}; - final session = OrtSession.fromAddress(args["address"]); + final session = OrtSession.fromAddress(sessionAddress); final runOptions = OrtRunOptions(); final outputs = session.run(runOptions, inputs); final embedding = (outputs[0]?.value as List>)[0]; diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 0c717d9791..4cfd7b9d20 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -2,9 +2,11 @@ import "dart:async"; import "dart:collection"; import "dart:developer" as dev show log; import "dart:math" show min; +import "dart:typed_data" show ByteData; +import "dart:ui" show Image; import "package:computer/computer.dart"; -import "package:flutter/services.dart"; +import "package:flutter/services.dart" show PlatformException; import "package:logging/logging.dart"; import "package:photos/core/cache/lru_map.dart"; import "package:photos/core/configuration.dart"; @@ -17,14 +19,13 @@ import "package:photos/events/file_uploaded_event.dart"; import "package:photos/events/machine_learning_control_event.dart"; import "package:photos/models/embedding.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/ml_versions.dart"; import "package:photos/services/collections_service.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart"; +import "package:photos/services/machine_learning/ml_result.dart"; +import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; import 'package:photos/services/machine_learning/semantic_search/embedding_store.dart'; -import 'package:photos/services/machine_learning/semantic_search/frameworks/ggml.dart'; -import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart'; -import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart'; import "package:photos/utils/debouncer.dart"; -import "package:photos/utils/device_info.dart"; import "package:photos/utils/local_settings.dart"; import "package:photos/utils/ml_util.dart"; // import "package:photos/utils/thumbnail_util.dart"; @@ -44,12 +45,9 @@ class SemanticSearchService { final _logger = Logger("SemanticSearchService"); final _queue = Queue(); - final _frameworkInitialization = Completer(); final _embeddingLoaderDebouncer = Debouncer(kDebounceDuration, executionInterval: kDebounceDuration); - late Model _currentModel; - late MLFramework _mlFramework; bool _hasInitialized = false; bool _isComputingEmbeddings = false; bool _isSyncing = false; @@ -70,12 +68,6 @@ class SemanticSearchService { return; } _hasInitialized = true; - final shouldDownloadOverMobileData = - Configuration.instance.shouldBackupOverMobileData(); - _currentModel = await _getCurrentModel(); - _mlFramework = _currentModel == Model.onnxClip - ? ONNX(shouldDownloadOverMobileData) - : GGML(shouldDownloadOverMobileData); await EmbeddingStore.instance.init(); await EmbeddingsDB.instance.init(); await _loadEmbeddings(); @@ -115,19 +107,12 @@ class SemanticSearchService { }); } - Future release() async { - if (_frameworkInitialization.isCompleted) { - await _mlFramework.release(); - } - } - Future sync() async { if (_isSyncing) { return; } _isSyncing = true; - final fetchCompleted = - await EmbeddingStore.instance.pullEmbeddings(_currentModel); + final fetchCompleted = await EmbeddingStore.instance.pullEmbeddings(); if (fetchCompleted) { await _backFill(); } @@ -183,14 +168,14 @@ class SemanticSearchService { } Future clearIndexes() async { - await EmbeddingStore.instance.clearEmbeddings(_currentModel); - _logger.info("Indexes cleared for $_currentModel"); + await EmbeddingStore.instance.clearEmbeddings(); + _logger.info("Indexes cleared"); } Future _loadEmbeddings() async { _logger.info("Pulling cached embeddings"); final startTime = DateTime.now(); - _cachedEmbeddings = await EmbeddingsDB.instance.getAll(_currentModel); + _cachedEmbeddings = await EmbeddingsDB.instance.getAll(); final endTime = DateTime.now(); _logger.info( "Loading ${_cachedEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", @@ -219,10 +204,10 @@ class SemanticSearchService { Future> _getFileIDsToBeIndexed() async { final uploadedFileIDs = await getIndexableFileIDs(); - final embeddedFileIDs = - await EmbeddingsDB.instance.getFileIDs(_currentModel); + final embeddedFileIDs = await EmbeddingsDB.instance.getIndexedFileIds(); + embeddedFileIDs.removeWhere((key, value) => value < clipMlVersion); - return uploadedFileIDs.difference(embeddedFileIDs).toList(); + return uploadedFileIDs.difference(embeddedFileIDs.keys.toSet()).toList(); } Future clearQueue() async { @@ -386,7 +371,6 @@ class SemanticSearchService { final embedding = Embedding( fileID: file.uploadedFileID!, - model: _currentModel, embedding: result, ); await EmbeddingStore.instance.storeEmbedding( @@ -398,7 +382,7 @@ class SemanticSearchService { "Could not get embedding for $file because FormatException occured, storing empty result locally", e, ); - final embedding = Embedding.empty(file.uploadedFileID!, _currentModel); + final embedding = Embedding.empty(file.uploadedFileID!); await EmbeddingsDB.instance.put(embedding); } on PlatformException catch (e, s) { _logger.severe( @@ -406,13 +390,32 @@ class SemanticSearchService { e, s, ); - final embedding = Embedding.empty(file.uploadedFileID!, _currentModel); + final embedding = Embedding.empty(file.uploadedFileID!); await EmbeddingsDB.instance.put(embedding); } catch (e, s) { _logger.severe(e, s); } } + static Future storeClipImageResult( + ClipResult clipResult, + EnteFile entefile, + ) async { + final embedding = Embedding( + fileID: clipResult.fileID, + embedding: clipResult.embedding, + ); + await EmbeddingStore.instance.storeEmbedding( + entefile, + embedding, + ); + } + + static Future storeEmptyClipImageResult(EnteFile entefile) async { + final embedding = Embedding.empty(entefile.uploadedFileID!); + await EmbeddingsDB.instance.put(embedding); + } + Future> _getTextEmbedding(String query) async { _logger.info("Searching for " + query); final cachedResult = _queryCache.get(query); @@ -453,14 +456,6 @@ class SemanticSearchService { return queryResults; } - Future _getCurrentModel() async { - if (await isGrapheneOS()) { - return Model.ggmlClip; - } else { - return Model.onnxClip; - } - } - void _startIndexing() { _logger.info("Start indexing"); if (!_mlController.isCompleted) { @@ -474,6 +469,19 @@ class SemanticSearchService { _mlController = Completer(); } } + + static Future runClipImage( + int enteFileID, + Image image, + ByteData imageByteData, + int clipImageAddress, + ) async { + final embedding = + await ClipImageEncoder.predict(image, imageByteData, clipImageAddress); + final clipResult = ClipResult(fileID: enteFileID, embedding: embedding); + + return clipResult; + } } List computeBulkSimilarities(Map args) {