[mob][photos] Improve magic search speed using simd
This commit is contained in:
@@ -7,16 +7,23 @@ import "package:photos/db/ml/db_fields.dart";
|
||||
import "package:photos/events/embedding_updated_event.dart";
|
||||
import "package:photos/models/ml/clip.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:photos/models/ml/vector.dart";
|
||||
|
||||
extension ClipDB on MLDataDB {
|
||||
static const databaseName = "ente.embeddings.db";
|
||||
|
||||
Future<List<ClipEmbedding>> getAll() async {
|
||||
Future<List<ClipEmbedding>> getAllClipEmbeddings() async {
|
||||
final db = await MLDataDB.instance.asyncDB;
|
||||
final results = await db.getAll('SELECT * FROM $clipTable');
|
||||
return _convertToEmbeddings(results);
|
||||
}
|
||||
|
||||
Future<List<EmbeddingVector>> getAllClipVectors() async {
|
||||
final db = await MLDataDB.instance.asyncDB;
|
||||
final results = await db.getAll('SELECT * FROM $clipTable');
|
||||
return _convertToVectors(results);
|
||||
}
|
||||
|
||||
// Get indexed FileIDs
|
||||
Future<Map<int, int>> clipIndexedFileWithVersion() async {
|
||||
final db = await MLDataDB.instance.asyncDB;
|
||||
@@ -83,6 +90,16 @@ extension ClipDB on MLDataDB {
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
List<EmbeddingVector> _convertToVectors(List<Map<String, dynamic>> results) {
|
||||
final List<EmbeddingVector> embeddings = [];
|
||||
for (final result in results) {
|
||||
final embedding = _getVectorFromRow(result);
|
||||
if (embedding.isEmpty) continue;
|
||||
embeddings.add(embedding);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
ClipEmbedding _getEmbeddingFromRow(Map<String, dynamic> row) {
|
||||
final fileID = row[fileIDColumn] as int;
|
||||
final bytes = row[embeddingColumn] as Uint8List;
|
||||
@@ -91,6 +108,13 @@ extension ClipDB on MLDataDB {
|
||||
return ClipEmbedding(fileID: fileID, embedding: list, version: version);
|
||||
}
|
||||
|
||||
EmbeddingVector _getVectorFromRow(Map<String, dynamic> row) {
|
||||
final fileID = row[fileIDColumn] as int;
|
||||
final bytes = row[embeddingColumn] as Uint8List;
|
||||
final list = Float32List.view(bytes.buffer);
|
||||
return EmbeddingVector(fileID: fileID, embedding: list);
|
||||
}
|
||||
|
||||
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
|
||||
return [
|
||||
embedding.fileID,
|
||||
|
||||
23
mobile/lib/models/ml/vector.dart
Normal file
23
mobile/lib/models/ml/vector.dart
Normal file
@@ -0,0 +1,23 @@
|
||||
import "dart:convert";
|
||||
|
||||
import "package:ml_linalg/vector.dart";
|
||||
|
||||
class EmbeddingVector {
|
||||
final int fileID;
|
||||
final Vector vector;
|
||||
|
||||
bool get isEmpty => vector.isEmpty;
|
||||
|
||||
EmbeddingVector({
|
||||
required this.fileID,
|
||||
required List<double> embedding,
|
||||
}) : vector = Vector.fromList(embedding);
|
||||
|
||||
static Vector decodeEmbedding(String embedding) {
|
||||
return Vector.fromList(List<double>.from(jsonDecode(embedding) as List));
|
||||
}
|
||||
|
||||
static String encodeEmbedding(Vector embedding) {
|
||||
return jsonEncode(embedding.toList());
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import "dart:ui" show Image;
|
||||
import "package:computer/computer.dart";
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/vector.dart";
|
||||
import "package:photos/core/cache/lru_map.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
@@ -15,9 +16,9 @@ import 'package:photos/events/embedding_updated_event.dart';
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/models/ml/clip.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:photos/models/ml/vector.dart";
|
||||
import "package:photos/service_locator.dart";
|
||||
import "package:photos/services/collections_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
|
||||
import "package:photos/services/machine_learning/ml_computer.dart";
|
||||
import "package:photos/services/machine_learning/ml_result.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
|
||||
@@ -37,7 +38,7 @@ class SemanticSearchService {
|
||||
bool _hasInitialized = false;
|
||||
bool _textModelIsLoaded = false;
|
||||
|
||||
Future<List<ClipEmbedding>>? _cachedImageEmbeddings;
|
||||
Future<List<EmbeddingVector>>? _cachedImageEmbeddingVectors;
|
||||
Future<(String, List<EnteFile>)>? _searchScreenRequest;
|
||||
String? _latestPendingQuery;
|
||||
|
||||
@@ -53,10 +54,10 @@ class SemanticSearchService {
|
||||
|
||||
// call getClipEmbeddings after 5 seconds
|
||||
Future.delayed(const Duration(seconds: 5), () async {
|
||||
await getClipEmbeddings();
|
||||
await getClipVectors();
|
||||
});
|
||||
Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
|
||||
_cachedImageEmbeddings = null;
|
||||
_cachedImageEmbeddingVectors = null;
|
||||
});
|
||||
|
||||
unawaited(_loadTextModel(delay: true));
|
||||
@@ -106,10 +107,10 @@ class SemanticSearchService {
|
||||
_logger.info("Indexes cleared");
|
||||
}
|
||||
|
||||
Future<List<ClipEmbedding>> getClipEmbeddings() async {
|
||||
_logger.info("Pulling cached embeddings");
|
||||
_cachedImageEmbeddings ??= MLDataDB.instance.getAll();
|
||||
return _cachedImageEmbeddings!;
|
||||
Future<List<EmbeddingVector>> getClipVectors() async {
|
||||
_logger.info("Pulling cached clip embeddings");
|
||||
_cachedImageEmbeddingVectors ??= MLDataDB.instance.getAllClipVectors();
|
||||
return _cachedImageEmbeddingVectors!;
|
||||
}
|
||||
|
||||
Future<List<EnteFile>> getMatchingFiles(
|
||||
@@ -267,7 +268,7 @@ class SemanticSearchService {
|
||||
double? minimumSimilarity,
|
||||
}) async {
|
||||
final startTime = DateTime.now();
|
||||
final embeddings = await getClipEmbeddings();
|
||||
final embeddings = await getClipVectors();
|
||||
final List<QueryResult> queryResults = await _computer.compute(
|
||||
computeBulkSimilarities,
|
||||
param: {
|
||||
@@ -311,25 +312,33 @@ class SemanticSearchService {
|
||||
|
||||
List<QueryResult> computeBulkSimilarities(Map args) {
|
||||
final queryResults = <QueryResult>[];
|
||||
final imageEmbeddings = args["imageEmbeddings"] as List<ClipEmbedding>;
|
||||
final imageEmbeddings = args["imageEmbeddings"] as List<EmbeddingVector>;
|
||||
final textEmbedding = args["textEmbedding"] as List<double>;
|
||||
final minimumSimilarity = args["minimumSimilarity"] ??
|
||||
SemanticSearchService.kMinimumSimilarityThreshold;
|
||||
double bestScore = 0.0;
|
||||
for (final imageEmbedding in imageEmbeddings) {
|
||||
final score = computeCosineSimilarity(
|
||||
imageEmbedding.embedding,
|
||||
textEmbedding,
|
||||
);
|
||||
if (score >= minimumSimilarity) {
|
||||
queryResults.add(QueryResult(imageEmbedding.fileID, score));
|
||||
|
||||
final Vector textVector = Vector.fromList(textEmbedding);
|
||||
if (!kDebugMode) {
|
||||
for (final imageEmbedding in imageEmbeddings) {
|
||||
final similarity = imageEmbedding.vector.dot(textVector);
|
||||
if (similarity >= minimumSimilarity) {
|
||||
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
|
||||
}
|
||||
}
|
||||
if (score > bestScore) {
|
||||
bestScore = score;
|
||||
} else {
|
||||
double bestScore = 0.0;
|
||||
for (final imageEmbedding in imageEmbeddings) {
|
||||
final similarity = imageEmbedding.vector.dot(textVector);
|
||||
if (similarity >= minimumSimilarity) {
|
||||
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
|
||||
}
|
||||
if (similarity > bestScore) {
|
||||
bestScore = similarity;
|
||||
}
|
||||
}
|
||||
if (kDebugMode && queryResults.isEmpty) {
|
||||
dev.log("No results found for query with best score: $bestScore");
|
||||
}
|
||||
}
|
||||
if (kDebugMode && queryResults.isEmpty) {
|
||||
dev.log("No results found for query with best score: $bestScore");
|
||||
}
|
||||
|
||||
queryResults.sort((first, second) => second.score.compareTo(first.score));
|
||||
|
||||
Reference in New Issue
Block a user