Various minor fixes for ML:
- Only download model on Wifi and notify user
- Fix ML queue not stopping on interaction
- Partial fix for ML stopping randomly
- Fix ML queue not stopping on disabling ML
- Update ML debug options
This commit is contained in:
Laurens Priem
2024-08-23 17:56:57 +02:00
committed by GitHub
10 changed files with 290 additions and 199 deletions

View File

@@ -660,6 +660,14 @@ class MLDataDB {
return maps.first['count'] as int;
}
Future<int> getErroredFaceCount() async {
final db = await instance.asyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore < 0',
);
return maps.first['count'] as int;
}
Future<int> getClusteredOrFacelessFileCount() async {
final db = await instance.asyncDB;
final List<Map<String, dynamic>> clustered = await db.getAll(
@@ -942,7 +950,7 @@ class MLDataDB {
}
/// WARNING: This will delete ALL data in the tables! Only use this for debug/testing purposes!
Future<void> dropFeedbackTables() async {
Future<void> dropFacesFeedbackTables() async {
try {
final db = await instance.asyncDB;

View File

@@ -1,4 +1,9 @@
import "dart:io" show Platform;
const faceMlVersion = 1;
const clipMlVersion = 1;
const clusterMlVersion = 1;
const minimumClusterSize = 2;
const embeddingFetchLimit = 200;
final fileDownloadMlLimit = Platform.isIOS ? 5 : 10;

View File

@@ -5,17 +5,8 @@ import "dart:ui" show Image;
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/ml/clip_db.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/events/diff_sync_complete_event.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/list.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/face/face.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/filedata/filedata_service.dart";
import "package:photos/services/filedata/model/file_data.dart";
import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart";
import "package:photos/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart";
@@ -23,7 +14,6 @@ import "package:photos/services/machine_learning/face_ml/person/person_service.d
import "package:photos/services/machine_learning/ml_exceptions.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/ml_util.dart";
class FaceRecognitionService {
final _logger = Logger("FaceRecognitionService");
@@ -40,8 +30,6 @@ class FaceRecognitionService {
bool _shouldSyncPeople = false;
bool _isSyncing = false;
static const _embeddingFetchLimit = 200;
Future<void> init() async {
if (_isInitialized) {
return;
@@ -65,12 +53,6 @@ class FaceRecognitionService {
Future<void> sync() async {
await _syncPersonFeedback();
if (localSettings.remoteFetchEnabled) {
} else {
_logger.severe(
'Not fetching embeddings because user manually disabled it in debug options',
);
}
}
Future<void> _syncPersonFeedback() async {
@@ -86,97 +68,6 @@ class FaceRecognitionService {
_isSyncing = false;
}
Stream<List<FileMLInstruction>> syncEmbeddings({
int yieldSize = 10,
}) async* {
final List<FileMLInstruction> filesToIndex = await getFilesForMlIndexing();
final List<List<FileMLInstruction>> chunks =
filesToIndex.chunks(_embeddingFetchLimit);
List<FileMLInstruction> batchToYield = [];
for (final chunk in chunks) {
final Set<int> ids = {};
final Map<int, FileMLInstruction> pendingIndex = {};
for (final instruction in chunk) {
ids.add(instruction.file.uploadedFileID!);
pendingIndex[instruction.file.uploadedFileID!] = instruction;
}
_logger.info("fetching embeddings for ${ids.length} files");
final res = await FileDataService.instance.getFilesData(ids);
_logger.info("embeddingResponse ${res.debugLog()}");
final List<Face> faces = [];
final List<ClipEmbedding> clipEmbeddings = [];
for (FileDataEntity fileMl in res.data.values) {
final existingInstruction = pendingIndex[fileMl.fileID]!;
final facesFromRemoteEmbedding = _getFacesFromRemoteEmbedding(fileMl);
//Note: Always do null check, empty value means no face was found.
if (facesFromRemoteEmbedding != null) {
faces.addAll(facesFromRemoteEmbedding);
existingInstruction.shouldRunFaces = false;
}
if (fileMl.clipEmbedding != null &&
fileMl.clipEmbedding!.version >= clipMlVersion) {
clipEmbeddings.add(
ClipEmbedding(
fileID: fileMl.fileID,
embedding: fileMl.clipEmbedding!.embedding,
version: fileMl.clipEmbedding!.version,
),
);
existingInstruction.shouldRunClip = false;
}
if (!existingInstruction.pendingML) {
pendingIndex.remove(fileMl.fileID);
} else {
existingInstruction.existingRemoteFileML = fileMl;
pendingIndex[fileMl.fileID] = existingInstruction;
}
}
for (final fileID in pendingIndex.keys) {
final instruction = pendingIndex[fileID]!;
if (instruction.pendingML) {
batchToYield.add(instruction);
if (batchToYield.length == yieldSize) {
_logger.info("queueing indexing for $yieldSize");
yield batchToYield;
batchToYield = [];
}
}
}
await MLDataDB.instance.bulkInsertFaces(faces);
await MLDataDB.instance.putMany(clipEmbeddings);
}
// Yield any remaining instructions
if (batchToYield.isNotEmpty) {
_logger.info("queueing indexing for ${batchToYield.length}");
yield batchToYield;
}
}
// Returns a list of faces from the given remote fileML. null if the version is less than the current version
// or if the remote faceEmbedding is null.
List<Face>? _getFacesFromRemoteEmbedding(FileDataEntity fileMl) {
final RemoteFaceEmbedding? remoteFaceEmbedding = fileMl.faceEmbedding;
if (shouldDiscardRemoteEmbedding(fileMl)) {
return null;
}
final List<Face> faces = [];
if (remoteFaceEmbedding!.faces.isEmpty) {
faces.add(
Face.empty(fileMl.fileID),
);
} else {
for (final f in remoteFaceEmbedding.faces) {
f.fileInfo = FileInfo(
imageHeight: remoteFaceEmbedding.height,
imageWidth: remoteFaceEmbedding.width,
);
faces.add(f);
}
}
return faces;
}
static Future<List<FaceResult>> runFacesPipeline(
int enteFileID,
Image image,

View File

@@ -158,8 +158,8 @@ class MLComputer {
}
Future<List<double>> runClipText(String query) async {
await _ensureLoadedClipTextModel();
try {
await _ensureLoadedClipTextModel();
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
final textEmbedding = await _runInIsolate(
(
@@ -195,9 +195,10 @@ class MLComputer {
// Load ClipText model
final String modelName = ClipTextEncoder.instance.modelName;
final String modelRemotePath = ClipTextEncoder.instance.modelRemotePath;
final String modelPath =
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
final String? modelPath = await ClipTextEncoder.instance.downloadModelSafe();
if (modelPath == null) {
throw Exception("Could not download clip text model, no wifi");
}
final address = await _runInIsolate(
(
MLComputerOperation.loadModel,

View File

@@ -5,6 +5,7 @@ import "package:onnx_dart/onnx_dart.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/machine_learning/onnx_env.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/network_util.dart";
import "package:synchronized/synchronized.dart";
abstract class MlModel {
@@ -32,6 +33,7 @@ abstract class MlModel {
bool _isNativePluginInitialized = false;
int _nativePluginSessionIndex = -1;
/// WARNING: If [downloadModel] was not first called, this method will download the model first using high bandwidth.
Future<(String, String)> getModelNameAndPath() async {
return _downloadModelLock.synchronized(() async {
final path =
@@ -40,12 +42,29 @@ abstract class MlModel {
});
}
Future<void> downloadModel([bool forceRefresh = false]) async {
Future<String?> downloadModelSafe() async {
if (await RemoteAssetsService.instance.hasAsset(modelRemotePath)) {
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
} else {
if (await canUseHighBandwidth()) {
return await downloadModel();
} else {
logger.warning(
'Cannot return model path as it is not available locally and high bandwidth is not available.',
);
return null;
}
}
}
Future<String> downloadModel([bool forceRefresh = false]) async {
return _downloadModelLock.synchronized(() async {
if (forceRefresh) {
await RemoteAssetsService.instance.getAssetIfUpdated(modelRemotePath);
final file = await RemoteAssetsService.instance
.getAssetIfUpdated(modelRemotePath);
return file!.path;
} else {
await RemoteAssetsService.instance.getAsset(modelRemotePath);
return await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
}
});
}

View File

@@ -55,9 +55,9 @@ class MLService {
bool _showClusteringIsHappening = false;
bool _mlControllerStatus = false;
bool _isIndexingOrClusteringRunning = false;
bool _isRunningML = false;
bool _shouldPauseIndexingAndClustering = false;
static final int _fileDownloadLimit = Platform.isIOS ? 5 : 10;
static const _kForceClusteringFaceCount = 8000;
/// Only call this function once at app startup, after that you can directly call [runAllML]
@@ -118,6 +118,7 @@ class MLService {
_mlControllerStatus = true;
}
if (_cannotRunMLFunction() && !force) return;
_isRunningML = true;
await sync();
@@ -134,6 +135,16 @@ class MLService {
} catch (e, s) {
_logger.severe("runAllML failed", e, s);
rethrow;
} finally {
_isRunningML = false;
}
}
void triggerML() {
if (_mlControllerStatus &&
!_isIndexingOrClusteringRunning &&
!_isRunningML) {
unawaited(runAllML());
}
}
@@ -159,24 +170,26 @@ class MLService {
_isIndexingOrClusteringRunning = true;
_logger.info('starting image indexing');
final Stream<List<FileMLInstruction>> instructionStream =
FaceRecognitionService.instance
.syncEmbeddings(yieldSize: _fileDownloadLimit);
fetchEmbeddingsAndInstructions(fileDownloadMlLimit);
int fileAnalyzedCount = 0;
final Stopwatch stopwatch = Stopwatch()..start();
stream:
await for (final chunk in instructionStream) {
if (!await canUseHighBandwidth()) {
_logger.info(
'stopping indexing because user is not connected to wifi',
);
break;
break stream;
} else {
await _ensureDownloadedModels();
}
final futures = <Future<bool>>[];
for (final instruction in chunk) {
if (_shouldPauseIndexingAndClustering) {
_logger.info("indexAllImages() was paused, stopping");
break;
break stream;
}
await _ensureLoadedModels(instruction);
futures.add(processImage(instruction));
@@ -491,6 +504,13 @@ class MLService {
}
}
void triggerModelsDownload() {
if (!areModelsDownloaded && !_downloadModelLock.locked) {
_logger.info("Models not downloaded, starting download");
unawaited(_ensureDownloadedModels());
}
}
Future<void> _ensureDownloadedModels([bool forceRefresh = false]) async {
if (_downloadModelLock.locked) {
_logger.finest("Download models already in progress");
@@ -500,6 +520,13 @@ class MLService {
_logger.finest("Models already downloaded");
return;
}
final goodInternet = await canUseHighBandwidth();
if (!goodInternet) {
_logger.info(
"Cannot download models because user is not connected to wifi",
);
return;
}
_logger.info('Downloading models');
await Future.wait([
FaceDetectionService.instance.downloadModel(forceRefresh),
@@ -534,7 +561,7 @@ class MLService {
}
bool _cannotRunMLFunction({String function = ""}) {
if (kDebugMode && Platform.isIOS) {
if (kDebugMode && Platform.isIOS && !_isIndexingOrClusteringRunning) {
return false;
}
if (_isIndexingOrClusteringRunning) {

View File

@@ -7,7 +7,7 @@ import "package:photos/db/ml/db.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/models/ml/face/person.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import 'package:photos/services/machine_learning/ml_service.dart';
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
@@ -17,10 +17,11 @@ import 'package:photos/ui/components/expandable_menu_item_widget.dart';
import 'package:photos/ui/components/menu_item_widget/menu_item_widget.dart';
import 'package:photos/ui/settings/common_settings.dart';
import "package:photos/utils/dialog_util.dart";
import "package:photos/utils/ml_util.dart";
import 'package:photos/utils/toast_util.dart';
class MLDebugSectionWidget extends StatefulWidget {
const MLDebugSectionWidget({Key? key}) : super(key: key);
const MLDebugSectionWidget({super.key});
@override
State<MLDebugSectionWidget> createState() => _MLDebugSectionWidgetState();
@@ -54,18 +55,19 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
}
Widget _getSectionOptions(BuildContext context) {
final Logger _logger = Logger("MLDebugSectionWidget");
final Logger logger = Logger("MLDebugSectionWidget");
return Column(
children: [
MenuItemWidget(
captionedTextWidget: FutureBuilder<int>(
future: MLDataDB.instance.getFaceIndexedFileCount(),
captionedTextWidget: FutureBuilder<IndexStatus>(
future: getIndexStatus(),
builder: (context, snapshot) {
if (snapshot.hasData) {
final IndexStatus status = snapshot.data!;
return CaptionedTextWidget(
title: localSettings.isMLIndexingEnabled
? "Disable faces (${snapshot.data!} files done)"
: "Enable faces (${snapshot.data!} files done)",
? "Disable ML (${status.indexedItems} files indexed)"
: "Enable ML (${status.indexedItems} files indexed)",
);
}
return const SizedBox.shrink();
@@ -77,14 +79,18 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
onTap: () async {
try {
final isEnabled = await localSettings.toggleMLIndexing();
if (!isEnabled) {
if (isEnabled) {
await MLService.instance.init();
await SemanticSearchService.instance.init();
unawaited(MLService.instance.runAllML(force: true));
} else {
MLService.instance.pauseIndexingAndClustering();
}
if (mounted) {
setState(() {});
}
} catch (e, s) {
_logger.warning('indexing failed ', e, s);
logger.warning('indexing failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
@@ -93,8 +99,8 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: localSettings.remoteFetchEnabled
? "Remote fetch enabled"
: "Remote fetch disabled",
? "Disable remote fetch"
: "Enable remote fetch",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
@@ -106,7 +112,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
setState(() {});
}
} catch (e, s) {
_logger.warning('Remote fetch toggle failed ', e, s);
logger.warning('Remote fetch toggle failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
@@ -115,8 +121,8 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: MLService.instance.debugIndexingDisabled
? "Debug enable indexing again"
: "Debug disable indexing",
? "Enable auto indexing (debug)"
: "Disable auto indexing (debug)",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
@@ -127,12 +133,14 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
!MLService.instance.debugIndexingDisabled;
if (MLService.instance.debugIndexingDisabled) {
MLService.instance.pauseIndexingAndClustering();
} else {
unawaited(MLService.instance.runAllML());
}
if (mounted) {
setState(() {});
}
} catch (e, s) {
_logger.warning('debugIndexingDisabled toggle failed ', e, s);
logger.warning('debugIndexingDisabled toggle failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
@@ -140,7 +148,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Run sync, indexing, clustering",
title: "Trigger run ML",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
@@ -150,7 +158,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
MLService.instance.debugIndexingDisabled = false;
unawaited(MLService.instance.runAllML());
} catch (e, s) {
_logger.warning('indexAndClusterAll failed ', e, s);
logger.warning('indexAndClusterAll failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
@@ -158,7 +166,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Run indexing",
title: "Trigger run indexing",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
@@ -168,7 +176,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
MLService.instance.debugIndexingDisabled = false;
unawaited(MLService.instance.indexAllImages());
} catch (e, s) {
_logger.warning('indexing failed ', e, s);
logger.warning('indexing failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
@@ -181,7 +189,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
if (snapshot.hasData) {
return CaptionedTextWidget(
title:
"Run clustering (${(100 * snapshot.data!).toStringAsFixed(0)}% done)",
"Trigger clustering (${(100 * snapshot.data!).toStringAsFixed(0)}% done)",
);
}
return const SizedBox.shrink();
@@ -194,41 +202,41 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
try {
await PersonService.instance.fetchRemoteClusterFeedback();
MLService.instance.debugIndexingDisabled = false;
await MLService.instance.clusterAllImages(clusterInBuckets: true);
await MLService.instance.clusterAllImages();
Bus.instance.fire(PeopleChangedEvent());
showShortToast(context, "Done");
} catch (e, s) {
_logger.warning('clustering failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
),
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Check for mixed clusters",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
try {
final susClusters =
await ClusterFeedbackService.instance.checkForMixedClusters();
for (final clusterinfo in susClusters) {
Future.delayed(const Duration(seconds: 4), () {
showToast(
context,
'Cluster with ${clusterinfo.$2} photos is sus',
);
});
}
} catch (e, s) {
_logger.warning('Checking for mixed clusters failed', e, s);
logger.warning('clustering failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
),
// sectionOptionSpacing,
// MenuItemWidget(
// captionedTextWidget: const CaptionedTextWidget(
// title: "Check for mixed clusters",
// ),
// pressedColor: getEnteColorScheme(context).fillFaint,
// trailingIcon: Icons.chevron_right_outlined,
// trailingIconIsMuted: true,
// onTap: () async {
// try {
// final susClusters =
// await ClusterFeedbackService.instance.checkForMixedClusters();
// for (final clusterinfo in susClusters) {
// Future.delayed(const Duration(seconds: 4), () {
// showToast(
// context,
// 'Cluster with ${clusterinfo.$2} photos is sus',
// );
// });
// }
// } catch (e, s) {
// logger.warning('Checking for mixed clusters failed', e, s);
// await showGenericErrorDialog(context: context, error: e);
// }
// },
// ),
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
@@ -239,16 +247,28 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
trailingIconIsMuted: true,
onTap: () async {
try {
await PersonService.instance.reconcileClusters();
Bus.instance.fire(PeopleChangedEvent());
await FaceRecognitionService.instance.sync();
showShortToast(context, "Done");
} catch (e, s) {
_logger.warning('sync person mappings failed ', e, s);
logger.warning('sync person mappings failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
),
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Show empty indexes",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
final emptyFaces = await MLDataDB.instance.getErroredFaceCount();
showShortToast(context, '$emptyFaces empty faces');
},
),
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Reset faces feedback",
@@ -262,15 +282,15 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
context,
title: "Are you sure?",
body:
"This will drop all people and their related feedback. It will keep clustering labels and embeddings untouched.",
"This will drop all people and their related feedback stored locally. It will keep clustering labels and embeddings untouched, as well as persons stored on remote.",
firstButtonLabel: "Yes, confirm",
firstButtonOnTap: () async {
try {
await MLDataDB.instance.dropFeedbackTables();
await MLDataDB.instance.dropFacesFeedbackTables();
Bus.instance.fire(PeopleChangedEvent());
showShortToast(context, "Done");
} catch (e, s) {
_logger.warning('reset feedback failed ', e, s);
logger.warning('reset feedback failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
@@ -290,7 +310,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
context,
title: "Are you sure?",
body:
"This will delete all people, their related feedback and clustering labels. It will keep embeddings untouched.",
"This will delete all people (also from remote), their related feedback and clustering labels. It will keep embeddings untouched.",
firstButtonLabel: "Yes, confirm",
firstButtonOnTap: () async {
try {
@@ -303,7 +323,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
Bus.instance.fire(PeopleChangedEvent());
showShortToast(context, "Done");
} catch (e, s) {
_logger.warning('peopleToPersonMapping remove failed ', e, s);
logger.warning('peopleToPersonMapping remove failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
@@ -313,7 +333,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Reset faces everything (embeddings)",
title: "Reset all local faces",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
@@ -323,7 +343,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
context,
title: "Are you sure?",
body:
"You will need to again re-index all the faces. You can drop feedback if you want to label again",
"This will drop all local faces data. You will need to again re-index faces.",
firstButtonLabel: "Yes, confirm",
firstButtonOnTap: () async {
try {
@@ -332,16 +352,17 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
Bus.instance.fire(PeopleChangedEvent());
showShortToast(context, "Done");
} catch (e, s) {
_logger.warning('drop feedback failed ', e, s);
logger.warning('drop feedback failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
);
},
),
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Reset clip embeddings",
title: "Reset all local clip",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
@@ -358,7 +379,7 @@ class _MLDebugSectionWidgetState extends State<MLDebugSectionWidget> {
await SemanticSearchService.instance.clearIndexes();
showShortToast(context, "Done");
} catch (e, s) {
_logger.warning('drop clip embeddings failed ', e, s);
logger.warning('drop clip embeddings failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},

View File

@@ -27,6 +27,7 @@ import "package:photos/ui/components/title_bar_widget.dart";
import "package:photos/ui/components/toggle_switch_widget.dart";
import "package:photos/ui/settings/ml/enable_ml_consent.dart";
import "package:photos/utils/ml_util.dart";
import "package:photos/utils/network_util.dart";
import "package:photos/utils/wakelock_util.dart";
class MachineLearningSettingsPage extends StatefulWidget {
@@ -203,6 +204,7 @@ class _MachineLearningSettingsPageState
await SemanticSearchService.instance.init();
unawaited(MLService.instance.runAllML(force: true));
} else {
MLService.instance.pauseIndexingAndClustering();
await UserRemoteFlagService.instance
.setBoolValue(UserRemoteFlagService.mlEnabled, false);
}
@@ -257,6 +259,8 @@ class ModelLoadingState extends StatefulWidget {
class _ModelLoadingStateState extends State<ModelLoadingState> {
StreamSubscription<(String, int, int)>? _progressStream;
final Map<String, (int, int)> _progressMap = {};
Timer? _timer;
@override
void initState() {
_progressStream =
@@ -277,6 +281,9 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
setState(() {});
}
});
_timer = Timer.periodic(const Duration(seconds: 10), (timer) {
setState(() {});
});
super.initState();
}
@@ -284,6 +291,7 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
void dispose() {
super.dispose();
_progressStream?.cancel();
_timer?.cancel();
}
@override
@@ -292,8 +300,25 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
children: [
MenuSectionTitle(title: S.of(context).status),
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: _getTitle(context),
captionedTextWidget: FutureBuilder(
future: canUseHighBandwidth(),
builder: (context, snapshot) {
if (snapshot.hasData) {
if (snapshot.data!) {
MLService.instance.triggerModelsDownload();
return CaptionedTextWidget(
title: S.of(context).loadingModel,
key: const ValueKey("loading_model"),
);
} else {
return CaptionedTextWidget(
title: S.of(context).waitingForWifi,
key: const ValueKey("waiting_for_wifi"),
);
}
}
return const CaptionedTextWidget(title: "");
},
),
trailingWidget: EnteLoadingWidget(
size: 12,
@@ -322,15 +347,6 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
],
);
}
String _getTitle(BuildContext context) {
// TODO: uncomment code below to actually check for high bandwidth
// final usableConnection = await canUseHighBandwidth();
// if (!usableConnection) {
// return S.of(context).waitingForWifi;
// }
return S.of(context).loadingModel;
}
}
class MLStatusWidget extends StatefulWidget {
@@ -348,6 +364,7 @@ class MLStatusWidgetState extends State<MLStatusWidget> {
void initState() {
super.initState();
_timer = Timer.periodic(const Duration(seconds: 10), (timer) {
MLService.instance.triggerML();
setState(() {});
});
}

View File

@@ -144,9 +144,7 @@ class SettingsPage extends StatelessWidget {
if (hasLoggedIn && flagService.internalUser) {
contents.addAll([sectionSpacing, const DebugSectionWidget()]);
if (flagService.isBetaUser) {
contents.addAll([sectionSpacing, const MLDebugSectionWidget()]);
}
contents.addAll([sectionSpacing, const MLDebugSectionWidget()]);
}
contents.add(const AppVersionWidget());
contents.add(const DeveloperSettingsWidget());

View File

@@ -7,11 +7,16 @@ import "package:photos/core/configuration.dart";
import "package:photos/db/files_db.dart";
import "package:photos/db/ml/clip_db.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/extensions/list.dart";
import "package:photos/models/file/extensions/file_props.dart";
import "package:photos/models/file/file.dart";
import "package:photos/models/file/file_type.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/face/dimension.dart";
import "package:photos/models/ml/face/face.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/filedata/filedata_service.dart";
import "package:photos/services/filedata/model/file_data.dart";
import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart";
import "package:photos/services/machine_learning/ml_exceptions.dart";
@@ -68,6 +73,7 @@ Future<IndexStatus> getIndexStatus() async {
}
}
/// Return a list of file instructions for files that should be indexed for ML
Future<List<FileMLInstruction>> getFilesForMlIndexing() async {
_logger.info('getFilesForMlIndexing called');
final time = DateTime.now();
@@ -146,7 +152,105 @@ Future<List<FileMLInstruction>> getFilesForMlIndexing() async {
return sortedBylocalID;
}
bool shouldDiscardRemoteEmbedding(FileDataEntity fileML) {
Stream<List<FileMLInstruction>> fetchEmbeddingsAndInstructions(
int yieldSize,) async* {
final List<FileMLInstruction> filesToIndex = await getFilesForMlIndexing();
final List<List<FileMLInstruction>> chunks =
filesToIndex.chunks(embeddingFetchLimit);
List<FileMLInstruction> batchToYield = [];
for (final chunk in chunks) {
if (!localSettings.remoteFetchEnabled) {
_logger.warning("remoteFetchEnabled is false, skiping embedding fetch");
final batches = chunk.chunks(yieldSize);
for (final batch in batches) {
yield batch;
}
continue;
}
final Set<int> ids = {};
final Map<int, FileMLInstruction> pendingIndex = {};
for (final instruction in chunk) {
ids.add(instruction.file.uploadedFileID!);
pendingIndex[instruction.file.uploadedFileID!] = instruction;
}
_logger.info("fetching embeddings for ${ids.length} files");
final res = await FileDataService.instance.getFilesData(ids);
_logger.info("embeddingResponse ${res.debugLog()}");
final List<Face> faces = [];
final List<ClipEmbedding> clipEmbeddings = [];
for (FileDataEntity fileMl in res.data.values) {
final existingInstruction = pendingIndex[fileMl.fileID]!;
final facesFromRemoteEmbedding = _getFacesFromRemoteEmbedding(fileMl);
//Note: Always do null check, empty value means no face was found.
if (facesFromRemoteEmbedding != null) {
faces.addAll(facesFromRemoteEmbedding);
existingInstruction.shouldRunFaces = false;
}
if (fileMl.clipEmbedding != null &&
fileMl.clipEmbedding!.version >= clipMlVersion) {
clipEmbeddings.add(
ClipEmbedding(
fileID: fileMl.fileID,
embedding: fileMl.clipEmbedding!.embedding,
version: fileMl.clipEmbedding!.version,
),
);
existingInstruction.shouldRunClip = false;
}
if (!existingInstruction.pendingML) {
pendingIndex.remove(fileMl.fileID);
} else {
existingInstruction.existingRemoteFileML = fileMl;
pendingIndex[fileMl.fileID] = existingInstruction;
}
}
for (final fileID in pendingIndex.keys) {
final instruction = pendingIndex[fileID]!;
if (instruction.pendingML) {
batchToYield.add(instruction);
if (batchToYield.length == yieldSize) {
_logger.info("queueing indexing for $yieldSize");
yield batchToYield;
batchToYield = [];
}
}
}
await MLDataDB.instance.bulkInsertFaces(faces);
await MLDataDB.instance.putMany(clipEmbeddings);
}
// Yield any remaining instructions
if (batchToYield.isNotEmpty) {
_logger.info("queueing indexing for ${batchToYield.length}");
yield batchToYield;
}
}
// Returns a list of faces from the given remote fileML. null if the version is less than the current version
// or if the remote faceEmbedding is null.
List<Face>? _getFacesFromRemoteEmbedding(FileDataEntity fileMl) {
final RemoteFaceEmbedding? remoteFaceEmbedding = fileMl.faceEmbedding;
if (_shouldDiscardRemoteEmbedding(fileMl)) {
return null;
}
final List<Face> faces = [];
if (remoteFaceEmbedding!.faces.isEmpty) {
faces.add(
Face.empty(fileMl.fileID),
);
} else {
for (final f in remoteFaceEmbedding.faces) {
f.fileInfo = FileInfo(
imageHeight: remoteFaceEmbedding.height,
imageWidth: remoteFaceEmbedding.width,
);
faces.add(f);
}
}
return faces;
}
bool _shouldDiscardRemoteEmbedding(FileDataEntity fileML) {
final fileID = fileML.fileID;
final RemoteFaceEmbedding? faceEmbedding = fileML.faceEmbedding;
if (faceEmbedding == null || faceEmbedding.version < faceMlVersion) {