[mob][photos] Image embedding caching perf (#5459)
## Description Made changes to more efficiently handle image embeddings caching, keeping them cached only in MLComputer and only for limited times when the user needs it. ## Tests Tested in debug mode on my pixel phone.
This commit is contained in:
@@ -1166,8 +1166,21 @@ class MLDataDB with SqlDbBase implements IMLDataDB<int> {
|
||||
Future<List<EmbeddingVector>> getAllClipVectors() async {
|
||||
Logger("ClipDB").info("reading all embeddings from DB");
|
||||
final db = await instance.asyncDB;
|
||||
final results = await db.getAll('SELECT * FROM $clipTable');
|
||||
return _convertToVectors(results);
|
||||
final results = await db
|
||||
.getAll('SELECT $fileIDColumn, $embeddingColumn FROM $clipTable');
|
||||
|
||||
// Convert rows to vectors
|
||||
final List<EmbeddingVector> embeddings = [];
|
||||
for (final result in results) {
|
||||
// Convert to EmbeddingVector
|
||||
final embedding = EmbeddingVector(
|
||||
fileID: result[fileIDColumn],
|
||||
embedding: Float32List.view(result[embeddingColumn].buffer),
|
||||
);
|
||||
if (embedding.isEmpty) continue;
|
||||
embeddings.add(embedding);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
// Get indexed FileIDs
|
||||
@@ -1229,23 +1242,6 @@ class MLDataDB with SqlDbBase implements IMLDataDB<int> {
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
}
|
||||
|
||||
List<EmbeddingVector> _convertToVectors(List<Map<String, dynamic>> results) {
|
||||
final List<EmbeddingVector> embeddings = [];
|
||||
for (final result in results) {
|
||||
final embedding = _getVectorFromRow(result);
|
||||
if (embedding.isEmpty) continue;
|
||||
embeddings.add(embedding);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
EmbeddingVector _getVectorFromRow(Map<String, dynamic> row) {
|
||||
final fileID = row[fileIDColumn] as int;
|
||||
final bytes = row[embeddingColumn] as Uint8List;
|
||||
final list = Float32List.view(bytes.buffer);
|
||||
return EmbeddingVector(fileID: fileID, embedding: list);
|
||||
}
|
||||
|
||||
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
|
||||
return [
|
||||
embedding.fileID,
|
||||
|
||||
@@ -5,13 +5,17 @@ import "package:ml_linalg/linalg.dart";
|
||||
import "package:photos/models/ml/face/box.dart";
|
||||
import "package:photos/models/ml/vector.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_constants.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
import "package:photos/services/machine_learning/ml_result.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/query_result.dart";
|
||||
import "package:photos/utils/image_ml_util.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
|
||||
final Map<String, dynamic> _isolateCache = {};
|
||||
|
||||
enum IsolateOperation {
|
||||
/// [MLIndexingIsolate]
|
||||
analyzeImage,
|
||||
@@ -35,10 +39,15 @@ enum IsolateOperation {
|
||||
runClipText,
|
||||
|
||||
/// [MLComputer]
|
||||
compareEmbeddings,
|
||||
computeBulkSimilarities,
|
||||
|
||||
/// [FaceClusteringService]
|
||||
linearIncrementalClustering,
|
||||
|
||||
/// Cache operations
|
||||
setIsolateCache,
|
||||
clearIsolateCache,
|
||||
clearAllIsolateCache,
|
||||
}
|
||||
|
||||
/// WARNING: Only return primitives unless you know the method is only going
|
||||
@@ -121,19 +130,31 @@ Future<dynamic> isolateFunction(
|
||||
return List<double>.from(textEmbedding, growable: false);
|
||||
|
||||
/// MLComputer
|
||||
case IsolateOperation.compareEmbeddings:
|
||||
final List<EmbeddingVector> embeddings =
|
||||
(args['embeddings'] as List<String>)
|
||||
.map((jsonString) => EmbeddingVector.fromJsonString(jsonString))
|
||||
.toList();
|
||||
final otherEmbedding =
|
||||
Vector.fromList(args['otherEmbedding'] as List<double>);
|
||||
final Map<int, double> result = {};
|
||||
for (final embedding in embeddings) {
|
||||
final double similarity = embedding.vector.dot(otherEmbedding);
|
||||
result[embedding.fileID] = similarity;
|
||||
case IsolateOperation.computeBulkSimilarities:
|
||||
final imageEmbeddings =
|
||||
_isolateCache[imageEmbeddingsKey] as List<EmbeddingVector>;
|
||||
final textEmbedding =
|
||||
args["textQueryToEmbeddingMap"] as Map<String, List<double>>;
|
||||
final minimumSimilarityMap =
|
||||
args["minimumSimilarityMap"] as Map<String, double>;
|
||||
final result = <String, List<QueryResult>>{};
|
||||
for (final MapEntry<String, List<double>> entry
|
||||
in textEmbedding.entries) {
|
||||
final query = entry.key;
|
||||
final textVector = Vector.fromList(entry.value);
|
||||
final minimumSimilarity = minimumSimilarityMap[query]!;
|
||||
final queryResults = <QueryResult>[];
|
||||
for (final imageEmbedding in imageEmbeddings) {
|
||||
final similarity = imageEmbedding.vector.dot(textVector);
|
||||
if (similarity >= minimumSimilarity) {
|
||||
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
|
||||
}
|
||||
}
|
||||
queryResults
|
||||
.sort((first, second) => second.score.compareTo(first.score));
|
||||
result[query] = queryResults;
|
||||
}
|
||||
return Map<int, double>.from(result);
|
||||
return result;
|
||||
|
||||
/// Cases for MLComputer end here
|
||||
|
||||
@@ -145,5 +166,27 @@ Future<dynamic> isolateFunction(
|
||||
return result;
|
||||
|
||||
/// Cases for FaceClusteringService end here
|
||||
|
||||
/// Cases for Caching start here
|
||||
|
||||
/// Caching
|
||||
case IsolateOperation.setIsolateCache:
|
||||
final key = args['key'] as String;
|
||||
final value = args['value'];
|
||||
_isolateCache[key] = value;
|
||||
return true;
|
||||
|
||||
/// Caching
|
||||
case IsolateOperation.clearIsolateCache:
|
||||
final key = args['key'] as String;
|
||||
_isolateCache.remove(key);
|
||||
return true;
|
||||
|
||||
/// Caching
|
||||
case IsolateOperation.clearAllIsolateCache:
|
||||
_isolateCache.clear();
|
||||
return true;
|
||||
|
||||
/// Cases for Caching stop here
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ import 'dart:isolate';
|
||||
|
||||
import "package:dart_ui_isolate/dart_ui_isolate.dart";
|
||||
import "package:flutter/foundation.dart" show kDebugMode;
|
||||
import "package:flutter/services.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:photos/core/error-reporting/isolate_logging.dart";
|
||||
import "package:photos/models/base/id.dart";
|
||||
@@ -13,7 +14,7 @@ abstract class SuperIsolate {
|
||||
Logger get logger;
|
||||
|
||||
Timer? _inactivityTimer;
|
||||
final Duration _inactivityDuration = const Duration(seconds: 120);
|
||||
final Duration _inactivityDuration = const Duration(seconds: 60);
|
||||
int _activeTasks = 0;
|
||||
|
||||
final _initIsolateLock = Lock();
|
||||
@@ -36,15 +37,22 @@ abstract class SuperIsolate {
|
||||
|
||||
_receivePort = ReceivePort();
|
||||
|
||||
// Get the root token before spawning the isolate
|
||||
final rootToken = RootIsolateToken.instance;
|
||||
if (rootToken == null && !isDartUiIsolate) {
|
||||
logger.severe('Failed to get RootIsolateToken');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
_isolate = isDartUiIsolate
|
||||
? await DartUiIsolate.spawn(
|
||||
_isolateMain,
|
||||
_receivePort.sendPort,
|
||||
[_receivePort.sendPort, null],
|
||||
)
|
||||
: await Isolate.spawn(
|
||||
_isolateMain,
|
||||
_receivePort.sendPort,
|
||||
[_receivePort.sendPort, rootToken],
|
||||
debugName: isolateName,
|
||||
);
|
||||
_mainSendPort = await _receivePort.first as SendPort;
|
||||
@@ -60,12 +68,18 @@ abstract class SuperIsolate {
|
||||
}
|
||||
|
||||
@pragma('vm:entry-point')
|
||||
static void _isolateMain(SendPort mainSendPort) async {
|
||||
static void _isolateMain(List<dynamic> args) async {
|
||||
final SendPort mainSendPort = args[0] as SendPort;
|
||||
final RootIsolateToken? rootToken = args[1] as RootIsolateToken?;
|
||||
|
||||
Logger.root.level = kDebugMode ? Level.ALL : Level.INFO;
|
||||
final IsolateLogger isolateLogger = IsolateLogger();
|
||||
Logger.root.onRecord.listen(isolateLogger.onLogRecordInIsolate);
|
||||
final receivePort = ReceivePort();
|
||||
mainSendPort.send(receivePort.sendPort);
|
||||
if (rootToken != null) {
|
||||
BackgroundIsolateBinaryMessenger.ensureInitialized(rootToken);
|
||||
}
|
||||
|
||||
receivePort.listen((message) async {
|
||||
final taskID = message[0] as String;
|
||||
@@ -137,6 +151,24 @@ abstract class SuperIsolate {
|
||||
|
||||
bool postFunctionlockStop(IsolateOperation operation) => false;
|
||||
|
||||
Future<void> cacheData(String key, dynamic value) async {
|
||||
await runInIsolate(IsolateOperation.setIsolateCache, {
|
||||
'key': key,
|
||||
'value': value,
|
||||
});
|
||||
}
|
||||
|
||||
/// Clears specific data from the isolate's cache
|
||||
Future<void> clearCachedData(String key) async {
|
||||
await runInIsolate(IsolateOperation.clearIsolateCache, {
|
||||
'key': key,
|
||||
});
|
||||
}
|
||||
|
||||
Future<void> clearAllCachedData() async {
|
||||
await runInIsolate(IsolateOperation.clearAllIsolateCache, {});
|
||||
}
|
||||
|
||||
/// Resets a timer that kills the isolate after a certain amount of inactivity.
|
||||
///
|
||||
/// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`)
|
||||
@@ -161,6 +193,7 @@ abstract class SuperIsolate {
|
||||
void _disposeIsolate() async {
|
||||
if (!_isIsolateSpawned) return;
|
||||
logger.info('Disposing isolate');
|
||||
await clearAllCachedData();
|
||||
await onDispose();
|
||||
_isIsolateSpawned = false;
|
||||
_isolate.kill();
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import 'dart:async';
|
||||
import 'dart:typed_data' show Uint8List;
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
import "package:photos/models/ml/face/box.dart";
|
||||
import "package:photos/services/isolate_functions.dart";
|
||||
import "package:photos/services/isolate_service.dart";
|
||||
import "package:photos/utils/image_ml_util.dart";
|
||||
|
||||
class FaceThumbnailGenerator extends SuperIsolate {
|
||||
@override
|
||||
Logger get logger => _logger;
|
||||
final _logger = Logger('FaceThumbnailGenerator');
|
||||
|
||||
@override
|
||||
bool get isDartUiIsolate => true;
|
||||
|
||||
@override
|
||||
String get isolateName => "FaceThumbnailGenerator";
|
||||
|
||||
@override
|
||||
bool get shouldAutomaticDispose => true;
|
||||
|
||||
// Singleton pattern
|
||||
FaceThumbnailGenerator._privateConstructor();
|
||||
static final FaceThumbnailGenerator instance =
|
||||
FaceThumbnailGenerator._privateConstructor();
|
||||
factory FaceThumbnailGenerator() => instance;
|
||||
|
||||
/// Generates face thumbnails for all [faceBoxes] in [imageData].
|
||||
///
|
||||
/// Uses [generateFaceThumbnailsUsingCanvas] inside the isolate.
|
||||
Future<List<Uint8List>> generateFaceThumbnails(
|
||||
String imagePath,
|
||||
List<FaceBox> faceBoxes,
|
||||
) async {
|
||||
final List<Map<String, dynamic>> faceBoxesJson =
|
||||
faceBoxes.map((box) => box.toJson()).toList();
|
||||
return await runInIsolate(
|
||||
IsolateOperation.generateFaceThumbnails,
|
||||
{
|
||||
'imagePath': imagePath,
|
||||
'faceBoxesList': faceBoxesJson,
|
||||
},
|
||||
).then((value) => value.cast<Uint8List>());
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,13 @@
|
||||
import 'dart:async';
|
||||
import 'dart:typed_data' show Uint8List;
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/linalg.dart";
|
||||
import "package:photos/models/ml/face/box.dart";
|
||||
import "package:photos/models/ml/vector.dart";
|
||||
import "package:photos/services/isolate_functions.dart";
|
||||
import "package:photos/services/isolate_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_constants.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/query_result.dart";
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:photos/utils/image_ml_util.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
class MLComputer extends SuperIsolate {
|
||||
@@ -20,7 +18,7 @@ class MLComputer extends SuperIsolate {
|
||||
final _initModelLock = Lock();
|
||||
|
||||
@override
|
||||
bool get isDartUiIsolate => true;
|
||||
bool get isDartUiIsolate => false;
|
||||
|
||||
@override
|
||||
String get isolateName => "MLComputerIsolate";
|
||||
@@ -33,24 +31,6 @@ class MLComputer extends SuperIsolate {
|
||||
static final MLComputer instance = MLComputer._privateConstructor();
|
||||
factory MLComputer() => instance;
|
||||
|
||||
/// Generates face thumbnails for all [faceBoxes] in [imageData].
|
||||
///
|
||||
/// Uses [generateFaceThumbnailsUsingCanvas] inside the isolate.
|
||||
Future<List<Uint8List>> generateFaceThumbnails(
|
||||
String imagePath,
|
||||
List<FaceBox> faceBoxes,
|
||||
) async {
|
||||
final List<Map<String, dynamic>> faceBoxesJson =
|
||||
faceBoxes.map((box) => box.toJson()).toList();
|
||||
return await runInIsolate(
|
||||
IsolateOperation.generateFaceThumbnails,
|
||||
{
|
||||
'imagePath': imagePath,
|
||||
'faceBoxesList': faceBoxesJson,
|
||||
},
|
||||
).then((value) => value.cast<Uint8List>());
|
||||
}
|
||||
|
||||
Future<List<double>> runClipText(String query) async {
|
||||
try {
|
||||
await _ensureLoadedClipTextModel();
|
||||
@@ -66,19 +46,6 @@ class MLComputer extends SuperIsolate {
|
||||
}
|
||||
}
|
||||
|
||||
Future<Map<int, double>> compareEmbeddings(List<EmbeddingVector> embeddings, Vector otherEmbedding) async {
|
||||
try {
|
||||
final fileIdToSimilarity = await runInIsolate(IsolateOperation.compareEmbeddings, {
|
||||
"embeddings": embeddings.map((e) => e.toJsonString()).toList(),
|
||||
"otherEmbedding": otherEmbedding.toList(),
|
||||
}) as Map<int, double>;
|
||||
return fileIdToSimilarity;
|
||||
} catch (e, s) {
|
||||
_logger.severe("Could not compare embeddings MLComputer isolate", e, s);
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _ensureLoadedClipTextModel() async {
|
||||
return _initModelLock.synchronized(() async {
|
||||
if (ClipTextEncoder.instance.isInitialized) return;
|
||||
@@ -114,4 +81,61 @@ class MLComputer extends SuperIsolate {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Future<Map<String, List<QueryResult>>> computeBulkSimilarities(
|
||||
Map<String, List<double>> textQueryToEmbeddingMap,
|
||||
Map<String, double> minimumSimilarityMap,
|
||||
) async {
|
||||
try {
|
||||
final queryToResults =
|
||||
await runInIsolate(IsolateOperation.computeBulkSimilarities, {
|
||||
"textQueryToEmbeddingMap": textQueryToEmbeddingMap,
|
||||
"minimumSimilarityMap": minimumSimilarityMap,
|
||||
}) as Map<String, List<QueryResult>>;
|
||||
return queryToResults;
|
||||
} catch (e, s) {
|
||||
_logger.severe(
|
||||
"Could not bulk compare embeddings inside MLComputer isolate",
|
||||
e,
|
||||
s,
|
||||
);
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> cacheImageEmbeddings(List<EmbeddingVector> embeddings) async {
|
||||
try {
|
||||
await runInIsolate(
|
||||
IsolateOperation.setIsolateCache,
|
||||
{
|
||||
'key': imageEmbeddingsKey,
|
||||
'value': embeddings,
|
||||
},
|
||||
) as bool;
|
||||
_logger.info(
|
||||
'Cached ${embeddings.length} image embeddings inside MLComputer',
|
||||
);
|
||||
return;
|
||||
} catch (e, s) {
|
||||
_logger.severe("Could not cache image embeddings in MLComputer", e, s);
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> clearImageEmbeddingsCache() async {
|
||||
try {
|
||||
await runInIsolate(
|
||||
IsolateOperation.clearIsolateCache,
|
||||
{'key': imageEmbeddingsKey},
|
||||
) as bool;
|
||||
return;
|
||||
} catch (e, s) {
|
||||
_logger.severe(
|
||||
"Could not clear image embeddings cache in MLComputer",
|
||||
e,
|
||||
s,
|
||||
);
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1
mobile/lib/services/machine_learning/ml_constants.dart
Normal file
1
mobile/lib/services/machine_learning/ml_constants.dart
Normal file
@@ -0,0 +1 @@
|
||||
const imageEmbeddingsKey = "imageEmbeddings";
|
||||
@@ -0,0 +1,6 @@
|
||||
class QueryResult {
|
||||
final int id;
|
||||
final double score;
|
||||
|
||||
QueryResult(this.id, this.score);
|
||||
}
|
||||
@@ -1,12 +1,10 @@
|
||||
import "dart:async" show unawaited;
|
||||
import "dart:async" show Timer, unawaited;
|
||||
import "dart:developer" as dev show log;
|
||||
import "dart:math" show min;
|
||||
import "dart:ui" show Image;
|
||||
|
||||
import "package:computer/computer.dart";
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/vector.dart";
|
||||
import "package:photos/core/cache/lru_map.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
@@ -15,13 +13,14 @@ import 'package:photos/events/embedding_updated_event.dart';
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/models/ml/clip.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:photos/models/ml/vector.dart";
|
||||
import "package:photos/service_locator.dart";
|
||||
import "package:photos/services/collections_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_computer.dart";
|
||||
import "package:photos/services/machine_learning/ml_result.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/query_result.dart";
|
||||
import "package:shared_preferences/shared_preferences.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
class SemanticSearchService {
|
||||
static final _logger = Logger("SemanticSearchService");
|
||||
@@ -30,7 +29,6 @@ class SemanticSearchService {
|
||||
static final SemanticSearchService instance =
|
||||
SemanticSearchService._privateConstructor();
|
||||
|
||||
static final Computer _computer = Computer.shared();
|
||||
final LRUMap<String, List<double>> _queryEmbeddingCache = LRUMap(20);
|
||||
static const kMinimumSimilarityThreshold = 0.175;
|
||||
late final mlDataDB = MLDataDB.instance;
|
||||
@@ -38,7 +36,11 @@ class SemanticSearchService {
|
||||
bool _hasInitialized = false;
|
||||
bool _textModelIsLoaded = false;
|
||||
|
||||
Future<List<EmbeddingVector>>? _cachedImageEmbeddingVectors;
|
||||
final _cacheLock = Lock();
|
||||
bool _imageEmbeddingsAreCached = false;
|
||||
Timer? _embeddingsCacheTimer;
|
||||
final Duration _embeddingsCacheDuration = const Duration(seconds: 60);
|
||||
|
||||
Future<(String, List<EnteFile>)>? _searchScreenRequest;
|
||||
String? _latestPendingQuery;
|
||||
|
||||
@@ -53,12 +55,11 @@ class SemanticSearchService {
|
||||
_logger.info("init called");
|
||||
_hasInitialized = true;
|
||||
|
||||
// call getClipEmbeddings after 5 seconds
|
||||
Future.delayed(const Duration(seconds: 5), () async {
|
||||
await getClipVectors();
|
||||
});
|
||||
Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
|
||||
_cachedImageEmbeddingVectors = null;
|
||||
if (_imageEmbeddingsAreCached) {
|
||||
MLComputer.instance.clearImageEmbeddingsCache();
|
||||
_imageEmbeddingsAreCached = false;
|
||||
}
|
||||
});
|
||||
|
||||
unawaited(_loadTextModel(delay: true));
|
||||
@@ -108,14 +109,21 @@ class SemanticSearchService {
|
||||
_logger.info("Indexes cleared");
|
||||
}
|
||||
|
||||
Future<List<EmbeddingVector>> getClipVectors() async {
|
||||
if (_cachedImageEmbeddingVectors != null) {
|
||||
return _cachedImageEmbeddingVectors!;
|
||||
}
|
||||
_cachedImageEmbeddingVectors ??= mlDataDB.getAllClipVectors();
|
||||
_logger.info("read all embeddings from DB");
|
||||
|
||||
return _cachedImageEmbeddingVectors!;
|
||||
Future<void> _cacheClipVectors() async {
|
||||
return _cacheLock.synchronized(() async {
|
||||
_resetInactivityTimer();
|
||||
if (_imageEmbeddingsAreCached) {
|
||||
return;
|
||||
}
|
||||
final now = DateTime.now();
|
||||
final imageEmbeddings = await mlDataDB.getAllClipVectors();
|
||||
_logger.info(
|
||||
"read all ${imageEmbeddings.length} embeddings from DB in ${DateTime.now().difference(now).inMilliseconds} ms",
|
||||
);
|
||||
await MLComputer.instance.cacheImageEmbeddings(imageEmbeddings);
|
||||
_imageEmbeddingsAreCached = true;
|
||||
return;
|
||||
});
|
||||
}
|
||||
|
||||
Future<List<EnteFile>> getMatchingFiles(
|
||||
@@ -257,17 +265,10 @@ class SemanticSearchService {
|
||||
required Map<String, double> minimumSimilarityMap,
|
||||
}) async {
|
||||
final startTime = DateTime.now();
|
||||
final imageEmbeddings = await getClipVectors();
|
||||
final Map<String, List<QueryResult>> queryResults = await _computer
|
||||
.compute<Map<String, dynamic>, Map<String, List<QueryResult>>>(
|
||||
computeBulkSimilarities,
|
||||
param: {
|
||||
"imageEmbeddings": imageEmbeddings,
|
||||
"textQueryToEmbeddingMap": textQueryToEmbeddingMap,
|
||||
"minimumSimilarityMap": minimumSimilarityMap,
|
||||
},
|
||||
taskName: "computeBulkSimilarities",
|
||||
);
|
||||
await _cacheClipVectors();
|
||||
final Map<String, List<QueryResult>> queryResults = await MLComputer
|
||||
.instance
|
||||
.computeBulkSimilarities(textQueryToEmbeddingMap, minimumSimilarityMap);
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"computingSimilarities took for ${textQueryToEmbeddingMap.length} queries " +
|
||||
@@ -278,6 +279,19 @@ class SemanticSearchService {
|
||||
return queryResults;
|
||||
}
|
||||
|
||||
void _resetInactivityTimer() {
|
||||
_embeddingsCacheTimer?.cancel();
|
||||
_embeddingsCacheTimer = Timer(_embeddingsCacheDuration, () {
|
||||
_logger.info(
|
||||
'Embeddings cache is unused for ${_embeddingsCacheDuration.inSeconds} seconds. Removing cache.',
|
||||
);
|
||||
if (_imageEmbeddingsAreCached) {
|
||||
MLComputer.instance.clearImageEmbeddingsCache();
|
||||
_imageEmbeddingsAreCached = false;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static Future<ClipResult> runClipImage(
|
||||
int enteFileID,
|
||||
Image image,
|
||||
@@ -296,50 +310,3 @@ class SemanticSearchService {
|
||||
return clipResult;
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, List<QueryResult>> computeBulkSimilarities(Map args) {
|
||||
final imageEmbeddings = args["imageEmbeddings"] as List<EmbeddingVector>;
|
||||
final textEmbedding =
|
||||
args["textQueryToEmbeddingMap"] as Map<String, List<double>>;
|
||||
final minimumSimilarityMap =
|
||||
args["minimumSimilarityMap"] as Map<String, double>;
|
||||
final result = <String, List<QueryResult>>{};
|
||||
for (final MapEntry<String, List<double>> entry in textEmbedding.entries) {
|
||||
final query = entry.key;
|
||||
final textVector = Vector.fromList(entry.value);
|
||||
final minimumSimilarity = minimumSimilarityMap[query]!;
|
||||
final queryResults = <QueryResult>[];
|
||||
if (!kDebugMode) {
|
||||
for (final imageEmbedding in imageEmbeddings) {
|
||||
final similarity = imageEmbedding.vector.dot(textVector);
|
||||
if (similarity >= minimumSimilarity) {
|
||||
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
double bestScore = 0.0;
|
||||
for (final imageEmbedding in imageEmbeddings) {
|
||||
final similarity = imageEmbedding.vector.dot(textVector);
|
||||
if (similarity >= minimumSimilarity) {
|
||||
queryResults.add(QueryResult(imageEmbedding.fileID, similarity));
|
||||
}
|
||||
if (similarity > bestScore) {
|
||||
bestScore = similarity;
|
||||
}
|
||||
}
|
||||
if (kDebugMode && queryResults.isEmpty) {
|
||||
dev.log("No results found for query with best score: $bestScore");
|
||||
}
|
||||
}
|
||||
queryResults.sort((first, second) => second.score.compareTo(first.score));
|
||||
result[query] = queryResults;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
class QueryResult {
|
||||
final int id;
|
||||
final double score;
|
||||
|
||||
QueryResult(this.id, this.score);
|
||||
}
|
||||
|
||||
@@ -35,7 +35,6 @@ import "package:photos/services/location_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_computer.dart";
|
||||
import "package:photos/services/machine_learning/ml_result.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
|
||||
import "package:photos/services/search_service.dart";
|
||||
|
||||
class MemoriesResult {
|
||||
@@ -92,8 +91,7 @@ class SmartMemoriesService {
|
||||
await MLDataDB.instance.getFileIDsToFacesWithoutEmbedding();
|
||||
_logger.finest('fileIdToFaces has ${fileIdToFaces.length} entries $t');
|
||||
|
||||
final allImageEmbeddings =
|
||||
await SemanticSearchService.instance.getClipVectors();
|
||||
final allImageEmbeddings = await MLDataDB.instance.getAllClipVectors();
|
||||
_logger.finest(
|
||||
'allImageEmbeddings has ${allImageEmbeddings.length} entries $t',
|
||||
);
|
||||
|
||||
@@ -10,7 +10,7 @@ import "package:photos/models/file/file.dart";
|
||||
import "package:photos/models/file/file_type.dart";
|
||||
import "package:photos/models/ml/face/box.dart";
|
||||
import "package:photos/models/ml/face/face.dart";
|
||||
import "package:photos/services/machine_learning/ml_computer.dart";
|
||||
import "package:photos/services/machine_learning/face_thumbnail_generator.dart";
|
||||
import "package:photos/utils/file_util.dart";
|
||||
import "package:photos/utils/thumbnail_util.dart";
|
||||
import "package:pool/pool.dart";
|
||||
@@ -210,7 +210,7 @@ Future<Map<String, Uint8List>?> _getFaceCrops(
|
||||
faceBoxes.add(e.value);
|
||||
}
|
||||
final List<Uint8List> faceCrop =
|
||||
await MLComputer.instance.generateFaceThumbnails(
|
||||
await FaceThumbnailGenerator.instance.generateFaceThumbnails(
|
||||
// await generateJpgFaceThumbnails(
|
||||
imagePath,
|
||||
faceBoxes,
|
||||
|
||||
Reference in New Issue
Block a user