[mob][photos] Make sure ClipText is loaded

This commit is contained in:
laurenspriem
2024-07-25 22:32:16 +02:00
parent e65a36d5c7
commit 1e695ffd84

View File

@@ -6,6 +6,7 @@ import 'dart:typed_data' show Uint8List;
import "package:dart_ui_isolate/dart_ui_isolate.dart";
import "package:logging/logging.dart";
import "package:photos/face/model/box.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/image_ml_util.dart";
@@ -13,6 +14,7 @@ import "package:synchronized/synchronized.dart";
enum MLComputerOperation {
generateFaceThumbnails,
loadModel,
runClipText,
}
@@ -85,6 +87,15 @@ class MLComputerIsolate {
faceBoxes,
);
sendPort.send(List.from(results));
case MLComputerOperation.loadModel:
final modelName = args['modelName'] as String;
final modelPath = args['modelPath'] as String;
final int address = await MlModel.loadModel(
modelName,
modelPath,
);
sendPort.send(address);
break;
case MLComputerOperation.runClipText:
final textEmbedding = await ClipTextEncoder.predict(args);
sendPort.send(List.from(textEmbedding, growable: false));
@@ -147,10 +158,12 @@ class MLComputerIsolate {
Future<List<double>> runClipText(String query) async {
try {
await _ensureLoadedClipTextModel();
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
final String remotePath = ClipTextEncoder.instance.vocabRemotePath;
final String tokenizerRemotePath =
ClipTextEncoder.instance.vocabRemotePath;
final String tokenizerVocabPath =
await RemoteAssetsService.instance.getAssetPath(remotePath);
await RemoteAssetsService.instance.getAssetPath(tokenizerRemotePath);
final textEmbedding = await _runInIsolate(
(
MLComputerOperation.runClipText,
@@ -168,4 +181,27 @@ class MLComputerIsolate {
rethrow;
}
}
Future<void> _ensureLoadedClipTextModel() async {
try {
if (ClipTextEncoder.instance.isInitialized) return;
final String modelName = ClipTextEncoder.instance.modelName;
final String remotePath = ClipTextEncoder.instance.modelRemotePath;
final String modelPath =
await RemoteAssetsService.instance.getAssetPath(remotePath);
final address = await _runInIsolate(
(
MLComputerOperation.loadModel,
{
'modelName': modelName,
'modelPath': modelPath,
},
),
) as int;
ClipTextEncoder.instance.storeSessionAddress(address);
} catch (e, s) {
_logger.severe("Could not load clip text model in MLComputer", e, s);
rethrow;
}
}
}