[mob] Download models before interence

This commit is contained in:
Neeraj Gupta
2024-07-24 17:28:45 +05:30
parent e6e9948fd8
commit fcd1ec3b05
3 changed files with 37 additions and 2 deletions

View File

@@ -38,6 +38,10 @@ abstract class MlModel {
}
}
Future<void> downloadModel() async {
await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath);
}
Future<void> _loadModelWithEntePlugin(
String modelName,
String modelPath,

View File

@@ -33,7 +33,9 @@ import "package:photos/services/machine_learning/file_ml/remote_fileml_service.d
import 'package:photos/services/machine_learning/ml_exceptions.dart';
import 'package:photos/services/machine_learning/ml_result.dart';
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart";
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/ml_util.dart";
import "package:photos/utils/network_util.dart";
@@ -65,11 +67,13 @@ class MLService {
factory MLService() => instance;
final _initModelLock = Lock();
final _downloadModelLock = Lock();
final _functionLock = Lock();
final _initIsolateLock = Lock();
bool _isInitialized = false;
bool _isModelsInitialized = false;
bool areModelDownloaded = false;
bool _isModelsInitUsingEntePlugin = false;
bool _isIsolateSpawned = false;
@@ -525,6 +529,31 @@ class MLService {
return actuallyRanML;
}
Future<void> downloadModels() {
if (areModelDownloaded) {
_logger.finest("Models already downloaded");
return Future.value();
}
if (_downloadModelLock.locked) {
_logger.finest("Download models already in progress");
}
return _downloadModelLock.synchronized(() async {
if (areModelDownloaded) {
return;
}
_logger.info('Downloading models');
await Future.wait([
FaceDetectionService.instance.downloadModel(),
FaceEmbeddingService.instance.downloadModel(),
ClipImageEncoder.instance.downloadModel(),
ClipImageEncoder.instance.downloadModel(),
RemoteAssetsService.instance
.getAsset(ClipTextTokenizer.kVocabRemotePath),
]);
areModelDownloaded = true;
});
}
Future<void> _initModelsUsingFfiBasedPlugin() async {
return _initModelLock.synchronized(() async {
if (_isModelsInitialized) return;
@@ -632,7 +661,7 @@ class MLService {
Future<void> _ensureReadyForInference() async {
await _initIsolate();
await _initModelsUsingFfiBasedPlugin();
await downloadModels();
if (Platform.isAndroid) {
await _initModelUsingEntePlugin();
} else {

View File

@@ -120,6 +120,8 @@ class _MachineLearningSettingsPageState
final isEnabled = await localSettings.toggleFaceIndexing();
if (isEnabled) {
await MLService.instance.init();
MLService.instance.downloadModels().ignore();
await SemanticSearchService.instance.init();
unawaited(MLService.instance.runAllML(force: true));
} else {}
@@ -136,7 +138,7 @@ class _MachineLearningSettingsPageState
height: 12,
),
hasEnabled
? MLService.instance.allModelsLoaded
? MLService.instance.areModelDownloaded
? const MLStatusWidget()
: const ModelLoadingState()
: const SizedBox.shrink(),