diff --git a/mobile/lib/services/machine_learning/ml_computer.dart b/mobile/lib/services/machine_learning/ml_computer.dart index 00044a4b8a..8c370b31d1 100644 --- a/mobile/lib/services/machine_learning/ml_computer.dart +++ b/mobile/lib/services/machine_learning/ml_computer.dart @@ -6,6 +6,7 @@ import 'dart:typed_data' show Uint8List; import "package:dart_ui_isolate/dart_ui_isolate.dart"; import "package:logging/logging.dart"; import "package:photos/face/model/box.dart"; +import "package:photos/services/machine_learning/ml_model.dart"; import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart"; import "package:photos/services/remote_assets_service.dart"; import "package:photos/utils/image_ml_util.dart"; @@ -13,6 +14,7 @@ import "package:synchronized/synchronized.dart"; enum MLComputerOperation { generateFaceThumbnails, + loadModel, runClipText, } @@ -85,6 +87,15 @@ class MLComputerIsolate { faceBoxes, ); sendPort.send(List.from(results)); + case MLComputerOperation.loadModel: + final modelName = args['modelName'] as String; + final modelPath = args['modelPath'] as String; + final int address = await MlModel.loadModel( + modelName, + modelPath, + ); + sendPort.send(address); + break; case MLComputerOperation.runClipText: final textEmbedding = await ClipTextEncoder.predict(args); sendPort.send(List.from(textEmbedding, growable: false)); @@ -147,10 +158,12 @@ class MLComputerIsolate { Future> runClipText(String query) async { try { + await _ensureLoadedClipTextModel(); final int clipAddress = ClipTextEncoder.instance.sessionAddress; - final String remotePath = ClipTextEncoder.instance.vocabRemotePath; + final String tokenizerRemotePath = + ClipTextEncoder.instance.vocabRemotePath; final String tokenizerVocabPath = - await RemoteAssetsService.instance.getAssetPath(remotePath); + await RemoteAssetsService.instance.getAssetPath(tokenizerRemotePath); final textEmbedding = await _runInIsolate( ( MLComputerOperation.runClipText, @@ -168,4 +181,27 @@ class MLComputerIsolate { rethrow; } } + + Future _ensureLoadedClipTextModel() async { + try { + if (ClipTextEncoder.instance.isInitialized) return; + final String modelName = ClipTextEncoder.instance.modelName; + final String remotePath = ClipTextEncoder.instance.modelRemotePath; + final String modelPath = + await RemoteAssetsService.instance.getAssetPath(remotePath); + final address = await _runInIsolate( + ( + MLComputerOperation.loadModel, + { + 'modelName': modelName, + 'modelPath': modelPath, + }, + ), + ) as int; + ClipTextEncoder.instance.storeSessionAddress(address); + } catch (e, s) { + _logger.severe("Could not load clip text model in MLComputer", e, s); + rethrow; + } + } }