diff --git a/mobile/lib/services/machine_learning/ml_model.dart b/mobile/lib/services/machine_learning/ml_model.dart index eef288423e..384ca2b4c5 100644 --- a/mobile/lib/services/machine_learning/ml_model.dart +++ b/mobile/lib/services/machine_learning/ml_model.dart @@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/machine_learning/onnx_env.dart"; import "package:photos/services/remote_assets_service.dart"; +import "package:synchronized/synchronized.dart"; abstract class MlModel { static final Logger isolateLogger = Logger("MlModelInIsolate"); @@ -16,6 +17,8 @@ abstract class MlModel { String get modelName; + final _downloadModelLock = Lock(); + static final bool usePlatformPlugin = Platform.isAndroid; bool get isInitialized => @@ -32,9 +35,21 @@ abstract class MlModel { int _nativePluginSessionIndex = -1; Future<(String, String)> getModelNameAndPath() async { - final path = - await RemoteAssetsService.instance.getAssetPath(modelRemotePath); - return (modelName, path); + return _downloadModelLock.synchronized(() async { + final path = + await RemoteAssetsService.instance.getAssetPath(modelRemotePath); + return (modelName, path); + }); + } + + Future downloadModel([bool forceRefresh = false]) async { + return _downloadModelLock.synchronized(() async { + if (forceRefresh) { + await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath); + } else { + await RemoteAssetsService.instance.getAsset(modelRemotePath); + } + }); } void storeSessionAddress(int address) { diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 31887a7bf0..4235d3cb36 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -43,8 +43,10 @@ class MLService { factory MLService() => instance; final _initModelLock = Lock(); + final _downloadModelLock = Lock(); bool _isInitialized = false; + bool areModelsDownloaded = false; late String client; @@ -52,8 +54,6 @@ class MLService { bool get showClusteringIsHappening => _showClusteringIsHappening; - bool modelsAreLoading = false; - bool debugIndexingDisabled = false; bool _showClusteringIsHappening = false; bool _mlControllerStatus = false; @@ -64,7 +64,7 @@ class MLService { static const _kForceClusteringFaceCount = 8000; /// Only call this function once at app startup, after that you can directly call [runAllML] - Future init() async { + Future init({bool firstTime = false}) async { if (localSettings.isFaceIndexingEnabled == false || _isInitialized) { return; } @@ -78,6 +78,9 @@ class MLService { // Activate FaceRecognitionService await FaceRecognitionService.instance.init(); + // Download models if not already downloaded + unawaited(_ensureDownloadedModels(firstTime)); + // Listen on MachineLearningController Bus.instance.on().listen((event) { if (localSettings.isFaceIndexingEnabled == false) { @@ -509,6 +512,25 @@ class MLService { return actuallyRanML; } + Future _ensureDownloadedModels([bool forceRefresh = false]) async { + if (_downloadModelLock.locked) { + _logger.finest("Download models already in progress"); + } + return _downloadModelLock.synchronized(() async { + if (areModelsDownloaded) { + _logger.finest("Models already downloaded"); + return; + } + _logger.info('Downloading models'); + await Future.wait([ + FaceDetectionService.instance.downloadModel(forceRefresh), + FaceEmbeddingService.instance.downloadModel(forceRefresh), + ClipImageEncoder.instance.downloadModel(forceRefresh), + ]); + areModelsDownloaded = true; + }); + } + Future _ensureLoadedModels(FileMLInstruction instruction) async { return _initModelLock.synchronized(() async { final faceDetectionLoaded = FaceDetectionService.instance.isInitialized; @@ -522,7 +544,6 @@ class MLService { return; } - modelsAreLoading = true; _logger.info( 'Loading models. faces: $shouldLoadFaces, clip: $shouldLoadClip', ); @@ -530,7 +551,6 @@ class MLService { .loadModels(loadFaces: shouldLoadFaces, loadClip: shouldLoadClip); _logger.info('Models loaded'); _logStatus(); - modelsAreLoading = false; }); } diff --git a/mobile/lib/ui/settings/machine_learning_settings_page.dart b/mobile/lib/ui/settings/machine_learning_settings_page.dart index b7f9158cc5..409f353469 100644 --- a/mobile/lib/ui/settings/machine_learning_settings_page.dart +++ b/mobile/lib/ui/settings/machine_learning_settings_page.dart @@ -119,8 +119,7 @@ class _MachineLearningSettingsPageState onChanged: () async { final isEnabled = await localSettings.toggleFaceIndexing(); if (isEnabled) { - await MLService.instance.init(); - MLService.instance.downloadModels().ignore(); + await MLService.instance.init(firstTime: true); await SemanticSearchService.instance.init(); unawaited(MLService.instance.runAllML(force: true)); @@ -138,9 +137,9 @@ class _MachineLearningSettingsPageState height: 12, ), hasEnabled - ? MLService.instance.modelsAreLoading - ? const ModelLoadingState() - : const MLStatusWidget() + ? MLService.instance.areModelsDownloaded + ? const MLStatusWidget() + : const ModelLoadingState() : const SizedBox.shrink(), ], );