[mob][photos] Image embedding caching perf (#5459)

## Description

Made changes to more efficiently handle image embeddings caching,
keeping them cached only in MLComputer and only for limited times when
the user needs it.

## Tests

Tested in debug mode on my pixel phone.
This commit is contained in:
Laurens Priem
2025-03-28 14:13:42 +05:30
committed by GitHub
10 changed files with 269 additions and 154 deletions

View File

@@ -1166,8 +1166,21 @@ class MLDataDB with SqlDbBase implements IMLDataDB<int> {
Future<List<EmbeddingVector>> getAllClipVectors() async {
Logger("ClipDB").info("reading all embeddings from DB");
final db = await instance.asyncDB;
final results = await db.getAll('SELECT * FROM $clipTable');
return _convertToVectors(results);
final results = await db
.getAll('SELECT $fileIDColumn, $embeddingColumn FROM $clipTable');
// Convert rows to vectors
final List<EmbeddingVector> embeddings = [];
for (final result in results) {
// Convert to EmbeddingVector
final embedding = EmbeddingVector(
fileID: result[fileIDColumn],
embedding: Float32List.view(result[embeddingColumn].buffer),
);
if (embedding.isEmpty) continue;
embeddings.add(embedding);
}
return embeddings;
}
// Get indexed FileIDs
@@ -1229,23 +1242,6 @@ class MLDataDB with SqlDbBase implements IMLDataDB<int> {
Bus.instance.fire(EmbeddingUpdatedEvent());
}
List<EmbeddingVector> _convertToVectors(List<Map<String, dynamic>> results) {
final List<EmbeddingVector> embeddings = [];
for (final result in results) {
final embedding = _getVectorFromRow(result);
if (embedding.isEmpty) continue;
embeddings.add(embedding);
}
return embeddings;
}
EmbeddingVector _getVectorFromRow(Map<String, dynamic> row) {
final fileID = row[fileIDColumn] as int;
final bytes = row[embeddingColumn] as Uint8List;
final list = Float32List.view(bytes.buffer);
return EmbeddingVector(fileID: fileID, embedding: list);
}
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
return [
embedding.fileID,

View File

@@ -5,13 +5,17 @@ import "package:ml_linalg/linalg.dart";
import "package:photos/models/ml/face/box.dart";
import "package:photos/models/ml/vector.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/ml_constants.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart";
import "package:photos/services/machine_learning/semantic_search/query_result.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/ml_util.dart";
final Map<String, dynamic> _isolateCache = {};
enum IsolateOperation {
/// [MLIndexingIsolate]
analyzeImage,
@@ -35,10 +39,15 @@ enum IsolateOperation {
runClipText,
/// [MLComputer]
compareEmbeddings,
computeBulkSimilarities,
/// [FaceClusteringService]
linearIncrementalClustering,
/// Cache operations
setIsolateCache,
clearIsolateCache,
clearAllIsolateCache,
}
/// WARNING: Only return primitives unless you know the method is only going
@@ -121,19 +130,31 @@ Future<dynamic> isolateFunction(
return List<double>.from(textEmbedding, growable: false);
/// MLComputer
case IsolateOperation.compareEmbeddings:
final List<EmbeddingVector> embeddings =
(args['embeddings'] as List<String>)
.map((jsonString) => EmbeddingVector.fromJsonString(jsonString))
.toList();
final otherEmbedding =
Vector.fromList(args['otherEmbedding'] as List<double>);
final Map<int, double> result = {};
for (final embedding in embeddings) {
final double similarity = embedding.vector.dot(otherEmbedding);
result[embedding.fileID] = similarity;
case IsolateOperation.computeBulkSimilarities:
final imageEmbeddings =
_isolateCache[imageEmbeddingsKey] as List<EmbeddingVector>;
final textEmbedding =
args["textQueryToEmbeddingMap"] as Map<String, List<double>>;
final minimumSimilarityMap =
args["minimumSimilarityMap"] as Map<String, double>;
final result = <String, List<QueryResult>>{};
for (final MapEntry<String, List<double>> entry
in textEmbedding.entries) {
final query = entry.key;
final textVector = Vector.fromList(entry.value);
final minimumSimilarity = minimumSimilarityMap[query]!;
final queryResults = <QueryResult>[];
for (final imageEmbedding in imageEmbeddings) {
final similarity = imageEmbedding.vector.dot(textVector);
if (similarity >= minimumSimilarity) {
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
}
}
queryResults
.sort((first, second) => second.score.compareTo(first.score));
result[query] = queryResults;
}
return Map<int, double>.from(result);
return result;
/// Cases for MLComputer end here
@@ -145,5 +166,27 @@ Future<dynamic> isolateFunction(
return result;
/// Cases for FaceClusteringService end here
/// Cases for Caching start here
/// Caching
case IsolateOperation.setIsolateCache:
final key = args['key'] as String;
final value = args['value'];
_isolateCache[key] = value;
return true;
/// Caching
case IsolateOperation.clearIsolateCache:
final key = args['key'] as String;
_isolateCache.remove(key);
return true;
/// Caching
case IsolateOperation.clearAllIsolateCache:
_isolateCache.clear();
return true;
/// Cases for Caching stop here
}
}

View File

@@ -3,6 +3,7 @@ import 'dart:isolate';
import "package:dart_ui_isolate/dart_ui_isolate.dart";
import "package:flutter/foundation.dart" show kDebugMode;
import "package:flutter/services.dart";
import "package:logging/logging.dart";
import "package:photos/core/error-reporting/isolate_logging.dart";
import "package:photos/models/base/id.dart";
@@ -13,7 +14,7 @@ abstract class SuperIsolate {
Logger get logger;
Timer? _inactivityTimer;
final Duration _inactivityDuration = const Duration(seconds: 120);
final Duration _inactivityDuration = const Duration(seconds: 60);
int _activeTasks = 0;
final _initIsolateLock = Lock();
@@ -36,15 +37,22 @@ abstract class SuperIsolate {
_receivePort = ReceivePort();
// Get the root token before spawning the isolate
final rootToken = RootIsolateToken.instance;
if (rootToken == null && !isDartUiIsolate) {
logger.severe('Failed to get RootIsolateToken');
return;
}
try {
_isolate = isDartUiIsolate
? await DartUiIsolate.spawn(
_isolateMain,
_receivePort.sendPort,
[_receivePort.sendPort, null],
)
: await Isolate.spawn(
_isolateMain,
_receivePort.sendPort,
[_receivePort.sendPort, rootToken],
debugName: isolateName,
);
_mainSendPort = await _receivePort.first as SendPort;
@@ -60,12 +68,18 @@ abstract class SuperIsolate {
}
@pragma('vm:entry-point')
static void _isolateMain(SendPort mainSendPort) async {
static void _isolateMain(List<dynamic> args) async {
final SendPort mainSendPort = args[0] as SendPort;
final RootIsolateToken? rootToken = args[1] as RootIsolateToken?;
Logger.root.level = kDebugMode ? Level.ALL : Level.INFO;
final IsolateLogger isolateLogger = IsolateLogger();
Logger.root.onRecord.listen(isolateLogger.onLogRecordInIsolate);
final receivePort = ReceivePort();
mainSendPort.send(receivePort.sendPort);
if (rootToken != null) {
BackgroundIsolateBinaryMessenger.ensureInitialized(rootToken);
}
receivePort.listen((message) async {
final taskID = message[0] as String;
@@ -137,6 +151,24 @@ abstract class SuperIsolate {
bool postFunctionlockStop(IsolateOperation operation) => false;
Future<void> cacheData(String key, dynamic value) async {
await runInIsolate(IsolateOperation.setIsolateCache, {
'key': key,
'value': value,
});
}
/// Clears specific data from the isolate's cache
Future<void> clearCachedData(String key) async {
await runInIsolate(IsolateOperation.clearIsolateCache, {
'key': key,
});
}
Future<void> clearAllCachedData() async {
await runInIsolate(IsolateOperation.clearAllIsolateCache, {});
}
/// Resets a timer that kills the isolate after a certain amount of inactivity.
///
/// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`)
@@ -161,6 +193,7 @@ abstract class SuperIsolate {
void _disposeIsolate() async {
if (!_isIsolateSpawned) return;
logger.info('Disposing isolate');
await clearAllCachedData();
await onDispose();
_isIsolateSpawned = false;
_isolate.kill();

View File

@@ -0,0 +1,47 @@
import 'dart:async';
import 'dart:typed_data' show Uint8List;
import "package:logging/logging.dart";
import "package:photos/models/ml/face/box.dart";
import "package:photos/services/isolate_functions.dart";
import "package:photos/services/isolate_service.dart";
import "package:photos/utils/image_ml_util.dart";
class FaceThumbnailGenerator extends SuperIsolate {
@override
Logger get logger => _logger;
final _logger = Logger('FaceThumbnailGenerator');
@override
bool get isDartUiIsolate => true;
@override
String get isolateName => "FaceThumbnailGenerator";
@override
bool get shouldAutomaticDispose => true;
// Singleton pattern
FaceThumbnailGenerator._privateConstructor();
static final FaceThumbnailGenerator instance =
FaceThumbnailGenerator._privateConstructor();
factory FaceThumbnailGenerator() => instance;
/// Generates face thumbnails for all [faceBoxes] in [imageData].
///
/// Uses [generateFaceThumbnailsUsingCanvas] inside the isolate.
Future<List<Uint8List>> generateFaceThumbnails(
String imagePath,
List<FaceBox> faceBoxes,
) async {
final List<Map<String, dynamic>> faceBoxesJson =
faceBoxes.map((box) => box.toJson()).toList();
return await runInIsolate(
IsolateOperation.generateFaceThumbnails,
{
'imagePath': imagePath,
'faceBoxesList': faceBoxesJson,
},
).then((value) => value.cast<Uint8List>());
}
}

View File

@@ -1,15 +1,13 @@
import 'dart:async';
import 'dart:typed_data' show Uint8List;
import "package:logging/logging.dart";
import "package:ml_linalg/linalg.dart";
import "package:photos/models/ml/face/box.dart";
import "package:photos/models/ml/vector.dart";
import "package:photos/services/isolate_functions.dart";
import "package:photos/services/isolate_service.dart";
import "package:photos/services/machine_learning/ml_constants.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart";
import "package:photos/services/machine_learning/semantic_search/query_result.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:synchronized/synchronized.dart";
class MLComputer extends SuperIsolate {
@@ -20,7 +18,7 @@ class MLComputer extends SuperIsolate {
final _initModelLock = Lock();
@override
bool get isDartUiIsolate => true;
bool get isDartUiIsolate => false;
@override
String get isolateName => "MLComputerIsolate";
@@ -33,24 +31,6 @@ class MLComputer extends SuperIsolate {
static final MLComputer instance = MLComputer._privateConstructor();
factory MLComputer() => instance;
/// Generates face thumbnails for all [faceBoxes] in [imageData].
///
/// Uses [generateFaceThumbnailsUsingCanvas] inside the isolate.
Future<List<Uint8List>> generateFaceThumbnails(
String imagePath,
List<FaceBox> faceBoxes,
) async {
final List<Map<String, dynamic>> faceBoxesJson =
faceBoxes.map((box) => box.toJson()).toList();
return await runInIsolate(
IsolateOperation.generateFaceThumbnails,
{
'imagePath': imagePath,
'faceBoxesList': faceBoxesJson,
},
).then((value) => value.cast<Uint8List>());
}
Future<List<double>> runClipText(String query) async {
try {
await _ensureLoadedClipTextModel();
@@ -66,19 +46,6 @@ class MLComputer extends SuperIsolate {
}
}
Future<Map<int, double>> compareEmbeddings(List<EmbeddingVector> embeddings, Vector otherEmbedding) async {
try {
final fileIdToSimilarity = await runInIsolate(IsolateOperation.compareEmbeddings, {
"embeddings": embeddings.map((e) => e.toJsonString()).toList(),
"otherEmbedding": otherEmbedding.toList(),
}) as Map<int, double>;
return fileIdToSimilarity;
} catch (e, s) {
_logger.severe("Could not compare embeddings MLComputer isolate", e, s);
rethrow;
}
}
Future<void> _ensureLoadedClipTextModel() async {
return _initModelLock.synchronized(() async {
if (ClipTextEncoder.instance.isInitialized) return;
@@ -114,4 +81,61 @@ class MLComputer extends SuperIsolate {
}
});
}
Future<Map<String, List<QueryResult>>> computeBulkSimilarities(
Map<String, List<double>> textQueryToEmbeddingMap,
Map<String, double> minimumSimilarityMap,
) async {
try {
final queryToResults =
await runInIsolate(IsolateOperation.computeBulkSimilarities, {
"textQueryToEmbeddingMap": textQueryToEmbeddingMap,
"minimumSimilarityMap": minimumSimilarityMap,
}) as Map<String, List<QueryResult>>;
return queryToResults;
} catch (e, s) {
_logger.severe(
"Could not bulk compare embeddings inside MLComputer isolate",
e,
s,
);
rethrow;
}
}
Future<void> cacheImageEmbeddings(List<EmbeddingVector> embeddings) async {
try {
await runInIsolate(
IsolateOperation.setIsolateCache,
{
'key': imageEmbeddingsKey,
'value': embeddings,
},
) as bool;
_logger.info(
'Cached ${embeddings.length} image embeddings inside MLComputer',
);
return;
} catch (e, s) {
_logger.severe("Could not cache image embeddings in MLComputer", e, s);
rethrow;
}
}
Future<void> clearImageEmbeddingsCache() async {
try {
await runInIsolate(
IsolateOperation.clearIsolateCache,
{'key': imageEmbeddingsKey},
) as bool;
return;
} catch (e, s) {
_logger.severe(
"Could not clear image embeddings cache in MLComputer",
e,
s,
);
rethrow;
}
}
}

View File

@@ -0,0 +1 @@
const imageEmbeddingsKey = "imageEmbeddings";

View File

@@ -0,0 +1,6 @@
class QueryResult {
final int id;
final double score;
QueryResult(this.id, this.score);
}

View File

@@ -1,12 +1,10 @@
import "dart:async" show unawaited;
import "dart:async" show Timer, unawaited;
import "dart:developer" as dev show log;
import "dart:math" show min;
import "dart:ui" show Image;
import "package:computer/computer.dart";
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/core/cache/lru_map.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
@@ -15,13 +13,14 @@ import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/models/file/file.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/models/ml/vector.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/collections_service.dart";
import "package:photos/services/machine_learning/ml_computer.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
import "package:photos/services/machine_learning/semantic_search/query_result.dart";
import "package:shared_preferences/shared_preferences.dart";
import "package:synchronized/synchronized.dart";
class SemanticSearchService {
static final _logger = Logger("SemanticSearchService");
@@ -30,7 +29,6 @@ class SemanticSearchService {
static final SemanticSearchService instance =
SemanticSearchService._privateConstructor();
static final Computer _computer = Computer.shared();
final LRUMap<String, List<double>> _queryEmbeddingCache = LRUMap(20);
static const kMinimumSimilarityThreshold = 0.175;
late final mlDataDB = MLDataDB.instance;
@@ -38,7 +36,11 @@ class SemanticSearchService {
bool _hasInitialized = false;
bool _textModelIsLoaded = false;
Future<List<EmbeddingVector>>? _cachedImageEmbeddingVectors;
final _cacheLock = Lock();
bool _imageEmbeddingsAreCached = false;
Timer? _embeddingsCacheTimer;
final Duration _embeddingsCacheDuration = const Duration(seconds: 60);
Future<(String, List<EnteFile>)>? _searchScreenRequest;
String? _latestPendingQuery;
@@ -53,12 +55,11 @@ class SemanticSearchService {
_logger.info("init called");
_hasInitialized = true;
// call getClipEmbeddings after 5 seconds
Future.delayed(const Duration(seconds: 5), () async {
await getClipVectors();
});
Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
_cachedImageEmbeddingVectors = null;
if (_imageEmbeddingsAreCached) {
MLComputer.instance.clearImageEmbeddingsCache();
_imageEmbeddingsAreCached = false;
}
});
unawaited(_loadTextModel(delay: true));
@@ -108,14 +109,21 @@ class SemanticSearchService {
_logger.info("Indexes cleared");
}
Future<List<EmbeddingVector>> getClipVectors() async {
if (_cachedImageEmbeddingVectors != null) {
return _cachedImageEmbeddingVectors!;
}
_cachedImageEmbeddingVectors ??= mlDataDB.getAllClipVectors();
_logger.info("read all embeddings from DB");
return _cachedImageEmbeddingVectors!;
Future<void> _cacheClipVectors() async {
return _cacheLock.synchronized(() async {
_resetInactivityTimer();
if (_imageEmbeddingsAreCached) {
return;
}
final now = DateTime.now();
final imageEmbeddings = await mlDataDB.getAllClipVectors();
_logger.info(
"read all ${imageEmbeddings.length} embeddings from DB in ${DateTime.now().difference(now).inMilliseconds} ms",
);
await MLComputer.instance.cacheImageEmbeddings(imageEmbeddings);
_imageEmbeddingsAreCached = true;
return;
});
}
Future<List<EnteFile>> getMatchingFiles(
@@ -257,17 +265,10 @@ class SemanticSearchService {
required Map<String, double> minimumSimilarityMap,
}) async {
final startTime = DateTime.now();
final imageEmbeddings = await getClipVectors();
final Map<String, List<QueryResult>> queryResults = await _computer
.compute<Map<String, dynamic>, Map<String, List<QueryResult>>>(
computeBulkSimilarities,
param: {
"imageEmbeddings": imageEmbeddings,
"textQueryToEmbeddingMap": textQueryToEmbeddingMap,
"minimumSimilarityMap": minimumSimilarityMap,
},
taskName: "computeBulkSimilarities",
);
await _cacheClipVectors();
final Map<String, List<QueryResult>> queryResults = await MLComputer
.instance
.computeBulkSimilarities(textQueryToEmbeddingMap, minimumSimilarityMap);
final endTime = DateTime.now();
_logger.info(
"computingSimilarities took for ${textQueryToEmbeddingMap.length} queries " +
@@ -278,6 +279,19 @@ class SemanticSearchService {
return queryResults;
}
void _resetInactivityTimer() {
_embeddingsCacheTimer?.cancel();
_embeddingsCacheTimer = Timer(_embeddingsCacheDuration, () {
_logger.info(
'Embeddings cache is unused for ${_embeddingsCacheDuration.inSeconds} seconds. Removing cache.',
);
if (_imageEmbeddingsAreCached) {
MLComputer.instance.clearImageEmbeddingsCache();
_imageEmbeddingsAreCached = false;
}
});
}
static Future<ClipResult> runClipImage(
int enteFileID,
Image image,
@@ -296,50 +310,3 @@ class SemanticSearchService {
return clipResult;
}
}
Map<String, List<QueryResult>> computeBulkSimilarities(Map args) {
final imageEmbeddings = args["imageEmbeddings"] as List<EmbeddingVector>;
final textEmbedding =
args["textQueryToEmbeddingMap"] as Map<String, List<double>>;
final minimumSimilarityMap =
args["minimumSimilarityMap"] as Map<String, double>;
final result = <String, List<QueryResult>>{};
for (final MapEntry<String, List<double>> entry in textEmbedding.entries) {
final query = entry.key;
final textVector = Vector.fromList(entry.value);
final minimumSimilarity = minimumSimilarityMap[query]!;
final queryResults = <QueryResult>[];
if (!kDebugMode) {
for (final imageEmbedding in imageEmbeddings) {
final similarity = imageEmbedding.vector.dot(textVector);
if (similarity >= minimumSimilarity) {
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
}
}
} else {
double bestScore = 0.0;
for (final imageEmbedding in imageEmbeddings) {
final similarity = imageEmbedding.vector.dot(textVector);
if (similarity >= minimumSimilarity) {
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
}
if (similarity > bestScore) {
bestScore = similarity;
}
}
if (kDebugMode && queryResults.isEmpty) {
dev.log("No results found for query with best score: $bestScore");
}
}
queryResults.sort((first, second) => second.score.compareTo(first.score));
result[query] = queryResults;
}
return result;
}
class QueryResult {
final int id;
final double score;
QueryResult(this.id, this.score);
}

View File

@@ -35,7 +35,6 @@ import "package:photos/services/location_service.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/machine_learning/ml_computer.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
import "package:photos/services/search_service.dart";
class MemoriesResult {
@@ -92,8 +91,7 @@ class SmartMemoriesService {
await MLDataDB.instance.getFileIDsToFacesWithoutEmbedding();
_logger.finest('fileIdToFaces has ${fileIdToFaces.length} entries $t');
final allImageEmbeddings =
await SemanticSearchService.instance.getClipVectors();
final allImageEmbeddings = await MLDataDB.instance.getAllClipVectors();
_logger.finest(
'allImageEmbeddings has ${allImageEmbeddings.length} entries $t',
);

View File

@@ -10,7 +10,7 @@ import "package:photos/models/file/file.dart";
import "package:photos/models/file/file_type.dart";
import "package:photos/models/ml/face/box.dart";
import "package:photos/models/ml/face/face.dart";
import "package:photos/services/machine_learning/ml_computer.dart";
import "package:photos/services/machine_learning/face_thumbnail_generator.dart";
import "package:photos/utils/file_util.dart";
import "package:photos/utils/thumbnail_util.dart";
import "package:pool/pool.dart";
@@ -210,7 +210,7 @@ Future<Map<String, Uint8List>?> _getFaceCrops(
faceBoxes.add(e.value);
}
final List<Uint8List> faceCrop =
await MLComputer.instance.generateFaceThumbnails(
await FaceThumbnailGenerator.instance.generateFaceThumbnails(
// await generateJpgFaceThumbnails(
imagePath,
faceBoxes,