From cf8b9e436504e5286b2f4d62e271542861c92a5b Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 26 Jul 2024 13:41:45 +0200 Subject: [PATCH] [mob][photos] Simplify --- .../clip/clip_text_encoder.dart | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/mobile/lib/services/machine_learning/semantic_search/clip/clip_text_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_text_encoder.dart index 1587a231fe..c033f6caf5 100644 --- a/mobile/lib/services/machine_learning/semantic_search/clip/clip_text_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_text_encoder.dart @@ -1,17 +1,15 @@ -import "dart:math"; - import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; -import "package:photos/extensions/stop_watch.dart"; import "package:photos/services/machine_learning/ml_model.dart"; import 'package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart'; import "package:photos/utils/ml_util.dart"; class ClipTextEncoder extends MlModel { // static const _kRemoteBucketModelPath = "clip-text-vit-32-float32-int32.onnx"; // Unquantized model - static const _kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx"; // Quantized model + static const _kRemoteBucketModelPath = + "clip-text-vit-32-uint8.onnx"; // Quantized model static const _kVocabRemotePath = "bpe_simple_vocab_16e6.txt"; // static const kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx"; @@ -56,19 +54,13 @@ class ClipTextEncoder extends MlModel { final session = OrtSession.fromAddress(address); final outputs = session.run(runOptions, inputs); final embedding = (outputs[0]?.value as List>)[0]; - double textNormalization = 0; - for (int i = 0; i < 512; i++) { - textNormalization += embedding[i] * embedding[i]; - } - final double sqrtTextNormalization = sqrt(textNormalization); - for (int i = 0; i < 512; i++) { - embedding[i] = embedding[i] / sqrtTextNormalization; - } + normalizeEmbedding(embedding); return embedding; } - static Future> _runPlatformPluginPredict(Int32List int32list) async { - final w = EnteWatch("ClipTextEncoder._runPlatformPlugin")..start(); + static Future> _runPlatformPluginPredict( + Int32List int32list, + ) async { final OnnxDart plugin = OnnxDart(); final result = await plugin.predictInt( int32list, @@ -76,7 +68,6 @@ class ClipTextEncoder extends MlModel { ); final List embedding = result!.sublist(0, 512); normalizeEmbedding(embedding); - w.stopWithLog("done"); return embedding; } }