Ml debug (#2847)
Various minor fixes for ML: - Only download model on Wifi and notify user - Fix ML queue not stopping on interaction - Partial fix for ML stopping randomly - Fix ML queue not stopping on disabling ML - Update ML debug options
This commit is contained in:
@@ -660,6 +660,14 @@ class MLDataDB {
|
||||
return maps.first['count'] as int;
|
||||
}
|
||||
|
||||
Future<int> getErroredFaceCount() async {
|
||||
final db = await instance.asyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore < 0',
|
||||
);
|
||||
return maps.first['count'] as int;
|
||||
}
|
||||
|
||||
Future<int> getClusteredOrFacelessFileCount() async {
|
||||
final db = await instance.asyncDB;
|
||||
final List<Map<String, dynamic>> clustered = await db.getAll(
|
||||
@@ -942,7 +950,7 @@ class MLDataDB {
|
||||
}
|
||||
|
||||
/// WARNING: This will delete ALL data in the tables! Only use this for debug/testing purposes!
|
||||
Future<void> dropFeedbackTables() async {
|
||||
Future<void> dropFacesFeedbackTables() async {
|
||||
try {
|
||||
final db = await instance.asyncDB;
|
||||
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import "dart:io" show Platform;
|
||||
|
||||
const faceMlVersion = 1;
|
||||
const clipMlVersion = 1;
|
||||
const clusterMlVersion = 1;
|
||||
const minimumClusterSize = 2;
|
||||
|
||||
const embeddingFetchLimit = 200;
|
||||
final fileDownloadMlLimit = Platform.isIOS ? 5 : 10;
|
||||
|
||||
@@ -5,17 +5,8 @@ import "dart:ui" show Image;
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/ml/clip_db.dart";
|
||||
import "package:photos/db/ml/db.dart";
|
||||
import "package:photos/events/diff_sync_complete_event.dart";
|
||||
import "package:photos/events/people_changed_event.dart";
|
||||
import "package:photos/extensions/list.dart";
|
||||
import "package:photos/models/ml/clip.dart";
|
||||
import "package:photos/models/ml/face/face.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:photos/service_locator.dart";
|
||||
import "package:photos/services/filedata/filedata_service.dart";
|
||||
import "package:photos/services/filedata/model/file_data.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart";
|
||||
@@ -23,7 +14,6 @@ import "package:photos/services/machine_learning/face_ml/person/person_service.d
|
||||
import "package:photos/services/machine_learning/ml_exceptions.dart";
|
||||
import "package:photos/services/machine_learning/ml_result.dart";
|
||||
import "package:photos/utils/image_ml_util.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
|
||||
class FaceRecognitionService {
|
||||
final _logger = Logger("FaceRecognitionService");
|
||||
@@ -40,8 +30,6 @@ class FaceRecognitionService {
|
||||
bool _shouldSyncPeople = false;
|
||||
bool _isSyncing = false;
|
||||
|
||||
static const _embeddingFetchLimit = 200;
|
||||
|
||||
Future<void> init() async {
|
||||
if (_isInitialized) {
|
||||
return;
|
||||
@@ -65,12 +53,6 @@ class FaceRecognitionService {
|
||||
|
||||
Future<void> sync() async {
|
||||
await _syncPersonFeedback();
|
||||
if (localSettings.remoteFetchEnabled) {
|
||||
} else {
|
||||
_logger.severe(
|
||||
'Not fetching embeddings because user manually disabled it in debug options',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _syncPersonFeedback() async {
|
||||
@@ -86,97 +68,6 @@ class FaceRecognitionService {
|
||||
_isSyncing = false;
|
||||
}
|
||||
|
||||
Stream<List<FileMLInstruction>> syncEmbeddings({
|
||||
int yieldSize = 10,
|
||||
}) async* {
|
||||
final List<FileMLInstruction> filesToIndex = await getFilesForMlIndexing();
|
||||
final List<List<FileMLInstruction>> chunks =
|
||||
filesToIndex.chunks(_embeddingFetchLimit);
|
||||
List<FileMLInstruction> batchToYield = [];
|
||||
|
||||
for (final chunk in chunks) {
|
||||
final Set<int> ids = {};
|
||||
final Map<int, FileMLInstruction> pendingIndex = {};
|
||||
for (final instruction in chunk) {
|
||||
ids.add(instruction.file.uploadedFileID!);
|
||||
pendingIndex[instruction.file.uploadedFileID!] = instruction;
|
||||
}
|
||||
_logger.info("fetching embeddings for ${ids.length} files");
|
||||
final res = await FileDataService.instance.getFilesData(ids);
|
||||
_logger.info("embeddingResponse ${res.debugLog()}");
|
||||
final List<Face> faces = [];
|
||||
final List<ClipEmbedding> clipEmbeddings = [];
|
||||
for (FileDataEntity fileMl in res.data.values) {
|
||||
final existingInstruction = pendingIndex[fileMl.fileID]!;
|
||||
final facesFromRemoteEmbedding = _getFacesFromRemoteEmbedding(fileMl);
|
||||
//Note: Always do null check, empty value means no face was found.
|
||||
if (facesFromRemoteEmbedding != null) {
|
||||
faces.addAll(facesFromRemoteEmbedding);
|
||||
existingInstruction.shouldRunFaces = false;
|
||||
}
|
||||
if (fileMl.clipEmbedding != null &&
|
||||
fileMl.clipEmbedding!.version >= clipMlVersion) {
|
||||
clipEmbeddings.add(
|
||||
ClipEmbedding(
|
||||
fileID: fileMl.fileID,
|
||||
embedding: fileMl.clipEmbedding!.embedding,
|
||||
version: fileMl.clipEmbedding!.version,
|
||||
),
|
||||
);
|
||||
existingInstruction.shouldRunClip = false;
|
||||
}
|
||||
if (!existingInstruction.pendingML) {
|
||||
pendingIndex.remove(fileMl.fileID);
|
||||
} else {
|
||||
existingInstruction.existingRemoteFileML = fileMl;
|
||||
pendingIndex[fileMl.fileID] = existingInstruction;
|
||||
}
|
||||
}
|
||||
for (final fileID in pendingIndex.keys) {
|
||||
final instruction = pendingIndex[fileID]!;
|
||||
if (instruction.pendingML) {
|
||||
batchToYield.add(instruction);
|
||||
if (batchToYield.length == yieldSize) {
|
||||
_logger.info("queueing indexing for $yieldSize");
|
||||
yield batchToYield;
|
||||
batchToYield = [];
|
||||
}
|
||||
}
|
||||
}
|
||||
await MLDataDB.instance.bulkInsertFaces(faces);
|
||||
await MLDataDB.instance.putMany(clipEmbeddings);
|
||||
}
|
||||
// Yield any remaining instructions
|
||||
if (batchToYield.isNotEmpty) {
|
||||
_logger.info("queueing indexing for ${batchToYield.length}");
|
||||
yield batchToYield;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a list of faces from the given remote fileML. null if the version is less than the current version
|
||||
// or if the remote faceEmbedding is null.
|
||||
List<Face>? _getFacesFromRemoteEmbedding(FileDataEntity fileMl) {
|
||||
final RemoteFaceEmbedding? remoteFaceEmbedding = fileMl.faceEmbedding;
|
||||
if (shouldDiscardRemoteEmbedding(fileMl)) {
|
||||
return null;
|
||||
}
|
||||
final List<Face> faces = [];
|
||||
if (remoteFaceEmbedding!.faces.isEmpty) {
|
||||
faces.add(
|
||||
Face.empty(fileMl.fileID),
|
||||
);
|
||||
} else {
|
||||
for (final f in remoteFaceEmbedding.faces) {
|
||||
f.fileInfo = FileInfo(
|
||||
imageHeight: remoteFaceEmbedding.height,
|
||||
imageWidth: remoteFaceEmbedding.width,
|
||||
);
|
||||
faces.add(f);
|
||||
}
|
||||
}
|
||||
return faces;
|
||||
}
|
||||
|
||||
static Future<List<FaceResult>> runFacesPipeline(
|
||||
int enteFileID,
|
||||
Image image,
|
||||
|
||||
@@ -158,8 +158,8 @@ class MLComputer {
|
||||
}
|
||||
|
||||
Future<List<double>> runClipText(String query) async {
|
||||
await _ensureLoadedClipTextModel();
|
||||
try {
|
||||
await _ensureLoadedClipTextModel();
|
||||
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
|
||||
final textEmbedding = await _runInIsolate(
|
||||
(
|
||||
@@ -195,9 +195,10 @@ class MLComputer {
|
||||
|
||||
// Load ClipText model
|
||||
final String modelName = ClipTextEncoder.instance.modelName;
|
||||
final String modelRemotePath = ClipTextEncoder.instance.modelRemotePath;
|
||||
final String modelPath =
|
||||
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
final String? modelPath = await ClipTextEncoder.instance.downloadModelSafe();
|
||||
if (modelPath == null) {
|
||||
throw Exception("Could not download clip text model, no wifi");
|
||||
}
|
||||
final address = await _runInIsolate(
|
||||
(
|
||||
MLComputerOperation.loadModel,
|
||||
|
||||
@@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/services/machine_learning/onnx_env.dart";
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:photos/utils/network_util.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
abstract class MlModel {
|
||||
@@ -32,6 +33,7 @@ abstract class MlModel {
|
||||
bool _isNativePluginInitialized = false;
|
||||
int _nativePluginSessionIndex = -1;
|
||||
|
||||
/// WARNING: If [downloadModel] was not first called, this method will download the model first using high bandwidth.
|
||||
Future<(String, String)> getModelNameAndPath() async {
|
||||
return _downloadModelLock.synchronized(() async {
|
||||
final path =
|
||||
@@ -40,12 +42,29 @@ abstract class MlModel {
|
||||
});
|
||||
}
|
||||
|
||||
Future<void> downloadModel([bool forceRefresh = false]) async {
|
||||
Future<String?> downloadModelSafe() async {
|
||||
if (await RemoteAssetsService.instance.hasAsset(modelRemotePath)) {
|
||||
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
} else {
|
||||
if (await canUseHighBandwidth()) {
|
||||
return await downloadModel();
|
||||
} else {
|
||||
logger.warning(
|
||||
'Cannot return model path as it is not available locally and high bandwidth is not available.',
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Future<String> downloadModel([bool forceRefresh = false]) async {
|
||||
return _downloadModelLock.synchronized(() async {
|
||||
if (forceRefresh) {
|
||||
await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath);
|
||||
final file = await RemoteAssetsService.instance
|
||||
.getAssetIfUpdated(modelRemotePath);
|
||||
return file!.path;
|
||||
} else {
|
||||
await RemoteAssetsService.instance.getAsset(modelRemotePath);
|
||||
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -55,9 +55,9 @@ class MLService {
|
||||
bool _showClusteringIsHappening = false;
|
||||
bool _mlControllerStatus = false;
|
||||
bool _isIndexingOrClusteringRunning = false;
|
||||
bool _isRunningML = false;
|
||||
bool _shouldPauseIndexingAndClustering = false;
|
||||
|
||||
static final int _fileDownloadLimit = Platform.isIOS ? 5 : 10;
|
||||
static const _kForceClusteringFaceCount = 8000;
|
||||
|
||||
/// Only call this function once at app startup, after that you can directly call [runAllML]
|
||||
@@ -118,6 +118,7 @@ class MLService {
|
||||
_mlControllerStatus = true;
|
||||
}
|
||||
if (_cannotRunMLFunction() && !force) return;
|
||||
_isRunningML = true;
|
||||
|
||||
await sync();
|
||||
|
||||
@@ -134,6 +135,16 @@ class MLService {
|
||||
} catch (e, s) {
|
||||
_logger.severe("runAllML failed", e, s);
|
||||
rethrow;
|
||||
} finally {
|
||||
_isRunningML = false;
|
||||
}
|
||||
}
|
||||
|
||||
void triggerML() {
|
||||
if (_mlControllerStatus &&
|
||||
!_isIndexingOrClusteringRunning &&
|
||||
!_isRunningML) {
|
||||
unawaited(runAllML());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,24 +170,26 @@ class MLService {
|
||||
_isIndexingOrClusteringRunning = true;
|
||||
_logger.info('starting image indexing');
|
||||
final Stream<List<FileMLInstruction>> instructionStream =
|
||||
FaceRecognitionService.instance
|
||||
.syncEmbeddings(yieldSize: _fileDownloadLimit);
|
||||
fetchEmbeddingsAndInstructions(fileDownloadMlLimit);
|
||||
|
||||
int fileAnalyzedCount = 0;
|
||||
final Stopwatch stopwatch = Stopwatch()..start();
|
||||
|
||||
stream:
|
||||
await for (final chunk in instructionStream) {
|
||||
if (!await canUseHighBandwidth()) {
|
||||
_logger.info(
|
||||
'stopping indexing because user is not connected to wifi',
|
||||
);
|
||||
break;
|
||||
break stream;
|
||||
} else {
|
||||
await _ensureDownloadedModels();
|
||||
}
|
||||
final futures = <Future<bool>>[];
|
||||
for (final instruction in chunk) {
|
||||
if (_shouldPauseIndexingAndClustering) {
|
||||
_logger.info("indexAllImages() was paused, stopping");
|
||||
break;
|
||||
break stream;
|
||||
}
|
||||
await _ensureLoadedModels(instruction);
|
||||
futures.add(processImage(instruction));
|
||||
@@ -491,6 +504,13 @@ class MLService {
|
||||
}
|
||||
}
|
||||
|
||||
void triggerModelsDownload() {
|
||||
if (!areModelsDownloaded && !_downloadModelLock.locked) {
|
||||
_logger.info("Models not downloaded, starting download");
|
||||
unawaited(_ensureDownloadedModels());
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _ensureDownloadedModels([bool forceRefresh = false]) async {
|
||||
if (_downloadModelLock.locked) {
|
||||
_logger.finest("Download models already in progress");
|
||||
@@ -500,6 +520,13 @@ class MLService {
|
||||
_logger.finest("Models already downloaded");
|
||||
return;
|
||||
}
|
||||
final goodInternet = await canUseHighBandwidth();
|
||||
if (!goodInternet) {
|
||||
_logger.info(
|
||||
"Cannot download models because user is not connected to wifi",
|
||||
);
|
||||
return;
|
||||
}
|
||||
_logger.info('Downloading models');
|
||||
await Future.wait([
|
||||
FaceDetectionService.instance.downloadModel(forceRefresh),
|
||||
@@ -534,7 +561,7 @@ class MLService {
|
||||
}
|
||||
|
||||
bool _cannotRunMLFunction({String function = ""}) {
|
||||
if (kDebugMode && Platform.isIOS) {
|
||||
if (kDebugMode && Platform.isIOS && !_isIndexingOrClusteringRunning) {
|
||||
return false;
|
||||
}
|
||||
if (_isIndexingOrClusteringRunning) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import "package:photos/db/ml/db.dart";
|
||||
import "package:photos/events/people_changed_event.dart";
|
||||
import "package:photos/models/ml/face/person.dart";
|
||||
import "package:photos/service_locator.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||
import 'package:photos/services/machine_learning/ml_service.dart';
|
||||
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
|
||||
@@ -17,10 +17,11 @@ import 'package:photos/ui/components/expandable_menu_item_widget.dart';
|
||||
import 'package:photos/ui/components/menu_item_widget/menu_item_widget.dart';
|
||||
import 'package:photos/ui/settings/common_settings.dart';
|
||||
import "package:photos/utils/dialog_util.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
import 'package:photos/utils/toast_util.dart';
|
||||
|
||||
class MLDebugSectionWidget extends StatefulWidget {
|
||||
const MLDebugSectionWidget({Key? key}) : super(key: key);
|
||||
const MLDebugSectionWidget({super.key});
|
||||
|
||||
@override
|
||||
State<MLDebugSectionWidget> createState() => _MLDebugSectionWidgetState();
|
||||
@@ -54,18 +55,19 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
}
|
||||
|
||||
Widget _getSectionOptions(BuildContext context) {
|
||||
final Logger _logger = Logger("MLDebugSectionWidget");
|
||||
final Logger logger = Logger("MLDebugSectionWidget");
|
||||
return Column(
|
||||
children: [
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: FutureBuilder<int>(
|
||||
future: MLDataDB.instance.getFaceIndexedFileCount(),
|
||||
captionedTextWidget: FutureBuilder<IndexStatus>(
|
||||
future: getIndexStatus(),
|
||||
builder: (context, snapshot) {
|
||||
if (snapshot.hasData) {
|
||||
final IndexStatus status = snapshot.data!;
|
||||
return CaptionedTextWidget(
|
||||
title: localSettings.isMLIndexingEnabled
|
||||
? "Disable faces (${snapshot.data!} files done)"
|
||||
: "Enable faces (${snapshot.data!} files done)",
|
||||
? "Disable ML (${status.indexedItems} files indexed)"
|
||||
: "Enable ML (${status.indexedItems} files indexed)",
|
||||
);
|
||||
}
|
||||
return const SizedBox.shrink();
|
||||
@@ -77,14 +79,18 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
onTap: () async {
|
||||
try {
|
||||
final isEnabled = await localSettings.toggleMLIndexing();
|
||||
if (!isEnabled) {
|
||||
if (isEnabled) {
|
||||
await MLService.instance.init();
|
||||
await SemanticSearchService.instance.init();
|
||||
unawaited(MLService.instance.runAllML(force: true));
|
||||
} else {
|
||||
MLService.instance.pauseIndexingAndClustering();
|
||||
}
|
||||
if (mounted) {
|
||||
setState(() {});
|
||||
}
|
||||
} catch (e, s) {
|
||||
_logger.warning('indexing failed ', e, s);
|
||||
logger.warning('indexing failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
@@ -93,8 +99,8 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: CaptionedTextWidget(
|
||||
title: localSettings.remoteFetchEnabled
|
||||
? "Remote fetch enabled"
|
||||
: "Remote fetch disabled",
|
||||
? "Disable remote fetch"
|
||||
: "Enable remote fetch",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
@@ -106,7 +112,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
setState(() {});
|
||||
}
|
||||
} catch (e, s) {
|
||||
_logger.warning('Remote fetch toggle failed ', e, s);
|
||||
logger.warning('Remote fetch toggle failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
@@ -115,8 +121,8 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: CaptionedTextWidget(
|
||||
title: MLService.instance.debugIndexingDisabled
|
||||
? "Debug enable indexing again"
|
||||
: "Debug disable indexing",
|
||||
? "Enable auto indexing (debug)"
|
||||
: "Disable auto indexing (debug)",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
@@ -127,12 +133,14 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
!MLService.instance.debugIndexingDisabled;
|
||||
if (MLService.instance.debugIndexingDisabled) {
|
||||
MLService.instance.pauseIndexingAndClustering();
|
||||
} else {
|
||||
unawaited(MLService.instance.runAllML());
|
||||
}
|
||||
if (mounted) {
|
||||
setState(() {});
|
||||
}
|
||||
} catch (e, s) {
|
||||
_logger.warning('debugIndexingDisabled toggle failed ', e, s);
|
||||
logger.warning('debugIndexingDisabled toggle failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
@@ -140,7 +148,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
title: "Run sync, indexing, clustering",
|
||||
title: "Trigger run ML",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
@@ -150,7 +158,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
MLService.instance.debugIndexingDisabled = false;
|
||||
unawaited(MLService.instance.runAllML());
|
||||
} catch (e, s) {
|
||||
_logger.warning('indexAndClusterAll failed ', e, s);
|
||||
logger.warning('indexAndClusterAll failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
@@ -158,7 +166,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
title: "Run indexing",
|
||||
title: "Trigger run indexing",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
@@ -168,7 +176,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
MLService.instance.debugIndexingDisabled = false;
|
||||
unawaited(MLService.instance.indexAllImages());
|
||||
} catch (e, s) {
|
||||
_logger.warning('indexing failed ', e, s);
|
||||
logger.warning('indexing failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
@@ -181,7 +189,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
if (snapshot.hasData) {
|
||||
return CaptionedTextWidget(
|
||||
title:
|
||||
"Run clustering (${(100 * snapshot.data!).toStringAsFixed(0)}% done)",
|
||||
"Trigger clustering (${(100 * snapshot.data!).toStringAsFixed(0)}% done)",
|
||||
);
|
||||
}
|
||||
return const SizedBox.shrink();
|
||||
@@ -194,41 +202,41 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
try {
|
||||
await PersonService.instance.fetchRemoteClusterFeedback();
|
||||
MLService.instance.debugIndexingDisabled = false;
|
||||
await MLService.instance.clusterAllImages(clusterInBuckets: true);
|
||||
await MLService.instance.clusterAllImages();
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
showShortToast(context, "Done");
|
||||
} catch (e, s) {
|
||||
_logger.warning('clustering failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
),
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
title: "Check for mixed clusters",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
trailingIconIsMuted: true,
|
||||
onTap: () async {
|
||||
try {
|
||||
final susClusters =
|
||||
await ClusterFeedbackService.instance.checkForMixedClusters();
|
||||
for (final clusterinfo in susClusters) {
|
||||
Future.delayed(const Duration(seconds: 4), () {
|
||||
showToast(
|
||||
context,
|
||||
'Cluster with ${clusterinfo.$2} photos is sus',
|
||||
);
|
||||
});
|
||||
}
|
||||
} catch (e, s) {
|
||||
_logger.warning('Checking for mixed clusters failed', e, s);
|
||||
logger.warning('clustering failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
),
|
||||
// sectionOptionSpacing,
|
||||
// MenuItemWidget(
|
||||
// captionedTextWidget: const CaptionedTextWidget(
|
||||
// title: "Check for mixed clusters",
|
||||
// ),
|
||||
// pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
// trailingIcon: Icons.chevron_right_outlined,
|
||||
// trailingIconIsMuted: true,
|
||||
// onTap: () async {
|
||||
// try {
|
||||
// final susClusters =
|
||||
// await ClusterFeedbackService.instance.checkForMixedClusters();
|
||||
// for (final clusterinfo in susClusters) {
|
||||
// Future.delayed(const Duration(seconds: 4), () {
|
||||
// showToast(
|
||||
// context,
|
||||
// 'Cluster with ${clusterinfo.$2} photos is sus',
|
||||
// );
|
||||
// });
|
||||
// }
|
||||
// } catch (e, s) {
|
||||
// logger.warning('Checking for mixed clusters failed', e, s);
|
||||
// await showGenericErrorDialog(context: context, error: e);
|
||||
// }
|
||||
// },
|
||||
// ),
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
@@ -239,16 +247,28 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
trailingIconIsMuted: true,
|
||||
onTap: () async {
|
||||
try {
|
||||
await PersonService.instance.reconcileClusters();
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
await FaceRecognitionService.instance.sync();
|
||||
showShortToast(context, "Done");
|
||||
} catch (e, s) {
|
||||
_logger.warning('sync person mappings failed ', e, s);
|
||||
logger.warning('sync person mappings failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
),
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
title: "Show empty indexes",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
trailingIconIsMuted: true,
|
||||
onTap: () async {
|
||||
final emptyFaces = await MLDataDB.instance.getErroredFaceCount();
|
||||
showShortToast(context, '$emptyFaces empty faces');
|
||||
},
|
||||
),
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
title: "Reset faces feedback",
|
||||
@@ -262,15 +282,15 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
context,
|
||||
title: "Are you sure?",
|
||||
body:
|
||||
"This will drop all people and their related feedback. It will keep clustering labels and embeddings untouched.",
|
||||
"This will drop all people and their related feedback stored locally. It will keep clustering labels and embeddings untouched, as well as persons stored on remote.",
|
||||
firstButtonLabel: "Yes, confirm",
|
||||
firstButtonOnTap: () async {
|
||||
try {
|
||||
await MLDataDB.instance.dropFeedbackTables();
|
||||
await MLDataDB.instance.dropFacesFeedbackTables();
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
showShortToast(context, "Done");
|
||||
} catch (e, s) {
|
||||
_logger.warning('reset feedback failed ', e, s);
|
||||
logger.warning('reset feedback failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
@@ -290,7 +310,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
context,
|
||||
title: "Are you sure?",
|
||||
body:
|
||||
"This will delete all people, their related feedback and clustering labels. It will keep embeddings untouched.",
|
||||
"This will delete all people (also from remote), their related feedback and clustering labels. It will keep embeddings untouched.",
|
||||
firstButtonLabel: "Yes, confirm",
|
||||
firstButtonOnTap: () async {
|
||||
try {
|
||||
@@ -303,7 +323,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
showShortToast(context, "Done");
|
||||
} catch (e, s) {
|
||||
_logger.warning('peopleToPersonMapping remove failed ', e, s);
|
||||
logger.warning('peopleToPersonMapping remove failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
@@ -313,7 +333,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
title: "Reset faces everything (embeddings)",
|
||||
title: "Reset all local faces",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
@@ -323,7 +343,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
context,
|
||||
title: "Are you sure?",
|
||||
body:
|
||||
"You will need to again re-index all the faces. You can drop feedback if you want to label again",
|
||||
"This will drop all local faces data. You will need to again re-index faces.",
|
||||
firstButtonLabel: "Yes, confirm",
|
||||
firstButtonOnTap: () async {
|
||||
try {
|
||||
@@ -332,16 +352,17 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
showShortToast(context, "Done");
|
||||
} catch (e, s) {
|
||||
_logger.warning('drop feedback failed ', e, s);
|
||||
logger.warning('drop feedback failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
),
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: const CaptionedTextWidget(
|
||||
title: "Reset clip embeddings",
|
||||
title: "Reset all local clip",
|
||||
),
|
||||
pressedColor: getEnteColorScheme(context).fillFaint,
|
||||
trailingIcon: Icons.chevron_right_outlined,
|
||||
@@ -358,7 +379,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
|
||||
await SemanticSearchService.instance.clearIndexes();
|
||||
showShortToast(context, "Done");
|
||||
} catch (e, s) {
|
||||
_logger.warning('drop clip embeddings failed ', e, s);
|
||||
logger.warning('drop clip embeddings failed ', e, s);
|
||||
await showGenericErrorDialog(context: context, error: e);
|
||||
}
|
||||
},
|
||||
|
||||
@@ -27,6 +27,7 @@ import "package:photos/ui/components/title_bar_widget.dart";
|
||||
import "package:photos/ui/components/toggle_switch_widget.dart";
|
||||
import "package:photos/ui/settings/ml/enable_ml_consent.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
import "package:photos/utils/network_util.dart";
|
||||
import "package:photos/utils/wakelock_util.dart";
|
||||
|
||||
class MachineLearningSettingsPage extends StatefulWidget {
|
||||
@@ -203,6 +204,7 @@ class _MachineLearningSettingsPageState
|
||||
await SemanticSearchService.instance.init();
|
||||
unawaited(MLService.instance.runAllML(force: true));
|
||||
} else {
|
||||
MLService.instance.pauseIndexingAndClustering();
|
||||
await UserRemoteFlagService.instance
|
||||
.setBoolValue(UserRemoteFlagService.mlEnabled, false);
|
||||
}
|
||||
@@ -257,6 +259,8 @@ class ModelLoadingState extends StatefulWidget {
|
||||
class _ModelLoadingStateState extends State<ModelLoadingState> {
|
||||
StreamSubscription<(String, int, int)>? _progressStream;
|
||||
final Map<String, (int, int)> _progressMap = {};
|
||||
Timer? _timer;
|
||||
|
||||
@override
|
||||
void initState() {
|
||||
_progressStream =
|
||||
@@ -277,6 +281,9 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
|
||||
setState(() {});
|
||||
}
|
||||
});
|
||||
_timer = Timer.periodic(const Duration(seconds: 10), (timer) {
|
||||
setState(() {});
|
||||
});
|
||||
super.initState();
|
||||
}
|
||||
|
||||
@@ -284,6 +291,7 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
|
||||
void dispose() {
|
||||
super.dispose();
|
||||
_progressStream?.cancel();
|
||||
_timer?.cancel();
|
||||
}
|
||||
|
||||
@override
|
||||
@@ -292,8 +300,25 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
|
||||
children: [
|
||||
MenuSectionTitle(title: S.of(context).status),
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: CaptionedTextWidget(
|
||||
title: _getTitle(context),
|
||||
captionedTextWidget: FutureBuilder(
|
||||
future: canUseHighBandwidth(),
|
||||
builder: (context, snapshot) {
|
||||
if (snapshot.hasData) {
|
||||
if (snapshot.data!) {
|
||||
MLService.instance.triggerModelsDownload();
|
||||
return CaptionedTextWidget(
|
||||
title: S.of(context).loadingModel,
|
||||
key: const ValueKey("loading_model"),
|
||||
);
|
||||
} else {
|
||||
return CaptionedTextWidget(
|
||||
title: S.of(context).waitingForWifi,
|
||||
key: const ValueKey("waiting_for_wifi"),
|
||||
);
|
||||
}
|
||||
}
|
||||
return const CaptionedTextWidget(title: "");
|
||||
},
|
||||
),
|
||||
trailingWidget: EnteLoadingWidget(
|
||||
size: 12,
|
||||
@@ -322,15 +347,6 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
String _getTitle(BuildContext context) {
|
||||
// TODO: uncomment code below to actually check for high bandwidth
|
||||
// final usableConnection = await canUseHighBandwidth();
|
||||
// if (!usableConnection) {
|
||||
// return S.of(context).waitingForWifi;
|
||||
// }
|
||||
return S.of(context).loadingModel;
|
||||
}
|
||||
}
|
||||
|
||||
class MLStatusWidget extends StatefulWidget {
|
||||
@@ -348,6 +364,7 @@ class MLStatusWidgetState extends State<MLStatusWidget> {
|
||||
void initState() {
|
||||
super.initState();
|
||||
_timer = Timer.periodic(const Duration(seconds: 10), (timer) {
|
||||
MLService.instance.triggerML();
|
||||
setState(() {});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -144,9 +144,7 @@ class SettingsPage extends StatelessWidget {
|
||||
|
||||
if (hasLoggedIn && flagService.internalUser) {
|
||||
contents.addAll([sectionSpacing, const DebugSectionWidget()]);
|
||||
if (flagService.isBetaUser) {
|
||||
contents.addAll([sectionSpacing, const MLDebugSectionWidget()]);
|
||||
}
|
||||
contents.addAll([sectionSpacing, const MLDebugSectionWidget()]);
|
||||
}
|
||||
contents.add(const AppVersionWidget());
|
||||
contents.add(const DeveloperSettingsWidget());
|
||||
|
||||
@@ -7,11 +7,16 @@ import "package:photos/core/configuration.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
import "package:photos/db/ml/clip_db.dart";
|
||||
import "package:photos/db/ml/db.dart";
|
||||
import "package:photos/extensions/list.dart";
|
||||
import "package:photos/models/file/extensions/file_props.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/models/file/file_type.dart";
|
||||
import "package:photos/models/ml/clip.dart";
|
||||
import "package:photos/models/ml/face/dimension.dart";
|
||||
import "package:photos/models/ml/face/face.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:photos/service_locator.dart";
|
||||
import "package:photos/services/filedata/filedata_service.dart";
|
||||
import "package:photos/services/filedata/model/file_data.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_exceptions.dart";
|
||||
@@ -68,6 +73,7 @@ Future<IndexStatus> getIndexStatus() async {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a list of file instructions for files that should be indexed for ML
|
||||
Future<List<FileMLInstruction>> getFilesForMlIndexing() async {
|
||||
_logger.info('getFilesForMlIndexing called');
|
||||
final time = DateTime.now();
|
||||
@@ -146,7 +152,105 @@ Future<List<FileMLInstruction>> getFilesForMlIndexing() async {
|
||||
return sortedBylocalID;
|
||||
}
|
||||
|
||||
bool shouldDiscardRemoteEmbedding(FileDataEntity fileML) {
|
||||
Stream<List<FileMLInstruction>> fetchEmbeddingsAndInstructions(
|
||||
int yieldSize,) async* {
|
||||
final List<FileMLInstruction> filesToIndex = await getFilesForMlIndexing();
|
||||
final List<List<FileMLInstruction>> chunks =
|
||||
filesToIndex.chunks(embeddingFetchLimit);
|
||||
List<FileMLInstruction> batchToYield = [];
|
||||
|
||||
for (final chunk in chunks) {
|
||||
if (!localSettings.remoteFetchEnabled) {
|
||||
_logger.warning("remoteFetchEnabled is false, skiping embedding fetch");
|
||||
final batches = chunk.chunks(yieldSize);
|
||||
for (final batch in batches) {
|
||||
yield batch;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
final Set<int> ids = {};
|
||||
final Map<int, FileMLInstruction> pendingIndex = {};
|
||||
for (final instruction in chunk) {
|
||||
ids.add(instruction.file.uploadedFileID!);
|
||||
pendingIndex[instruction.file.uploadedFileID!] = instruction;
|
||||
}
|
||||
_logger.info("fetching embeddings for ${ids.length} files");
|
||||
final res = await FileDataService.instance.getFilesData(ids);
|
||||
_logger.info("embeddingResponse ${res.debugLog()}");
|
||||
final List<Face> faces = [];
|
||||
final List<ClipEmbedding> clipEmbeddings = [];
|
||||
for (FileDataEntity fileMl in res.data.values) {
|
||||
final existingInstruction = pendingIndex[fileMl.fileID]!;
|
||||
final facesFromRemoteEmbedding = _getFacesFromRemoteEmbedding(fileMl);
|
||||
//Note: Always do null check, empty value means no face was found.
|
||||
if (facesFromRemoteEmbedding != null) {
|
||||
faces.addAll(facesFromRemoteEmbedding);
|
||||
existingInstruction.shouldRunFaces = false;
|
||||
}
|
||||
if (fileMl.clipEmbedding != null &&
|
||||
fileMl.clipEmbedding!.version >= clipMlVersion) {
|
||||
clipEmbeddings.add(
|
||||
ClipEmbedding(
|
||||
fileID: fileMl.fileID,
|
||||
embedding: fileMl.clipEmbedding!.embedding,
|
||||
version: fileMl.clipEmbedding!.version,
|
||||
),
|
||||
);
|
||||
existingInstruction.shouldRunClip = false;
|
||||
}
|
||||
if (!existingInstruction.pendingML) {
|
||||
pendingIndex.remove(fileMl.fileID);
|
||||
} else {
|
||||
existingInstruction.existingRemoteFileML = fileMl;
|
||||
pendingIndex[fileMl.fileID] = existingInstruction;
|
||||
}
|
||||
}
|
||||
for (final fileID in pendingIndex.keys) {
|
||||
final instruction = pendingIndex[fileID]!;
|
||||
if (instruction.pendingML) {
|
||||
batchToYield.add(instruction);
|
||||
if (batchToYield.length == yieldSize) {
|
||||
_logger.info("queueing indexing for $yieldSize");
|
||||
yield batchToYield;
|
||||
batchToYield = [];
|
||||
}
|
||||
}
|
||||
}
|
||||
await MLDataDB.instance.bulkInsertFaces(faces);
|
||||
await MLDataDB.instance.putMany(clipEmbeddings);
|
||||
}
|
||||
// Yield any remaining instructions
|
||||
if (batchToYield.isNotEmpty) {
|
||||
_logger.info("queueing indexing for ${batchToYield.length}");
|
||||
yield batchToYield;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a list of faces from the given remote fileML. null if the version is less than the current version
|
||||
// or if the remote faceEmbedding is null.
|
||||
List<Face>? _getFacesFromRemoteEmbedding(FileDataEntity fileMl) {
|
||||
final RemoteFaceEmbedding? remoteFaceEmbedding = fileMl.faceEmbedding;
|
||||
if (_shouldDiscardRemoteEmbedding(fileMl)) {
|
||||
return null;
|
||||
}
|
||||
final List<Face> faces = [];
|
||||
if (remoteFaceEmbedding!.faces.isEmpty) {
|
||||
faces.add(
|
||||
Face.empty(fileMl.fileID),
|
||||
);
|
||||
} else {
|
||||
for (final f in remoteFaceEmbedding.faces) {
|
||||
f.fileInfo = FileInfo(
|
||||
imageHeight: remoteFaceEmbedding.height,
|
||||
imageWidth: remoteFaceEmbedding.width,
|
||||
);
|
||||
faces.add(f);
|
||||
}
|
||||
}
|
||||
return faces;
|
||||
}
|
||||
|
||||
bool _shouldDiscardRemoteEmbedding(FileDataEntity fileML) {
|
||||
final fileID = fileML.fileID;
|
||||
final RemoteFaceEmbedding? faceEmbedding = fileML.faceEmbedding;
|
||||
if (faceEmbedding == null || faceEmbedding.version < faceMlVersion) {
|
||||
|
||||
Reference in New Issue
Block a user