[mob][photos] Improve magic search speed using simd

This commit is contained in:
laurenspriem
2024-09-03 18:48:16 +02:00
parent 877c660ba9
commit 7097cb335d
3 changed files with 80 additions and 24 deletions

View File

@@ -7,16 +7,23 @@ import "package:photos/db/ml/db_fields.dart";
import "package:photos/events/embedding_updated_event.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/models/ml/vector.dart";
extension ClipDB on MLDataDB {
static const databaseName = "ente.embeddings.db";
Future<List<ClipEmbedding>> getAll() async {
Future<List<ClipEmbedding>> getAllClipEmbeddings() async {
final db = await MLDataDB.instance.asyncDB;
final results = await db.getAll('SELECT * FROM $clipTable');
return _convertToEmbeddings(results);
}
Future<List<EmbeddingVector>> getAllClipVectors() async {
final db = await MLDataDB.instance.asyncDB;
final results = await db.getAll('SELECT * FROM $clipTable');
return _convertToVectors(results);
}
// Get indexed FileIDs
Future<Map<int, int>> clipIndexedFileWithVersion() async {
final db = await MLDataDB.instance.asyncDB;
@@ -83,6 +90,16 @@ extension ClipDB on MLDataDB {
return embeddings;
}
List<EmbeddingVector> _convertToVectors(List<Map<String, dynamic>> results) {
final List<EmbeddingVector> embeddings = [];
for (final result in results) {
final embedding = _getVectorFromRow(result);
if (embedding.isEmpty) continue;
embeddings.add(embedding);
}
return embeddings;
}
ClipEmbedding _getEmbeddingFromRow(Map<String, dynamic> row) {
final fileID = row[fileIDColumn] as int;
final bytes = row[embeddingColumn] as Uint8List;
@@ -91,6 +108,13 @@ extension ClipDB on MLDataDB {
return ClipEmbedding(fileID: fileID, embedding: list, version: version);
}
EmbeddingVector _getVectorFromRow(Map<String, dynamic> row) {
final fileID = row[fileIDColumn] as int;
final bytes = row[embeddingColumn] as Uint8List;
final list = Float32List.view(bytes.buffer);
return EmbeddingVector(fileID: fileID, embedding: list);
}
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
return [
embedding.fileID,

View File

@@ -0,0 +1,23 @@
import "dart:convert";
import "package:ml_linalg/vector.dart";
class EmbeddingVector {
final int fileID;
final Vector vector;
bool get isEmpty => vector.isEmpty;
EmbeddingVector({
required this.fileID,
required List<double> embedding,
}) : vector = Vector.fromList(embedding);
static Vector decodeEmbedding(String embedding) {
return Vector.fromList(List<double>.from(jsonDecode(embedding) as List));
}
static String encodeEmbedding(Vector embedding) {
return jsonEncode(embedding.toList());
}
}

View File

@@ -6,6 +6,7 @@ import "dart:ui" show Image;
import "package:computer/computer.dart";
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/core/cache/lru_map.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
@@ -15,9 +16,9 @@ import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/models/file/file.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/models/ml/vector.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/collections_service.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/ml_computer.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
@@ -37,7 +38,7 @@ class SemanticSearchService {
bool _hasInitialized = false;
bool _textModelIsLoaded = false;
Future<List<ClipEmbedding>>? _cachedImageEmbeddings;
Future<List<EmbeddingVector>>? _cachedImageEmbeddingVectors;
Future<(String, List<EnteFile>)>? _searchScreenRequest;
String? _latestPendingQuery;
@@ -53,10 +54,10 @@ class SemanticSearchService {
// call getClipEmbeddings after 5 seconds
Future.delayed(const Duration(seconds: 5), () async {
await getClipEmbeddings();
await getClipVectors();
});
Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
_cachedImageEmbeddings = null;
_cachedImageEmbeddingVectors = null;
});
unawaited(_loadTextModel(delay: true));
@@ -106,10 +107,10 @@ class SemanticSearchService {
_logger.info("Indexes cleared");
}
Future<List<ClipEmbedding>> getClipEmbeddings() async {
_logger.info("Pulling cached embeddings");
_cachedImageEmbeddings ??= MLDataDB.instance.getAll();
return _cachedImageEmbeddings!;
Future<List<EmbeddingVector>> getClipVectors() async {
_logger.info("Pulling cached clip embeddings");
_cachedImageEmbeddingVectors ??= MLDataDB.instance.getAllClipVectors();
return _cachedImageEmbeddingVectors!;
}
Future<List<EnteFile>> getMatchingFiles(
@@ -267,7 +268,7 @@ class SemanticSearchService {
double? minimumSimilarity,
}) async {
final startTime = DateTime.now();
final embeddings = await getClipEmbeddings();
final embeddings = await getClipVectors();
final List<QueryResult> queryResults = await _computer.compute(
computeBulkSimilarities,
param: {
@@ -311,25 +312,33 @@ class SemanticSearchService {
List<QueryResult> computeBulkSimilarities(Map args) {
final queryResults = <QueryResult>[];
final imageEmbeddings = args["imageEmbeddings"] as List<ClipEmbedding>;
final imageEmbeddings = args["imageEmbeddings"] as List<EmbeddingVector>;
final textEmbedding = args["textEmbedding"] as List<double>;
final minimumSimilarity = args["minimumSimilarity"] ??
SemanticSearchService.kMinimumSimilarityThreshold;
double bestScore = 0.0;
for (final imageEmbedding in imageEmbeddings) {
final score = computeCosineSimilarity(
imageEmbedding.embedding,
textEmbedding,
);
if (score >= minimumSimilarity) {
queryResults.add(QueryResult(imageEmbedding.fileID, score));
final Vector textVector = Vector.fromList(textEmbedding);
if (!kDebugMode) {
for (final imageEmbedding in imageEmbeddings) {
final similarity = imageEmbedding.vector.dot(textVector);
if (similarity >= minimumSimilarity) {
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
}
}
if (score > bestScore) {
bestScore = score;
} else {
double bestScore = 0.0;
for (final imageEmbedding in imageEmbeddings) {
final similarity = imageEmbedding.vector.dot(textVector);
if (similarity >= minimumSimilarity) {
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
}
if (similarity > bestScore) {
bestScore = similarity;
}
}
if (kDebugMode && queryResults.isEmpty) {
dev.log("No results found for query with best score: $bestScore");
}
}
if (kDebugMode && queryResults.isEmpty) {
dev.log("No results found for query with best score: $bestScore");
}
queryResults.sort((first, second) => second.score.compareTo(first.score));