[mob][photos] Only download models on wifi
This commit is contained in:
@@ -158,8 +158,8 @@ class MLComputer {
|
||||
}
|
||||
|
||||
Future<List<double>> runClipText(String query) async {
|
||||
await _ensureLoadedClipTextModel();
|
||||
try {
|
||||
await _ensureLoadedClipTextModel();
|
||||
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
|
||||
final textEmbedding = await _runInIsolate(
|
||||
(
|
||||
@@ -195,9 +195,10 @@ class MLComputer {
|
||||
|
||||
// Load ClipText model
|
||||
final String modelName = ClipTextEncoder.instance.modelName;
|
||||
final String modelRemotePath = ClipTextEncoder.instance.modelRemotePath;
|
||||
final String modelPath =
|
||||
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
final String? modelPath = await ClipTextEncoder.instance.downloadModelSafe();
|
||||
if (modelPath == null) {
|
||||
throw Exception("Could not download clip text model, no wifi");
|
||||
}
|
||||
final address = await _runInIsolate(
|
||||
(
|
||||
MLComputerOperation.loadModel,
|
||||
|
||||
@@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/services/machine_learning/onnx_env.dart";
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:photos/utils/network_util.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
abstract class MlModel {
|
||||
@@ -32,6 +33,7 @@ abstract class MlModel {
|
||||
bool _isNativePluginInitialized = false;
|
||||
int _nativePluginSessionIndex = -1;
|
||||
|
||||
/// WARNING: If [downloadModel] was not first called, this method will download the model first using high bandwidth.
|
||||
Future<(String, String)> getModelNameAndPath() async {
|
||||
return _downloadModelLock.synchronized(() async {
|
||||
final path =
|
||||
@@ -40,12 +42,29 @@ abstract class MlModel {
|
||||
});
|
||||
}
|
||||
|
||||
Future<void> downloadModel([bool forceRefresh = false]) async {
|
||||
Future<String?> downloadModelSafe() async {
|
||||
if (await RemoteAssetsService.instance.hasAsset(modelRemotePath)) {
|
||||
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
} else {
|
||||
if (await canUseHighBandwidth()) {
|
||||
return await downloadModel();
|
||||
} else {
|
||||
logger.warning(
|
||||
'Cannot return model path as it is not available locally and high bandwidth is not available.',
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Future<String> downloadModel([bool forceRefresh = false]) async {
|
||||
return _downloadModelLock.synchronized(() async {
|
||||
if (forceRefresh) {
|
||||
await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath);
|
||||
final file = await RemoteAssetsService.instance
|
||||
.getAssetIfUpdated(modelRemotePath);
|
||||
return file!.path;
|
||||
} else {
|
||||
await RemoteAssetsService.instance.getAsset(modelRemotePath);
|
||||
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -171,6 +171,8 @@ class MLService {
|
||||
'stopping indexing because user is not connected to wifi',
|
||||
);
|
||||
break;
|
||||
} else {
|
||||
await _ensureDownloadedModels();
|
||||
}
|
||||
final futures = <Future<bool>>[];
|
||||
for (final instruction in chunk) {
|
||||
@@ -500,6 +502,11 @@ class MLService {
|
||||
_logger.finest("Models already downloaded");
|
||||
return;
|
||||
}
|
||||
final goodInternet = await canUseHighBandwidth();
|
||||
if (!goodInternet) {
|
||||
_logger.info("Cannot download models because user is not connected to wifi");
|
||||
return;
|
||||
}
|
||||
_logger.info('Downloading models');
|
||||
await Future.wait([
|
||||
FaceDetectionService.instance.downloadModel(forceRefresh),
|
||||
|
||||
Reference in New Issue
Block a user