[mob] Use stream to optimize processing

This commit is contained in:
Neeraj Gupta
2024-07-24 13:45:18 +05:30
parent 6d7d7b88a6
commit e134d599b7
3 changed files with 47 additions and 29 deletions

View File

@@ -67,7 +67,6 @@ class FaceRecognitionService {
Future<void> sync() async {
await _syncPersonFeedback();
if (LocalSettings.instance.remoteFetchEnabled) {
await _syncFaceEmbeddings();
} else {
_logger.severe(
'Not fetching embeddings because user manually disabled it in debug options',
@@ -88,22 +87,24 @@ class FaceRecognitionService {
_isSyncing = false;
}
Future<List<FileMLInstruction>> _syncFaceEmbeddings() async {
Stream<List<FileMLInstruction>> syncEmbeddings({
int yieldSize = 10,
}) async* {
final List<FileMLInstruction> filesToIndex = await getFilesForMlIndexing();
final Map<int, FileMLInstruction> pendingIndex = {};
final List<List<FileMLInstruction>> chunks =
filesToIndex.chunks(_embeddingFetchLimit); // Chunks of 200
int filesIndexedForFaces = 0;
int filesIndexedForClip = 0;
filesToIndex.chunks(_embeddingFetchLimit);
List<FileMLInstruction> batchToYield = [];
for (final chunk in chunks) {
final Set<int> ids = {};
final Map<int, FileMLInstruction> pendingIndex = {};
for (final instruction in chunk) {
ids.add(instruction.file.uploadedFileID!);
pendingIndex[instruction.file.uploadedFileID!] = instruction;
}
_logger.info('starting remote fetch for ${ids.length} files');
_logger.info("fetching embeddings for ${ids.length} files");
final res = await RemoteFileMLService.instance.getFileEmbeddings(ids);
_logger.info('fetched ${res.mlData.length} embeddings');
_logger.info("embeddingResponse ${res.debugLog()}");
final List<Face> faces = [];
final List<ClipEmbedding> clipEmbeddings = [];
for (RemoteFileML fileMl in res.mlData.values) {
@@ -112,7 +113,6 @@ class FaceRecognitionService {
//Note: Always do null check, empty value means no face was found.
if (facesFromRemoteEmbedding != null) {
faces.addAll(facesFromRemoteEmbedding);
filesIndexedForFaces++;
existingInstruction.shouldRunFaces = false;
}
if (fileMl.clipEmbedding != null &&
@@ -125,7 +125,6 @@ class FaceRecognitionService {
),
);
existingInstruction.shouldRunClip = false;
filesIndexedForClip++;
}
if (!existingInstruction.pendingML) {
pendingIndex.remove(fileMl.fileID);
@@ -134,23 +133,30 @@ class FaceRecognitionService {
pendingIndex[fileMl.fileID] = existingInstruction;
}
}
for (final fileID in pendingIndex.keys) {
final instruction = pendingIndex[fileID]!;
if (instruction.pendingML) {
batchToYield.add(instruction);
if (batchToYield.length == yieldSize) {
_logger.info("queueing indexing for $yieldSize");
yield batchToYield;
batchToYield = [];
}
}
}
if (res.noEmbeddingFileIDs.isNotEmpty) {
_logger.info(
'No embeddings found for ${res.noEmbeddingFileIDs.length} files',
);
for (final fileID in res.noEmbeddingFileIDs) {
faces.add(Face.empty(fileID, error: false));
}
}
await FaceMLDataDB.instance.bulkInsertFaces(faces);
await EmbeddingsDB.instance.putMany(clipEmbeddings);
_logger.info(
'Embedding store files for face $filesIndexedForFaces, and clip $filesIndexedForClip',
);
}
final List<FileMLInstruction> instructions = pendingIndex.values.toList();
return instructions;
// Yield any remaining instructions
if (batchToYield.isNotEmpty) {
_logger.info("queueing indexing for $batchToYield.length");
yield batchToYield;
}
}
// Returns a list of faces from the given remote fileML. null if the version is less than the current version

View File

@@ -16,4 +16,17 @@ class FilesMLDataResponse {
required this.fetchErrorFileIDs,
required this.pendingIndexFileIDs,
});
String debugLog() {
final nonZeroNoEmbeddingFileIDs = noEmbeddingFileIDs.isNotEmpty
? ', smallEmbeddings: ${noEmbeddingFileIDs.length}'
: '';
final nonZeroFetchErrorFileIDs = fetchErrorFileIDs.isNotEmpty
? ', errorForFileIDs: ${fetchErrorFileIDs.length}'
: '';
final nonZeroPendingIndexFileIDs = pendingIndexFileIDs.isNotEmpty
? ', pendingIndexFileIDs: ${pendingIndexFileIDs.length}'
: '';
return 'MLRemote(mlData: ${mlData.length}$nonZeroNoEmbeddingFileIDs$nonZeroFetchErrorFileIDs$nonZeroPendingIndexFileIDs)';
}
}

View File

@@ -14,7 +14,6 @@ import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
import "package:photos/events/machine_learning_control_event.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/list.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/box.dart";
import "package:photos/face/model/detection.dart" as face_detection;
@@ -61,7 +60,9 @@ class MLService {
// Singleton pattern
MLService._privateConstructor();
static final instance = MLService._privateConstructor();
factory MLService() => instance;
final _initModelLock = Lock();
@@ -175,26 +176,25 @@ class MLService {
try {
_isIndexingOrClusteringRunning = true;
_logger.info('starting image indexing');
final filesToIndex = await getFilesForMlIndexing();
final List<List<FileMLInstruction>> chunks =
filesToIndex.chunks(_fileDownloadLimit);
final Stream<List<FileMLInstruction>> instructionStream =
FaceRecognitionService.instance
.syncEmbeddings(yieldSize: _fileDownloadLimit);
int fileAnalyzedCount = 0;
final Stopwatch stopwatch = Stopwatch()..start();
outerLoop:
for (final chunk in chunks) {
await for (final chunk in instructionStream) {
if (!await canUseHighBandwidth()) {
_logger.info(
'stopping indexing because user is not connected to wifi',
);
break outerLoop;
break;
}
final futures = <Future<bool>>[];
for (final instruction in chunk) {
if (_shouldPauseIndexingAndClustering) {
_logger.info("indexAllImages() was paused, stopping");
break outerLoop;
break;
}
await _ensureReadyForInference();
futures.add(processImage(instruction));
@@ -206,7 +206,6 @@ class MLService {
);
fileAnalyzedCount += sumFutures;
}
_logger.info(
"`indexAllImages()` finished. Analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image)",
);