[mob][photos] Make sure models are downloaded
This commit is contained in:
@@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/services/machine_learning/onnx_env.dart";
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
abstract class MlModel {
|
||||
static final Logger isolateLogger = Logger("MlModelInIsolate");
|
||||
@@ -16,6 +17,8 @@ abstract class MlModel {
|
||||
|
||||
String get modelName;
|
||||
|
||||
final _downloadModelLock = Lock();
|
||||
|
||||
static final bool usePlatformPlugin = Platform.isAndroid;
|
||||
|
||||
bool get isInitialized =>
|
||||
@@ -32,9 +35,21 @@ abstract class MlModel {
|
||||
int _nativePluginSessionIndex = -1;
|
||||
|
||||
Future<(String, String)> getModelNameAndPath() async {
|
||||
final path =
|
||||
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
return (modelName, path);
|
||||
return _downloadModelLock.synchronized(() async {
|
||||
final path =
|
||||
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
return (modelName, path);
|
||||
});
|
||||
}
|
||||
|
||||
Future<void> downloadModel([bool forceRefresh = false]) async {
|
||||
return _downloadModelLock.synchronized(() async {
|
||||
if (forceRefresh) {
|
||||
await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath);
|
||||
} else {
|
||||
await RemoteAssetsService.instance.getAsset(modelRemotePath);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void storeSessionAddress(int address) {
|
||||
|
||||
@@ -43,8 +43,10 @@ class MLService {
|
||||
factory MLService() => instance;
|
||||
|
||||
final _initModelLock = Lock();
|
||||
final _downloadModelLock = Lock();
|
||||
|
||||
bool _isInitialized = false;
|
||||
bool areModelsDownloaded = false;
|
||||
|
||||
late String client;
|
||||
|
||||
@@ -52,8 +54,6 @@ class MLService {
|
||||
|
||||
bool get showClusteringIsHappening => _showClusteringIsHappening;
|
||||
|
||||
bool modelsAreLoading = false;
|
||||
|
||||
bool debugIndexingDisabled = false;
|
||||
bool _showClusteringIsHappening = false;
|
||||
bool _mlControllerStatus = false;
|
||||
@@ -64,7 +64,7 @@ class MLService {
|
||||
static const _kForceClusteringFaceCount = 8000;
|
||||
|
||||
/// Only call this function once at app startup, after that you can directly call [runAllML]
|
||||
Future<void> init() async {
|
||||
Future<void> init({bool firstTime = false}) async {
|
||||
if (localSettings.isFaceIndexingEnabled == false || _isInitialized) {
|
||||
return;
|
||||
}
|
||||
@@ -78,6 +78,9 @@ class MLService {
|
||||
// Activate FaceRecognitionService
|
||||
await FaceRecognitionService.instance.init();
|
||||
|
||||
// Download models if not already downloaded
|
||||
unawaited(_ensureDownloadedModels(firstTime));
|
||||
|
||||
// Listen on MachineLearningController
|
||||
Bus.instance.on<MachineLearningControlEvent>().listen((event) {
|
||||
if (localSettings.isFaceIndexingEnabled == false) {
|
||||
@@ -509,6 +512,25 @@ class MLService {
|
||||
return actuallyRanML;
|
||||
}
|
||||
|
||||
Future<void> _ensureDownloadedModels([bool forceRefresh = false]) async {
|
||||
if (_downloadModelLock.locked) {
|
||||
_logger.finest("Download models already in progress");
|
||||
}
|
||||
return _downloadModelLock.synchronized(() async {
|
||||
if (areModelsDownloaded) {
|
||||
_logger.finest("Models already downloaded");
|
||||
return;
|
||||
}
|
||||
_logger.info('Downloading models');
|
||||
await Future.wait([
|
||||
FaceDetectionService.instance.downloadModel(forceRefresh),
|
||||
FaceEmbeddingService.instance.downloadModel(forceRefresh),
|
||||
ClipImageEncoder.instance.downloadModel(forceRefresh),
|
||||
]);
|
||||
areModelsDownloaded = true;
|
||||
});
|
||||
}
|
||||
|
||||
Future<void> _ensureLoadedModels(FileMLInstruction instruction) async {
|
||||
return _initModelLock.synchronized(() async {
|
||||
final faceDetectionLoaded = FaceDetectionService.instance.isInitialized;
|
||||
@@ -522,7 +544,6 @@ class MLService {
|
||||
return;
|
||||
}
|
||||
|
||||
modelsAreLoading = true;
|
||||
_logger.info(
|
||||
'Loading models. faces: $shouldLoadFaces, clip: $shouldLoadClip',
|
||||
);
|
||||
@@ -530,7 +551,6 @@ class MLService {
|
||||
.loadModels(loadFaces: shouldLoadFaces, loadClip: shouldLoadClip);
|
||||
_logger.info('Models loaded');
|
||||
_logStatus();
|
||||
modelsAreLoading = false;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -119,8 +119,7 @@ class _MachineLearningSettingsPageState
|
||||
onChanged: () async {
|
||||
final isEnabled = await localSettings.toggleFaceIndexing();
|
||||
if (isEnabled) {
|
||||
await MLService.instance.init();
|
||||
MLService.instance.downloadModels().ignore();
|
||||
await MLService.instance.init(firstTime: true);
|
||||
|
||||
await SemanticSearchService.instance.init();
|
||||
unawaited(MLService.instance.runAllML(force: true));
|
||||
@@ -138,9 +137,9 @@ class _MachineLearningSettingsPageState
|
||||
height: 12,
|
||||
),
|
||||
hasEnabled
|
||||
? MLService.instance.modelsAreLoading
|
||||
? const ModelLoadingState()
|
||||
: const MLStatusWidget()
|
||||
? MLService.instance.areModelsDownloaded
|
||||
? const MLStatusWidget()
|
||||
: const ModelLoadingState()
|
||||
: const SizedBox.shrink(),
|
||||
],
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user