[mob][photos] Only download models on wifi

This commit is contained in:
laurenspriem
2024-08-22 14:52:09 +02:00
parent 2ffad1f2a2
commit 6adbf17630
3 changed files with 34 additions and 7 deletions

View File

@@ -158,8 +158,8 @@ class MLComputer {
}
Future<List<double>> runClipText(String query) async {
await _ensureLoadedClipTextModel();
try {
await _ensureLoadedClipTextModel();
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
final textEmbedding = await _runInIsolate(
(
@@ -195,9 +195,10 @@ class MLComputer {
// Load ClipText model
final String modelName = ClipTextEncoder.instance.modelName;
final String modelRemotePath = ClipTextEncoder.instance.modelRemotePath;
final String modelPath =
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
final String? modelPath = await ClipTextEncoder.instance.downloadModelSafe();
if (modelPath == null) {
throw Exception("Could not download clip text model, no wifi");
}
final address = await _runInIsolate(
(
MLComputerOperation.loadModel,

View File

@@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/machine_learning/onnx_env.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/network_util.dart";
import "package:synchronized/synchronized.dart";
abstract class MlModel {
@@ -32,6 +33,7 @@ abstract class MlModel {
bool _isNativePluginInitialized = false;
int _nativePluginSessionIndex = -1;
/// WARNING: If [downloadModel] was not first called, this method will download the model first using high bandwidth.
Future<(String, String)> getModelNameAndPath() async {
return _downloadModelLock.synchronized(() async {
final path =
@@ -40,12 +42,29 @@ abstract class MlModel {
});
}
Future<void> downloadModel([bool forceRefresh = false]) async {
Future<String?> downloadModelSafe() async {
if (await RemoteAssetsService.instance.hasAsset(modelRemotePath)) {
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
} else {
if (await canUseHighBandwidth()) {
return await downloadModel();
} else {
logger.warning(
'Cannot return model path as it is not available locally and high bandwidth is not available.',
);
return null;
}
}
}
Future<String> downloadModel([bool forceRefresh = false]) async {
return _downloadModelLock.synchronized(() async {
if (forceRefresh) {
await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath);
final file = await RemoteAssetsService.instance
.getAssetIfUpdated(modelRemotePath);
return file!.path;
} else {
await RemoteAssetsService.instance.getAsset(modelRemotePath);
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
}
});
}

View File

@@ -171,6 +171,8 @@ class MLService {
'stopping indexing because user is not connected to wifi',
);
break;
} else {
await _ensureDownloadedModels();
}
final futures = <Future<bool>>[];
for (final instruction in chunk) {
@@ -500,6 +502,11 @@ class MLService {
_logger.finest("Models already downloaded");
return;
}
final goodInternet = await canUseHighBandwidth();
if (!goodInternet) {
_logger.info("Cannot download models because user is not connected to wifi");
return;
}
_logger.info('Downloading models');
await Future.wait([
FaceDetectionService.instance.downloadModel(forceRefresh),