[mob][photos] Add clip indexing to MLService

This commit is contained in:
laurenspriem
2024-07-04 07:24:42 +05:30
parent 11656a59a6
commit e65b4643cd
3 changed files with 135 additions and 86 deletions

View File

@@ -9,7 +9,6 @@ import "package:computer/computer.dart";
import "package:dart_ui_isolate/dart_ui_isolate.dart";
import "package:flutter/foundation.dart" show debugPrint;
import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:package_info_plus/package_info_plus.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
@@ -22,7 +21,6 @@ import "package:photos/face/model/detection.dart" as face_detection;
import "package:photos/face/model/dimension.dart";
import "package:photos/face/model/face.dart";
import "package:photos/face/model/landmark.dart";
import "package:photos/models/file/file.dart";
import "package:photos/service_locator.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart";
@@ -35,6 +33,8 @@ import 'package:photos/services/machine_learning/file_ml/file_ml.dart';
import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart';
import 'package:photos/services/machine_learning/ml_exceptions.dart';
import 'package:photos/services/machine_learning/ml_result.dart';
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/local_settings.dart";
import "package:photos/utils/ml_util.dart";
@@ -192,7 +192,7 @@ class FaceMlService {
break outerLoop;
}
await _ensureReadyForInference();
futures.add(processImage(instruction.enteFile));
futures.add(processImage(instruction));
}
final awaitedFutures = await Future.wait(futures);
final sumFutures = awaitedFutures.fold<int>(
@@ -385,22 +385,20 @@ class FaceMlService {
}
}
Future<bool> processImage(EnteFile enteFile) async {
Future<bool> processImage(FileMLInstruction instruction) async { // TODO: clean this function up
_logger.info(
"`processImage` start processing image with uploadedFileID: ${enteFile.uploadedFileID}",
"`processImage` start processing image with uploadedFileID: ${instruction.enteFile.uploadedFileID}",
);
bool actuallyRanML = false;
try {
final MLResult? result = await _analyzeImageInSingleIsolate(
enteFile,
// preferUsingThumbnailForEverything: false,
// disposeImageIsolateAfterUse: false,
instruction,
);
if (result == null) {
if (!_shouldPauseIndexingAndClustering) {
_logger.severe(
"Failed to analyze image with uploadedFileID: ${enteFile.uploadedFileID}",
"Failed to analyze image with uploadedFileID: ${instruction.enteFile.uploadedFileID}",
);
}
return actuallyRanML;
@@ -410,7 +408,7 @@ class FaceMlService {
final List<Face> faces = [];
if (result.foundNoFaces) {
debugPrint(
'No faces detected for file with name:${enteFile.displayName}',
'No faces detected for file with name:${instruction.enteFile.displayName}',
);
faces.add(
Face.empty(result.fileId, error: result.errorOccured),
@@ -421,9 +419,9 @@ class FaceMlService {
result.decodedImageSize.height == -1) {
_logger.severe(
"decodedImageSize is not stored correctly for image with "
"ID: ${enteFile.uploadedFileID}");
"ID: ${instruction.enteFile.uploadedFileID}");
_logger.info(
"Using aligned image size for image with ID: ${enteFile.uploadedFileID}. This size is ${result.decodedImageSize.width}x${result.decodedImageSize.height} compared to size of ${enteFile.width}x${enteFile.height} in the metadata",
"Using aligned image size for image with ID: ${instruction.enteFile.uploadedFileID}. This size is ${result.decodedImageSize.width}x${result.decodedImageSize.height} compared to size of ${instruction.enteFile.width}x${instruction.enteFile.height} in the metadata",
);
}
for (int i = 0; i < result.faces!.length; ++i) {
@@ -463,9 +461,9 @@ class FaceMlService {
_logger.info("inserting ${faces.length} faces for ${result.fileId}");
if (!result.errorOccured) {
await RemoteFileMLService.instance.putFileEmbedding(
enteFile,
instruction.enteFile,
FileMl(
enteFile.uploadedFileID!,
instruction.enteFile.uploadedFileID!,
FaceEmbeddings(
faces,
result.mlVersion,
@@ -483,18 +481,30 @@ class FaceMlService {
await FaceMLDataDB.instance.bulkInsertFaces(faces);
return actuallyRanML;
}
if (result.clipRan) {
actuallyRanML = true;
await SemanticSearchService.storeClipImageResult(
result.clip!,
instruction.enteFile,
);
}
} on ThumbnailRetrievalException catch (e, s) {
_logger.severe(
'ThumbnailRetrievalException while processing image with ID ${enteFile.uploadedFileID}, storing empty face so indexing does not get stuck',
'ThumbnailRetrievalException while processing image with ID ${instruction.enteFile.uploadedFileID}, storing empty face so indexing does not get stuck',
e,
s,
);
await FaceMLDataDB.instance
.bulkInsertFaces([Face.empty(enteFile.uploadedFileID!, error: true)]);
await FaceMLDataDB.instance.bulkInsertFaces(
[Face.empty(instruction.enteFile.uploadedFileID!, error: true)],
);
await SemanticSearchService.storeEmptyClipImageResult(
instruction.enteFile,
);
return true;
} catch (e, s) {
_logger.severe(
"Failed to analyze using FaceML for image with ID: ${enteFile.uploadedFileID}. Not storing any faces, which means it will be automatically retried later.",
"Failed to analyze using FaceML for image with ID: ${instruction.enteFile.uploadedFileID}. Not storing any faces, which means it will be automatically retried later.",
e,
s,
);
@@ -524,6 +534,11 @@ class FaceMlService {
} catch (e, s) {
_logger.severe("Could not initialize mobilefacenet", e, s);
}
try {
await ClipImageEncoder.instance.init();
} catch (e, s) {
_logger.severe("Could not initialize clip image", e, s);
}
_isModelsInitialized = true;
_logger.info('initModels done');
_logStatus();
@@ -546,7 +561,11 @@ class FaceMlService {
} catch (e, s) {
_logger.severe("Could not dispose mobilefacenet", e, s);
}
OrtEnv.instance.release();
try {
await ClipImageEncoder.instance.release();
} catch (e, s) {
_logger.severe("Could not dispose clip image", e, s);
}
_isModelsInitialized = false;
});
}
@@ -681,9 +700,13 @@ class FaceMlService {
}
/// Analyzes the given image data by running the full pipeline for faces, using [_analyzeImageSync] in the isolate.
Future<MLResult?> _analyzeImageInSingleIsolate(EnteFile enteFile) async {
final String filePath =
await getImagePathForML(enteFile, typeOfData: FileDataForML.fileData);
Future<MLResult?> _analyzeImageInSingleIsolate(
FileMLInstruction instruction,
) async {
final String filePath = await getImagePathForML(
instruction.enteFile,
typeOfData: FileDataForML.fileData,
);
final Stopwatch stopwatch = Stopwatch()..start();
late MLResult result;
@@ -693,12 +716,15 @@ class FaceMlService {
(
FaceMlOperation.analyzeImage,
{
"enteFileID": enteFile.uploadedFileID ?? -1,
"enteFileID": instruction.enteFile.uploadedFileID ?? -1,
"filePath": filePath,
"runFaces": instruction.shouldRunFaces,
"runClip": instruction.shouldRunClip,
"faceDetectionAddress":
FaceDetectionService.instance.sessionAddress,
"faceEmbeddingAddress":
FaceEmbeddingService.instance.sessionAddress,
"clipImageAddress": ClipImageEncoder.instance.sessionAddress,
}
),
) as String?;
@@ -711,20 +737,21 @@ class FaceMlService {
result = MLResult.fromJsonString(resultJsonString);
} catch (e, s) {
_logger.severe(
"Could not analyze image with ID ${enteFile.uploadedFileID} \n",
"Could not analyze image with ID ${instruction.enteFile.uploadedFileID} \n",
e,
s,
);
debugPrint(
"This image with ID ${enteFile.uploadedFileID} has name ${enteFile.displayName}.",
"This image with ID ${instruction.enteFile.uploadedFileID} has name ${instruction.enteFile.displayName}.",
);
final resultBuilder = MLResult.fromEnteFileID(enteFile.uploadedFileID!)
..errorOccurred();
final resultBuilder =
MLResult.fromEnteFileID(instruction.enteFile.uploadedFileID!)
..errorOccurred();
return resultBuilder;
}
stopwatch.stop();
_logger.info(
"Finished Analyze image with uploadedFileID ${enteFile.uploadedFileID}, in "
"Finished Analyze image with uploadedFileID ${instruction.enteFile.uploadedFileID}, in "
"${stopwatch.elapsedMilliseconds} ms (including time waiting for inference engine availability)",
);
@@ -735,8 +762,11 @@ class FaceMlService {
try {
final int enteFileID = args["enteFileID"] as int;
final String imagePath = args["filePath"] as String;
final bool runFaces = args["runFaces"] as bool;
final bool runClip = args["runClip"] as bool;
final int faceDetectionAddress = args["faceDetectionAddress"] as int;
final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int;
final int clipImageAddress = args["clipImageAddress"] as int;
dev.log(
"Start analyzing image with uploadedFileID: $enteFileID inside the isolate",
@@ -754,18 +784,29 @@ class FaceMlService {
final result = MLResult.fromEnteFileID(enteFileID);
result.decodedImageSize = decodedImageSize;
final resultFaces = await FaceRecognitionService.runFacesPipeline(
enteFileID,
image,
imageByteData,
faceDetectionAddress,
faceEmbeddingAddress,
);
if (resultFaces.isEmpty) {
return result..noFaceDetected();
if (runFaces) {
final resultFaces = await FaceRecognitionService.runFacesPipeline(
enteFileID,
image,
imageByteData,
faceDetectionAddress,
faceEmbeddingAddress,
);
if (resultFaces.isEmpty) {
return result..noFaceDetected();
}
result.faces = resultFaces;
}
result.faces = resultFaces;
if (runClip) {
final clipResult = await SemanticSearchService.runClipImage(
enteFileID,
image,
imageByteData,
clipImageAddress,
);
result.clip = clipResult;
}
return result;
} catch (e, s) {

View File

@@ -1,5 +1,5 @@
import "dart:io";
import "dart:typed_data";
import "dart:typed_data" show ByteData;
import "dart:ui" show Image;
import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
@@ -25,17 +25,17 @@ class ClipImageEncoder extends MlModel {
static final instance = ClipImageEncoder._privateConstructor();
factory ClipImageEncoder() => instance;
static Future<List<double>> predict(Map args) async {
final imageData = await File(args["imagePath"]).readAsBytes();
final image = await decodeImageFromData(imageData);
final ByteData imgByteData = await getByteDataFromImage(image);
final inputList = await preprocessImageClip(image, imgByteData);
static Future<List<double>> predict(
Image image,
ByteData imageByteData,
int sessionAddress,
) async {
final inputList = await preprocessImageClip(image, imageByteData);
final inputOrt =
OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 224, 224]);
final inputs = {'input': inputOrt};
final session = OrtSession.fromAddress(args["address"]);
final session = OrtSession.fromAddress(sessionAddress);
final runOptions = OrtRunOptions();
final outputs = session.run(runOptions, inputs);
final embedding = (outputs[0]?.value as List<List<double>>)[0];

View File

@@ -2,9 +2,11 @@ import "dart:async";
import "dart:collection";
import "dart:developer" as dev show log;
import "dart:math" show min;
import "dart:typed_data" show ByteData;
import "dart:ui" show Image;
import "package:computer/computer.dart";
import "package:flutter/services.dart";
import "package:flutter/services.dart" show PlatformException;
import "package:logging/logging.dart";
import "package:photos/core/cache/lru_map.dart";
import "package:photos/core/configuration.dart";
@@ -17,14 +19,13 @@ import "package:photos/events/file_uploaded_event.dart";
import "package:photos/events/machine_learning_control_event.dart";
import "package:photos/models/embedding.dart";
import "package:photos/models/file/file.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/services/collections_service.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
import 'package:photos/services/machine_learning/semantic_search/embedding_store.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/ggml.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart';
import "package:photos/utils/debouncer.dart";
import "package:photos/utils/device_info.dart";
import "package:photos/utils/local_settings.dart";
import "package:photos/utils/ml_util.dart";
// import "package:photos/utils/thumbnail_util.dart";
@@ -44,12 +45,9 @@ class SemanticSearchService {
final _logger = Logger("SemanticSearchService");
final _queue = Queue<EnteFile>();
final _frameworkInitialization = Completer<bool>();
final _embeddingLoaderDebouncer =
Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
late Model _currentModel;
late MLFramework _mlFramework;
bool _hasInitialized = false;
bool _isComputingEmbeddings = false;
bool _isSyncing = false;
@@ -70,12 +68,6 @@ class SemanticSearchService {
return;
}
_hasInitialized = true;
final shouldDownloadOverMobileData =
Configuration.instance.shouldBackupOverMobileData();
_currentModel = await _getCurrentModel();
_mlFramework = _currentModel == Model.onnxClip
? ONNX(shouldDownloadOverMobileData)
: GGML(shouldDownloadOverMobileData);
await EmbeddingStore.instance.init();
await EmbeddingsDB.instance.init();
await _loadEmbeddings();
@@ -115,19 +107,12 @@ class SemanticSearchService {
});
}
Future<void> release() async {
if (_frameworkInitialization.isCompleted) {
await _mlFramework.release();
}
}
Future<void> sync() async {
if (_isSyncing) {
return;
}
_isSyncing = true;
final fetchCompleted =
await EmbeddingStore.instance.pullEmbeddings(_currentModel);
final fetchCompleted = await EmbeddingStore.instance.pullEmbeddings();
if (fetchCompleted) {
await _backFill();
}
@@ -183,14 +168,14 @@ class SemanticSearchService {
}
Future<void> clearIndexes() async {
await EmbeddingStore.instance.clearEmbeddings(_currentModel);
_logger.info("Indexes cleared for $_currentModel");
await EmbeddingStore.instance.clearEmbeddings();
_logger.info("Indexes cleared");
}
Future<void> _loadEmbeddings() async {
_logger.info("Pulling cached embeddings");
final startTime = DateTime.now();
_cachedEmbeddings = await EmbeddingsDB.instance.getAll(_currentModel);
_cachedEmbeddings = await EmbeddingsDB.instance.getAll();
final endTime = DateTime.now();
_logger.info(
"Loading ${_cachedEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
@@ -219,10 +204,10 @@ class SemanticSearchService {
Future<List<int>> _getFileIDsToBeIndexed() async {
final uploadedFileIDs = await getIndexableFileIDs();
final embeddedFileIDs =
await EmbeddingsDB.instance.getFileIDs(_currentModel);
final embeddedFileIDs = await EmbeddingsDB.instance.getIndexedFileIds();
embeddedFileIDs.removeWhere((key, value) => value < clipMlVersion);
return uploadedFileIDs.difference(embeddedFileIDs).toList();
return uploadedFileIDs.difference(embeddedFileIDs.keys.toSet()).toList();
}
Future<void> clearQueue() async {
@@ -386,7 +371,6 @@ class SemanticSearchService {
final embedding = Embedding(
fileID: file.uploadedFileID!,
model: _currentModel,
embedding: result,
);
await EmbeddingStore.instance.storeEmbedding(
@@ -398,7 +382,7 @@ class SemanticSearchService {
"Could not get embedding for $file because FormatException occured, storing empty result locally",
e,
);
final embedding = Embedding.empty(file.uploadedFileID!, _currentModel);
final embedding = Embedding.empty(file.uploadedFileID!);
await EmbeddingsDB.instance.put(embedding);
} on PlatformException catch (e, s) {
_logger.severe(
@@ -406,13 +390,32 @@ class SemanticSearchService {
e,
s,
);
final embedding = Embedding.empty(file.uploadedFileID!, _currentModel);
final embedding = Embedding.empty(file.uploadedFileID!);
await EmbeddingsDB.instance.put(embedding);
} catch (e, s) {
_logger.severe(e, s);
}
}
static Future<void> storeClipImageResult(
ClipResult clipResult,
EnteFile entefile,
) async {
final embedding = Embedding(
fileID: clipResult.fileID,
embedding: clipResult.embedding,
);
await EmbeddingStore.instance.storeEmbedding(
entefile,
embedding,
);
}
static Future<void> storeEmptyClipImageResult(EnteFile entefile) async {
final embedding = Embedding.empty(entefile.uploadedFileID!);
await EmbeddingsDB.instance.put(embedding);
}
Future<List<double>> _getTextEmbedding(String query) async {
_logger.info("Searching for " + query);
final cachedResult = _queryCache.get(query);
@@ -453,14 +456,6 @@ class SemanticSearchService {
return queryResults;
}
Future<Model> _getCurrentModel() async {
if (await isGrapheneOS()) {
return Model.ggmlClip;
} else {
return Model.onnxClip;
}
}
void _startIndexing() {
_logger.info("Start indexing");
if (!_mlController.isCompleted) {
@@ -474,6 +469,19 @@ class SemanticSearchService {
_mlController = Completer<void>();
}
}
static Future<ClipResult> runClipImage(
int enteFileID,
Image image,
ByteData imageByteData,
int clipImageAddress,
) async {
final embedding =
await ClipImageEncoder.predict(image, imageByteData, clipImageAddress);
final clipResult = ClipResult(fileID: enteFileID, embedding: embedding);
return clipResult;
}
}
List<QueryResult> computeBulkSimilarities(Map args) {