[mob][photos] Make sure models are downloaded

This commit is contained in:
laurenspriem
2024-07-26 12:59:18 +02:00
parent 739235836e
commit 1d0cf3875f
3 changed files with 47 additions and 13 deletions

View File

@@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/machine_learning/onnx_env.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:synchronized/synchronized.dart";
abstract class MlModel {
static final Logger isolateLogger = Logger("MlModelInIsolate");
@@ -16,6 +17,8 @@ abstract class MlModel {
String get modelName;
final _downloadModelLock = Lock();
static final bool usePlatformPlugin = Platform.isAndroid;
bool get isInitialized =>
@@ -32,9 +35,21 @@ abstract class MlModel {
int _nativePluginSessionIndex = -1;
Future<(String, String)> getModelNameAndPath() async {
final path =
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
return (modelName, path);
return _downloadModelLock.synchronized(() async {
final path =
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
return (modelName, path);
});
}
Future<void> downloadModel([bool forceRefresh = false]) async {
return _downloadModelLock.synchronized(() async {
if (forceRefresh) {
await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath);
} else {
await RemoteAssetsService.instance.getAsset(modelRemotePath);
}
});
}
void storeSessionAddress(int address) {

View File

@@ -43,8 +43,10 @@ class MLService {
factory MLService() => instance;
final _initModelLock = Lock();
final _downloadModelLock = Lock();
bool _isInitialized = false;
bool areModelsDownloaded = false;
late String client;
@@ -52,8 +54,6 @@ class MLService {
bool get showClusteringIsHappening => _showClusteringIsHappening;
bool modelsAreLoading = false;
bool debugIndexingDisabled = false;
bool _showClusteringIsHappening = false;
bool _mlControllerStatus = false;
@@ -64,7 +64,7 @@ class MLService {
static const _kForceClusteringFaceCount = 8000;
/// Only call this function once at app startup, after that you can directly call [runAllML]
Future<void> init() async {
Future<void> init({bool firstTime = false}) async {
if (localSettings.isFaceIndexingEnabled == false || _isInitialized) {
return;
}
@@ -78,6 +78,9 @@ class MLService {
// Activate FaceRecognitionService
await FaceRecognitionService.instance.init();
// Download models if not already downloaded
unawaited(_ensureDownloadedModels(firstTime));
// Listen on MachineLearningController
Bus.instance.on<MachineLearningControlEvent>().listen((event) {
if (localSettings.isFaceIndexingEnabled == false) {
@@ -509,6 +512,25 @@ class MLService {
return actuallyRanML;
}
Future<void> _ensureDownloadedModels([bool forceRefresh = false]) async {
if (_downloadModelLock.locked) {
_logger.finest("Download models already in progress");
}
return _downloadModelLock.synchronized(() async {
if (areModelsDownloaded) {
_logger.finest("Models already downloaded");
return;
}
_logger.info('Downloading models');
await Future.wait([
FaceDetectionService.instance.downloadModel(forceRefresh),
FaceEmbeddingService.instance.downloadModel(forceRefresh),
ClipImageEncoder.instance.downloadModel(forceRefresh),
]);
areModelsDownloaded = true;
});
}
Future<void> _ensureLoadedModels(FileMLInstruction instruction) async {
return _initModelLock.synchronized(() async {
final faceDetectionLoaded = FaceDetectionService.instance.isInitialized;
@@ -522,7 +544,6 @@ class MLService {
return;
}
modelsAreLoading = true;
_logger.info(
'Loading models. faces: $shouldLoadFaces, clip: $shouldLoadClip',
);
@@ -530,7 +551,6 @@ class MLService {
.loadModels(loadFaces: shouldLoadFaces, loadClip: shouldLoadClip);
_logger.info('Models loaded');
_logStatus();
modelsAreLoading = false;
});
}

View File

@@ -119,8 +119,7 @@ class _MachineLearningSettingsPageState
onChanged: () async {
final isEnabled = await localSettings.toggleFaceIndexing();
if (isEnabled) {
await MLService.instance.init();
MLService.instance.downloadModels().ignore();
await MLService.instance.init(firstTime: true);
await SemanticSearchService.instance.init();
unawaited(MLService.instance.runAllML(force: true));
@@ -138,9 +137,9 @@ class _MachineLearningSettingsPageState
height: 12,
),
hasEnabled
? MLService.instance.modelsAreLoading
? const ModelLoadingState()
: const MLStatusWidget()
? MLService.instance.areModelsDownloaded
? const MLStatusWidget()
: const ModelLoadingState()
: const SizedBox.shrink(),
],
);