From fcd1ec3b05da37d742ab1b5f4bd1ada845de225c Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Wed, 24 Jul 2024 17:28:45 +0530 Subject: [PATCH] [mob] Download models before interence --- .../services/machine_learning/ml_model.dart | 4 +++ .../services/machine_learning/ml_service.dart | 31 ++++++++++++++++++- .../machine_learning_settings_page.dart | 4 ++- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/mobile/lib/services/machine_learning/ml_model.dart b/mobile/lib/services/machine_learning/ml_model.dart index 1127f57d27..261a741c5e 100644 --- a/mobile/lib/services/machine_learning/ml_model.dart +++ b/mobile/lib/services/machine_learning/ml_model.dart @@ -38,6 +38,10 @@ abstract class MlModel { } } + Future downloadModel() async { + await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath); + } + Future _loadModelWithEntePlugin( String modelName, String modelPath, diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index c95ba7920b..257d9c074d 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -33,7 +33,9 @@ import "package:photos/services/machine_learning/file_ml/remote_fileml_service.d import 'package:photos/services/machine_learning/ml_exceptions.dart'; import 'package:photos/services/machine_learning/ml_result.dart'; import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; +import "package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart"; import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; +import "package:photos/services/remote_assets_service.dart"; import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/ml_util.dart"; import "package:photos/utils/network_util.dart"; @@ -65,11 +67,13 @@ class MLService { factory MLService() => instance; final _initModelLock = Lock(); + final _downloadModelLock = Lock(); final _functionLock = Lock(); final _initIsolateLock = Lock(); bool _isInitialized = false; bool _isModelsInitialized = false; + bool areModelDownloaded = false; bool _isModelsInitUsingEntePlugin = false; bool _isIsolateSpawned = false; @@ -525,6 +529,31 @@ class MLService { return actuallyRanML; } + Future downloadModels() { + if (areModelDownloaded) { + _logger.finest("Models already downloaded"); + return Future.value(); + } + if (_downloadModelLock.locked) { + _logger.finest("Download models already in progress"); + } + return _downloadModelLock.synchronized(() async { + if (areModelDownloaded) { + return; + } + _logger.info('Downloading models'); + await Future.wait([ + FaceDetectionService.instance.downloadModel(), + FaceEmbeddingService.instance.downloadModel(), + ClipImageEncoder.instance.downloadModel(), + ClipImageEncoder.instance.downloadModel(), + RemoteAssetsService.instance + .getAsset(ClipTextTokenizer.kVocabRemotePath), + ]); + areModelDownloaded = true; + }); + } + Future _initModelsUsingFfiBasedPlugin() async { return _initModelLock.synchronized(() async { if (_isModelsInitialized) return; @@ -632,7 +661,7 @@ class MLService { Future _ensureReadyForInference() async { await _initIsolate(); - await _initModelsUsingFfiBasedPlugin(); + await downloadModels(); if (Platform.isAndroid) { await _initModelUsingEntePlugin(); } else { diff --git a/mobile/lib/ui/settings/machine_learning_settings_page.dart b/mobile/lib/ui/settings/machine_learning_settings_page.dart index 7482057588..65cb61e5e2 100644 --- a/mobile/lib/ui/settings/machine_learning_settings_page.dart +++ b/mobile/lib/ui/settings/machine_learning_settings_page.dart @@ -120,6 +120,8 @@ class _MachineLearningSettingsPageState final isEnabled = await localSettings.toggleFaceIndexing(); if (isEnabled) { await MLService.instance.init(); + MLService.instance.downloadModels().ignore(); + await SemanticSearchService.instance.init(); unawaited(MLService.instance.runAllML(force: true)); } else {} @@ -136,7 +138,7 @@ class _MachineLearningSettingsPageState height: 12, ), hasEnabled - ? MLService.instance.allModelsLoaded + ? MLService.instance.areModelDownloaded ? const MLStatusWidget() : const ModelLoadingState() : const SizedBox.shrink(),