[mob][photos] Simplify
This commit is contained in:
@@ -1,17 +1,15 @@
|
||||
import "dart:math";
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/extensions/stop_watch.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
import 'package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart';
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
|
||||
class ClipTextEncoder extends MlModel {
|
||||
// static const _kRemoteBucketModelPath = "clip-text-vit-32-float32-int32.onnx"; // Unquantized model
|
||||
static const _kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx"; // Quantized model
|
||||
static const _kRemoteBucketModelPath =
|
||||
"clip-text-vit-32-uint8.onnx"; // Quantized model
|
||||
static const _kVocabRemotePath = "bpe_simple_vocab_16e6.txt";
|
||||
|
||||
// static const kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx";
|
||||
@@ -56,19 +54,13 @@ class ClipTextEncoder extends MlModel {
|
||||
final session = OrtSession.fromAddress(address);
|
||||
final outputs = session.run(runOptions, inputs);
|
||||
final embedding = (outputs[0]?.value as List<List<double>>)[0];
|
||||
double textNormalization = 0;
|
||||
for (int i = 0; i < 512; i++) {
|
||||
textNormalization += embedding[i] * embedding[i];
|
||||
}
|
||||
final double sqrtTextNormalization = sqrt(textNormalization);
|
||||
for (int i = 0; i < 512; i++) {
|
||||
embedding[i] = embedding[i] / sqrtTextNormalization;
|
||||
}
|
||||
normalizeEmbedding(embedding);
|
||||
return embedding;
|
||||
}
|
||||
|
||||
static Future<List<double>> _runPlatformPluginPredict(Int32List int32list) async {
|
||||
final w = EnteWatch("ClipTextEncoder._runPlatformPlugin")..start();
|
||||
static Future<List<double>> _runPlatformPluginPredict(
|
||||
Int32List int32list,
|
||||
) async {
|
||||
final OnnxDart plugin = OnnxDart();
|
||||
final result = await plugin.predictInt(
|
||||
int32list,
|
||||
@@ -76,7 +68,6 @@ class ClipTextEncoder extends MlModel {
|
||||
);
|
||||
final List<double> embedding = result!.sublist(0, 512);
|
||||
normalizeEmbedding(embedding);
|
||||
w.stopWithLog("done");
|
||||
return embedding;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user