[mob][photos] Simplify

This commit is contained in:
laurenspriem
2024-07-26 13:41:45 +02:00
parent 627ec5f2a6
commit cf8b9e4365

View File

@@ -1,17 +1,15 @@
import "dart:math";
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import "package:onnx_dart/onnx_dart.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import 'package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart';
import "package:photos/utils/ml_util.dart";
class ClipTextEncoder extends MlModel {
// static const _kRemoteBucketModelPath = "clip-text-vit-32-float32-int32.onnx"; // Unquantized model
static const _kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx"; // Quantized model
static const _kRemoteBucketModelPath =
"clip-text-vit-32-uint8.onnx"; // Quantized model
static const _kVocabRemotePath = "bpe_simple_vocab_16e6.txt";
// static const kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx";
@@ -56,19 +54,13 @@ class ClipTextEncoder extends MlModel {
final session = OrtSession.fromAddress(address);
final outputs = session.run(runOptions, inputs);
final embedding = (outputs[0]?.value as List<List<double>>)[0];
double textNormalization = 0;
for (int i = 0; i < 512; i++) {
textNormalization += embedding[i] * embedding[i];
}
final double sqrtTextNormalization = sqrt(textNormalization);
for (int i = 0; i < 512; i++) {
embedding[i] = embedding[i] / sqrtTextNormalization;
}
normalizeEmbedding(embedding);
return embedding;
}
static Future<List<double>> _runPlatformPluginPredict(Int32List int32list) async {
final w = EnteWatch("ClipTextEncoder._runPlatformPlugin")..start();
static Future<List<double>> _runPlatformPluginPredict(
Int32List int32list,
) async {
final OnnxDart plugin = OnnxDart();
final result = await plugin.predictInt(
int32list,
@@ -76,7 +68,6 @@ class ClipTextEncoder extends MlModel {
);
final List<double> embedding = result!.sublist(0, 512);
normalizeEmbedding(embedding);
w.stopWithLog("done");
return embedding;
}
}