From 6adbf176303ec90dd21eb12924be4334d38639eb Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 22 Aug 2024 14:52:09 +0200 Subject: [PATCH] [mob][photos] Only download models on wifi --- .../machine_learning/ml_computer.dart | 9 ++++--- .../services/machine_learning/ml_model.dart | 25 ++++++++++++++++--- .../services/machine_learning/ml_service.dart | 7 ++++++ 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/mobile/lib/services/machine_learning/ml_computer.dart b/mobile/lib/services/machine_learning/ml_computer.dart index ed9cc719d4..afbd511c62 100644 --- a/mobile/lib/services/machine_learning/ml_computer.dart +++ b/mobile/lib/services/machine_learning/ml_computer.dart @@ -158,8 +158,8 @@ class MLComputer { } Future> runClipText(String query) async { - await _ensureLoadedClipTextModel(); try { + await _ensureLoadedClipTextModel(); final int clipAddress = ClipTextEncoder.instance.sessionAddress; final textEmbedding = await _runInIsolate( ( @@ -195,9 +195,10 @@ class MLComputer { // Load ClipText model final String modelName = ClipTextEncoder.instance.modelName; - final String modelRemotePath = ClipTextEncoder.instance.modelRemotePath; - final String modelPath = - await RemoteAssetsService.instance.getAssetPath(modelRemotePath); + final String? modelPath = await ClipTextEncoder.instance.downloadModelSafe(); + if (modelPath == null) { + throw Exception("Could not download clip text model, no wifi"); + } final address = await _runInIsolate( ( MLComputerOperation.loadModel, diff --git a/mobile/lib/services/machine_learning/ml_model.dart b/mobile/lib/services/machine_learning/ml_model.dart index daad0ad61a..a3a05e5f14 100644 --- a/mobile/lib/services/machine_learning/ml_model.dart +++ b/mobile/lib/services/machine_learning/ml_model.dart @@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/machine_learning/onnx_env.dart"; import "package:photos/services/remote_assets_service.dart"; +import "package:photos/utils/network_util.dart"; import "package:synchronized/synchronized.dart"; abstract class MlModel { @@ -32,6 +33,7 @@ abstract class MlModel { bool _isNativePluginInitialized = false; int _nativePluginSessionIndex = -1; + /// WARNING: If [downloadModel] was not first called, this method will download the model first using high bandwidth. Future<(String, String)> getModelNameAndPath() async { return _downloadModelLock.synchronized(() async { final path = @@ -40,12 +42,29 @@ abstract class MlModel { }); } - Future downloadModel([bool forceRefresh = false]) async { + Future downloadModelSafe() async { + if (await RemoteAssetsService.instance.hasAsset(modelRemotePath)) { + return await RemoteAssetsService.instance.getAssetPath(modelRemotePath); + } else { + if (await canUseHighBandwidth()) { + return await downloadModel(); + } else { + logger.warning( + 'Cannot return model path as it is not available locally and high bandwidth is not available.', + ); + return null; + } + } + } + + Future downloadModel([bool forceRefresh = false]) async { return _downloadModelLock.synchronized(() async { if (forceRefresh) { - await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath); + final file = await RemoteAssetsService.instance + .getAssetIfUpdated(modelRemotePath); + return file!.path; } else { - await RemoteAssetsService.instance.getAsset(modelRemotePath); + return await RemoteAssetsService.instance.getAssetPath(modelRemotePath); } }); } diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 7fb523a917..d37e358dd8 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -171,6 +171,8 @@ class MLService { 'stopping indexing because user is not connected to wifi', ); break; + } else { + await _ensureDownloadedModels(); } final futures = >[]; for (final instruction in chunk) { @@ -500,6 +502,11 @@ class MLService { _logger.finest("Models already downloaded"); return; } + final goodInternet = await canUseHighBandwidth(); + if (!goodInternet) { + _logger.info("Cannot download models because user is not connected to wifi"); + return; + } _logger.info('Downloading models'); await Future.wait([ FaceDetectionService.instance.downloadModel(forceRefresh),