diff --git a/auth/lib/app/view/app.dart b/auth/lib/app/view/app.dart index 3bd9d4e734..5053f70f12 100644 --- a/auth/lib/app/view/app.dart +++ b/auth/lib/app/view/app.dart @@ -189,7 +189,7 @@ class _AppState extends State with WindowListener, TrayListener { windowManager.show(); break; case 'exit_app': - windowManager.close(); + windowManager.destroy(); break; } } diff --git a/auth/lib/models/code.dart b/auth/lib/models/code.dart index 7a7daf58dd..696d3f2fc1 100644 --- a/auth/lib/models/code.dart +++ b/auth/lib/models/code.dart @@ -128,7 +128,7 @@ class Code { final code = Code( _getAccount(uri), issuer, - _getDigits(uri, issuer), + _getDigits(uri), _getPeriod(uri), getSanitizedSecret(uri.queryParameters['secret']!), _getAlgorithm(uri), @@ -201,11 +201,11 @@ class Code { } } - static int _getDigits(Uri uri, String issuer) { + static int _getDigits(Uri uri) { try { return int.parse(uri.queryParameters['digits']!); } catch (e) { - if (issuer.toLowerCase() == "steam") { + if (uri.host == "steam") { return steamDigits; } return defaultDigits; diff --git a/auth/lib/onboarding/view/setup_enter_secret_key_page.dart b/auth/lib/onboarding/view/setup_enter_secret_key_page.dart index 6741788c30..b4ab8bfd0f 100644 --- a/auth/lib/onboarding/view/setup_enter_secret_key_page.dart +++ b/auth/lib/onboarding/view/setup_enter_secret_key_page.dart @@ -240,7 +240,7 @@ class _SetupEnterSecretKeyPageState extends State { final account = _accountController.text.trim(); final issuer = _issuerController.text.trim(); final secret = _secretController.text.trim().replaceAll(' ', ''); - final isStreamCode = issuer.toLowerCase() == "steam"; + final isStreamCode = issuer.toLowerCase() == "steam" || issuer.toLowerCase().contains('steampowered.com'); if (widget.code != null && widget.code!.secret != secret) { ButtonResult? result = await showChoiceActionSheet( context, diff --git a/auth/lib/ui/code_widget.dart b/auth/lib/ui/code_widget.dart index a690d977e3..cb8b274cad 100644 --- a/auth/lib/ui/code_widget.dart +++ b/auth/lib/ui/code_widget.dart @@ -48,7 +48,6 @@ class _CodeWidgetState extends State { late bool _shouldShowLargeIcon; late bool _hideCode; bool isMaskingEnabled = false; - late final colorScheme = getEnteColorScheme(context); @override void initState() { @@ -78,6 +77,7 @@ class _CodeWidgetState extends State { @override Widget build(BuildContext context) { + final colorScheme = getEnteColorScheme(context); if (isMaskingEnabled != PreferenceService.instance.shouldHideCodes()) { isMaskingEnabled = PreferenceService.instance.shouldHideCodes(); _hideCode = isMaskingEnabled; @@ -91,6 +91,100 @@ class _CodeWidgetState extends State { _isInitialized = true; } final l10n = context.l10n; + + Widget getCardContents(AppLocalizations l10n) { + return Stack( + children: [ + if (widget.code.isPinned) + Align( + alignment: Alignment.topRight, + child: CustomPaint( + painter: PinBgPainter( + color: colorScheme.pinnedBgColor, + ), + size: const Size(39, 39), + ), + ), + Column( + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + if (widget.code.type.isTOTPCompatible) + CodeTimerProgress( + period: widget.code.period, + ), + const SizedBox(height: 16), + Row( + children: [ + _shouldShowLargeIcon ? _getIcon() : const SizedBox.shrink(), + Expanded( + child: Column( + children: [ + _getTopRow(), + const SizedBox(height: 4), + _getBottomRow(l10n), + ], + ), + ), + ], + ), + const SizedBox( + height: 20, + ), + ], + ), + if (widget.code.isPinned) ...[ + Align( + alignment: Alignment.topRight, + child: Padding( + padding: const EdgeInsets.only(right: 6, top: 6), + child: SvgPicture.asset("assets/svg/pin-card.svg"), + ), + ), + ], + ], + ); + } + + Widget clippedCard(AppLocalizations l10n) { + return Container( + height: 132, + decoration: BoxDecoration( + borderRadius: BorderRadius.circular(8), + color: Theme.of(context).colorScheme.codeCardBackgroundColor, + boxShadow: + widget.code.isPinned ? colorScheme.pinnedCardBoxShadow : [], + ), + child: ClipRRect( + borderRadius: BorderRadius.circular(8), + child: Material( + color: Colors.transparent, + child: InkWell( + customBorder: RoundedRectangleBorder( + borderRadius: BorderRadius.circular(10), + ), + onTap: () { + _copyCurrentOTPToClipboard(); + }, + onDoubleTap: isMaskingEnabled + ? () { + setState( + () { + _hideCode = !_hideCode; + }, + ); + } + : null, + onLongPress: () { + _copyCurrentOTPToClipboard(); + }, + child: getCardContents(l10n), + ), + ), + ), + ); + } + return Container( margin: const EdgeInsets.only(left: 16, right: 16, bottom: 8, top: 8), child: Builder( @@ -126,7 +220,7 @@ class _CodeWidgetState extends State { ], padding: const EdgeInsets.all(8.0), ), - child: _clippedCard(l10n), + child: clippedCard(l10n), ); } @@ -216,7 +310,7 @@ class _CodeWidgetState extends State { ], ), child: Builder( - builder: (context) => _clippedCard(l10n), + builder: (context) => clippedCard(l10n), ), ); }, @@ -224,98 +318,6 @@ class _CodeWidgetState extends State { ); } - Widget _clippedCard(AppLocalizations l10n) { - return Container( - height: 132, - decoration: BoxDecoration( - borderRadius: BorderRadius.circular(8), - color: Theme.of(context).colorScheme.codeCardBackgroundColor, - boxShadow: widget.code.isPinned ? colorScheme.pinnedCardBoxShadow : [], - ), - child: ClipRRect( - borderRadius: BorderRadius.circular(8), - child: Material( - color: Colors.transparent, - child: InkWell( - customBorder: RoundedRectangleBorder( - borderRadius: BorderRadius.circular(10), - ), - onTap: () { - _copyCurrentOTPToClipboard(); - }, - onDoubleTap: isMaskingEnabled - ? () { - setState( - () { - _hideCode = !_hideCode; - }, - ); - } - : null, - onLongPress: () { - _copyCurrentOTPToClipboard(); - }, - child: _getCardContents(l10n), - ), - ), - ), - ); - } - - Widget _getCardContents(AppLocalizations l10n) { - return Stack( - children: [ - if (widget.code.isPinned) - Align( - alignment: Alignment.topRight, - child: CustomPaint( - painter: PinBgPainter( - color: colorScheme.pinnedBgColor, - ), - size: const Size(39, 39), - ), - ), - Column( - crossAxisAlignment: CrossAxisAlignment.start, - mainAxisAlignment: MainAxisAlignment.spaceBetween, - children: [ - if (widget.code.type.isTOTPCompatible) - CodeTimerProgress( - period: widget.code.period, - ), - const SizedBox(height: 16), - Row( - children: [ - _shouldShowLargeIcon ? _getIcon() : const SizedBox.shrink(), - Expanded( - child: Column( - children: [ - _getTopRow(), - const SizedBox(height: 4), - _getBottomRow(l10n), - ], - ), - ), - ], - ), - const SizedBox( - height: 20, - ), - ], - ), - if (widget.code.isPinned) ...[ - Align( - alignment: Alignment.topRight, - child: Padding( - padding: const EdgeInsets.only(right: 6, top: 6), - child: SvgPicture.asset("assets/svg/pin-card.svg"), - ), - ), - ], - ], - ); - } - Widget _getBottomRow(AppLocalizations l10n) { return Container( padding: const EdgeInsets.only(left: 16, right: 16), @@ -585,7 +587,7 @@ class _CodeWidgetState extends State { String _getFormattedCode(String code) { if (_hideCode) { // replace all digits with • - code = code.replaceAll(RegExp(r'\d'), '•'); + code = code.replaceAll(RegExp(r'\S'), '•'); } if (code.length == 6) { return "${code.substring(0, 3)} ${code.substring(3, 6)}"; diff --git a/auth/lib/utils/totp_util.dart b/auth/lib/utils/totp_util.dart index d0076ed41f..61c7f20e92 100644 --- a/auth/lib/utils/totp_util.dart +++ b/auth/lib/utils/totp_util.dart @@ -4,7 +4,7 @@ import 'package:otp/otp.dart' as otp; import 'package:steam_totp/steam_totp.dart'; String getOTP(Code code) { - if (code.issuer.toLowerCase() == 'steam') { + if (code.type == Type.steam) { return _getSteamCode(code); } if (code.type == Type.hotp) { @@ -39,7 +39,7 @@ String _getSteamCode(Code code, [bool isNext = false]) { } String getNextTotp(Code code) { - if (code.issuer.toLowerCase() == 'steam') { + if (code.type == Type.steam) { return _getSteamCode(code, true); } return otp.OTP.generateTOTPCodeString( diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index 6a42a0a3be..50de0b9a11 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -242,8 +242,6 @@ Future _init(bool isBackground, {String via = ''}) async { // unawaited(ObjectDetectionService.instance.init()); if (flagService.faceSearchEnabled) { unawaited(FaceMlService.instance.init()); - FaceMlService.instance.listenIndexOnDiffSync(); - FaceMlService.instance.listenOnPeopleChangedSync(); } else { if (LocalSettings.instance.isFaceIndexingEnabled) { unawaited(LocalSettings.instance.toggleFaceIndexing()); diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 9e72f4c55a..165a695ed4 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -9,7 +9,6 @@ import "dart:ui" show Image; import "package:computer/computer.dart"; import "package:dart_ui_isolate/dart_ui_isolate.dart"; import "package:flutter/foundation.dart" show debugPrint, kDebugMode; -import "package:flutter_image_compress/flutter_image_compress.dart"; import "package:logging/logging.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:package_info_plus/package_info_plus.dart"; @@ -74,7 +73,7 @@ class FaceMlService { late ReceivePort _receivePort = ReceivePort(); late SendPort _mainSendPort; - bool isIsolateSpawned = false; + bool _isIsolateSpawned = false; // singleton pattern FaceMlService._privateConstructor(); @@ -91,12 +90,14 @@ class FaceMlService { bool isInitialized = false; late String client; - bool canRunMLController = false; - bool isImageIndexRunning = false; - bool isClusteringRunning = false; - bool shouldSyncPeople = false; + bool debugIndexingDisabled = false; + bool _mlControllerStatus = false; + bool _isIndexingOrClusteringRunning = false; + bool _shouldPauseIndexingAndClustering = false; + bool _shouldSyncPeople = false; + bool _isSyncing = false; - final int _fileDownloadLimit = 15; + final int _fileDownloadLimit = 10; final int _embeddingFetchLimit = 200; Future init({bool initializeImageMlIsolate = false}) async { @@ -133,31 +134,36 @@ class FaceMlService { _logger.info("client: $client"); isInitialized = true; - canRunMLController = !Platform.isAndroid || kDebugMode; + _mlControllerStatus = !Platform.isAndroid; /// hooking FaceML into [MachineLearningController] - if (Platform.isAndroid && !kDebugMode) { - Bus.instance.on().listen((event) { - if (LocalSettings.instance.isFaceIndexingEnabled == false) { - return; - } - canRunMLController = event.shouldRun; - if (canRunMLController) { + Bus.instance.on().listen((event) { + if (LocalSettings.instance.isFaceIndexingEnabled == false) { + return; + } + _mlControllerStatus = event.shouldRun; + if (_mlControllerStatus) { + if (_shouldPauseIndexingAndClustering) { + _shouldPauseIndexingAndClustering = false; + _logger.info( + "MLController allowed running ML, faces indexing undoing previous pause", + ); + } else { _logger.info( "MLController allowed running ML, faces indexing starting", ); unawaited(indexAndClusterAll()); - } else { - _logger - .info("MLController stopped running ML, faces indexing paused"); - pauseIndexing(); } - }); - } else { - if (!kDebugMode) { - unawaited(indexAndClusterAll()); + } else { + _logger.info( + "MLController stopped running ML, faces indexing will be paused (unless it's fetching embeddings)", + ); + pauseIndexingAndClustering(); } - } + }); + + _listenIndexOnDiffSync(); + _listenOnPeopleChangedSync(); }); } @@ -165,24 +171,15 @@ class FaceMlService { OrtEnv.instance.init(); } - void listenIndexOnDiffSync() { + void _listenIndexOnDiffSync() { Bus.instance.on().listen((event) async { - if (LocalSettings.instance.isFaceIndexingEnabled == false || kDebugMode) { - return; - } - // [neeraj] intentional delay in starting indexing on diff sync, this gives time for the user - // to disable face-indexing in case it's causing crash. In the future, we - // should have a better way to handle this. - shouldSyncPeople = true; - Future.delayed(const Duration(seconds: 10), () { - unawaited(indexAndClusterAll()); - }); + unawaited(sync()); }); } - void listenOnPeopleChangedSync() { + void _listenOnPeopleChangedSync() { Bus.instance.on().listen((event) { - shouldSyncPeople = true; + _shouldSyncPeople = true; }); } @@ -218,9 +215,9 @@ class FaceMlService { }); } - Future initIsolate() async { + Future _initIsolate() async { return _initLockIsolate.synchronized(() async { - if (isIsolateSpawned) return; + if (_isIsolateSpawned) return; _logger.info("initIsolate called"); _receivePort = ReceivePort(); @@ -231,19 +228,19 @@ class FaceMlService { _receivePort.sendPort, ); _mainSendPort = await _receivePort.first as SendPort; - isIsolateSpawned = true; + _isIsolateSpawned = true; _resetInactivityTimer(); } catch (e) { _logger.severe('Could not spawn isolate', e); - isIsolateSpawned = false; + _isIsolateSpawned = false; } }); } - Future ensureSpawnedIsolate() async { - if (!isIsolateSpawned) { - await initIsolate(); + Future _ensureSpawnedIsolate() async { + if (!_isIsolateSpawned) { + await _initIsolate(); } } @@ -286,11 +283,11 @@ class FaceMlService { Future _runInIsolate( (FaceMlOperation, Map) message, ) async { - await ensureSpawnedIsolate(); + await _ensureSpawnedIsolate(); return _functionLock.synchronized(() async { _resetInactivityTimer(); - if (isImageIndexRunning == false || canRunMLController == false) { + if (_shouldPauseIndexingAndClustering == false) { return null; } @@ -332,35 +329,42 @@ class FaceMlService { _logger.info( 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', ); - disposeIsolate(); + _disposeIsolate(); } }); } - void disposeIsolate() async { - if (!isIsolateSpawned) return; + void _disposeIsolate() async { + if (!_isIsolateSpawned) return; await release(); - isIsolateSpawned = false; + _isIsolateSpawned = false; _isolate.kill(); _receivePort.close(); _inactivityTimer?.cancel(); } - Future indexAndClusterAll() async { - if (isClusteringRunning || isImageIndexRunning) { - _logger.info("indexing or clustering is already running, skipping"); + Future sync({bool forceSync = true}) async { + if (_isSyncing) { return; } - if (shouldSyncPeople) { + _isSyncing = true; + if (forceSync) { await PersonService.instance.reconcileClusters(); - shouldSyncPeople = false; + _shouldSyncPeople = false; } + _isSyncing = false; + } + + Future indexAndClusterAll() async { + if (_cannotRunMLFunction()) return; + + await sync(forceSync: _shouldSyncPeople); await indexAllImages(); final indexingCompleteRatio = await _getIndexedDoneRatio(); if (indexingCompleteRatio < 0.95) { _logger.info( - "Indexing is not far enough, skipping clustering. Indexing is at $indexingCompleteRatio", + "Indexing is not far enough to start clustering, skipping clustering. Indexing is at $indexingCompleteRatio", ); return; } else { @@ -368,35 +372,195 @@ class FaceMlService { } } + void pauseIndexingAndClustering() { + if (_isIndexingOrClusteringRunning) { + _shouldPauseIndexingAndClustering = true; + } + } + + /// Analyzes all the images in the database with the latest ml version and stores the results in the database. + /// + /// This function first checks if the image has already been analyzed with the lastest faceMlVersion and stored in the database. If so, it skips the image. + Future indexAllImages({int retryFetchCount = 10}) async { + if (_cannotRunMLFunction()) return; + + try { + _isIndexingOrClusteringRunning = true; + _logger.info('starting image indexing'); + + final w = (kDebugMode ? EnteWatch('prepare indexing files') : null) + ?..start(); + final Map alreadyIndexedFiles = + await FaceMLDataDB.instance.getIndexedFileIds(); + w?.log('getIndexedFileIds'); + final List enteFiles = + await SearchService.instance.getAllFiles(); + w?.log('getAllFiles'); + + // Make sure the image conversion isolate is spawned + // await ImageMlIsolate.instance.ensureSpawned(); + await ensureInitialized(); + + int fileAnalyzedCount = 0; + int fileSkippedCount = 0; + final stopwatch = Stopwatch()..start(); + final List filesWithLocalID = []; + final List filesWithoutLocalID = []; + final List hiddenFilesToIndex = []; + w?.log('getIndexableFileIDs'); + + for (final EnteFile enteFile in enteFiles) { + if (_skipAnalysisEnteFile(enteFile, alreadyIndexedFiles)) { + fileSkippedCount++; + continue; + } + if ((enteFile.localID ?? '').isEmpty) { + filesWithoutLocalID.add(enteFile); + } else { + filesWithLocalID.add(enteFile); + } + } + w?.log('sifting through all normal files'); + final List hiddenFiles = + await SearchService.instance.getHiddenFiles(); + w?.log('getHiddenFiles: ${hiddenFiles.length} hidden files'); + for (final EnteFile enteFile in hiddenFiles) { + if (_skipAnalysisEnteFile(enteFile, alreadyIndexedFiles)) { + fileSkippedCount++; + continue; + } + hiddenFilesToIndex.add(enteFile); + } + + // list of files where files with localID are first + final sortedBylocalID = []; + sortedBylocalID.addAll(filesWithLocalID); + sortedBylocalID.addAll(filesWithoutLocalID); + sortedBylocalID.addAll(hiddenFilesToIndex); + w?.log('preparing all files to index'); + final List> chunks = + sortedBylocalID.chunks(_embeddingFetchLimit); + int fetchedCount = 0; + outerLoop: + for (final chunk in chunks) { + final futures = >[]; + + if (LocalSettings.instance.remoteFetchEnabled) { + try { + final List fileIds = []; + // Try to find embeddings on the remote server + for (final f in chunk) { + fileIds.add(f.uploadedFileID!); + } + _logger.info('starting remote fetch for ${fileIds.length} files'); + final res = + await RemoteFileMLService.instance.getFilessEmbedding(fileIds); + _logger.info('fetched ${res.mlData.length} embeddings'); + fetchedCount += res.mlData.length; + final List faces = []; + final remoteFileIdToVersion = {}; + for (FileMl fileMl in res.mlData.values) { + if (_shouldDiscardRemoteEmbedding(fileMl)) continue; + if (fileMl.faceEmbedding.faces.isEmpty) { + faces.add( + Face.empty( + fileMl.fileID, + ), + ); + } else { + for (final f in fileMl.faceEmbedding.faces) { + f.fileInfo = FileInfo( + imageHeight: fileMl.height, + imageWidth: fileMl.width, + ); + faces.add(f); + } + } + remoteFileIdToVersion[fileMl.fileID] = + fileMl.faceEmbedding.version; + } + if (res.noEmbeddingFileIDs.isNotEmpty) { + _logger.info( + 'No embeddings found for ${res.noEmbeddingFileIDs.length} files', + ); + for (final fileID in res.noEmbeddingFileIDs) { + faces.add(Face.empty(fileID, error: false)); + remoteFileIdToVersion[fileID] = faceMlVersion; + } + } + + await FaceMLDataDB.instance.bulkInsertFaces(faces); + _logger.info('stored embeddings'); + for (final entry in remoteFileIdToVersion.entries) { + alreadyIndexedFiles[entry.key] = entry.value; + } + _logger + .info('already indexed files ${remoteFileIdToVersion.length}'); + } catch (e, s) { + _logger.severe("err while getting files embeddings", e, s); + if (retryFetchCount < 1000) { + Future.delayed(Duration(seconds: retryFetchCount), () { + unawaited(indexAllImages(retryFetchCount: retryFetchCount * 2)); + }); + return; + } else { + _logger.severe( + "Failed to fetch embeddings for files after multiple retries", + e, + s, + ); + rethrow; + } + } + } + if (!await canUseHighBandwidth()) { + continue; + } + final smallerChunks = chunk.chunks(_fileDownloadLimit); + for (final smallestChunk in smallerChunks) { + for (final enteFile in smallestChunk) { + if (_shouldPauseIndexingAndClustering) { + _logger.info("indexAllImages() was paused, stopping"); + break outerLoop; + } + if (_skipAnalysisEnteFile( + enteFile, + alreadyIndexedFiles, + )) { + fileSkippedCount++; + continue; + } + futures.add(processImage(enteFile)); + } + final awaitedFutures = await Future.wait(futures); + final sumFutures = awaitedFutures.fold( + 0, + (previousValue, element) => previousValue + (element ? 1 : 0), + ); + fileAnalyzedCount += sumFutures; + } + } + + stopwatch.stop(); + _logger.info( + "`indexAllImages()` finished. Fetched $fetchedCount and analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image, skipped $fileSkippedCount images. MLController status: $_mlControllerStatus)", + ); + } catch (e, s) { + _logger.severe("indexAllImages failed", e, s); + } finally { + _isIndexingOrClusteringRunning = false; + _shouldPauseIndexingAndClustering = false; + } + } + Future clusterAllImages({ double minFaceScore = kMinimumQualityFaceScore, bool clusterInBuckets = true, }) async { - if (!canRunMLController) { - _logger - .info("MLController does not allow running ML, skipping clustering"); - return; - } - if (isClusteringRunning) { - _logger.info("clusterAllImages is already running, skipping"); - return; - } - // verify faces is enabled - if (LocalSettings.instance.isFaceIndexingEnabled == false) { - _logger.warning("clustering is disabled by user"); - return; - } - - final indexingCompleteRatio = await _getIndexedDoneRatio(); - if (indexingCompleteRatio < 0.95) { - _logger.info( - "Indexing is not far enough, skipping clustering. Indexing is at $indexingCompleteRatio", - ); - return; - } + if (_cannotRunMLFunction()) return; _logger.info("`clusterAllImages()` called"); - isClusteringRunning = true; + _isIndexingOrClusteringRunning = true; final clusterAllImagesTime = DateTime.now(); try { @@ -441,7 +605,7 @@ class FaceMlService { int bucket = 1; while (true) { - if (!canRunMLController) { + if (_shouldPauseIndexingAndClustering) { _logger.info( "MLController does not allow running ML, stopping before clustering bucket $bucket", ); @@ -535,193 +699,12 @@ class FaceMlService { } catch (e, s) { _logger.severe("`clusterAllImages` failed", e, s); } finally { - isClusteringRunning = false; + _isIndexingOrClusteringRunning = false; + _shouldPauseIndexingAndClustering = false; } } - /// Analyzes all the images in the database with the latest ml version and stores the results in the database. - /// - /// This function first checks if the image has already been analyzed with the lastest faceMlVersion and stored in the database. If so, it skips the image. - Future indexAllImages({int retryFetchCount = 10}) async { - if (isImageIndexRunning) { - _logger.warning("indexAllImages is already running, skipping"); - return; - } - // verify faces is enabled - if (LocalSettings.instance.isFaceIndexingEnabled == false) { - _logger.warning("indexing is disabled by user"); - return; - } - try { - isImageIndexRunning = true; - _logger.info('starting image indexing'); - - final w = (kDebugMode ? EnteWatch('prepare indexing files') : null) - ?..start(); - final Map alreadyIndexedFiles = - await FaceMLDataDB.instance.getIndexedFileIds(); - w?.log('getIndexedFileIds'); - final List enteFiles = - await SearchService.instance.getAllFiles(); - w?.log('getAllFiles'); - - // Make sure the image conversion isolate is spawned - // await ImageMlIsolate.instance.ensureSpawned(); - await ensureInitialized(); - - int fileAnalyzedCount = 0; - int fileSkippedCount = 0; - final stopwatch = Stopwatch()..start(); - final List filesWithLocalID = []; - final List filesWithoutLocalID = []; - final List hiddenFilesToIndex = []; - w?.log('getIndexableFileIDs'); - - for (final EnteFile enteFile in enteFiles) { - if (_skipAnalysisEnteFile(enteFile, alreadyIndexedFiles)) { - fileSkippedCount++; - continue; - } - if ((enteFile.localID ?? '').isEmpty) { - filesWithoutLocalID.add(enteFile); - } else { - filesWithLocalID.add(enteFile); - } - } - w?.log('sifting through all normal files'); - final List hiddenFiles = - await SearchService.instance.getHiddenFiles(); - w?.log('getHiddenFiles: ${hiddenFiles.length} hidden files'); - for (final EnteFile enteFile in hiddenFiles) { - if (_skipAnalysisEnteFile(enteFile, alreadyIndexedFiles)) { - fileSkippedCount++; - continue; - } - hiddenFilesToIndex.add(enteFile); - } - - // list of files where files with localID are first - final sortedBylocalID = []; - sortedBylocalID.addAll(filesWithLocalID); - sortedBylocalID.addAll(filesWithoutLocalID); - sortedBylocalID.addAll(hiddenFilesToIndex); - w?.log('preparing all files to index'); - final List> chunks = - sortedBylocalID.chunks(_embeddingFetchLimit); - outerLoop: - for (final chunk in chunks) { - final futures = >[]; - - if (LocalSettings.instance.remoteFetchEnabled) { - try { - final List fileIds = []; - // Try to find embeddings on the remote server - for (final f in chunk) { - fileIds.add(f.uploadedFileID!); - } - final EnteWatch? w = - flagService.internalUser ? EnteWatch("face_em_fetch") : null; - w?.start(); - w?.log('starting remote fetch for ${fileIds.length} files'); - final res = - await RemoteFileMLService.instance.getFilessEmbedding(fileIds); - w?.logAndReset('fetched ${res.mlData.length} embeddings'); - final List faces = []; - final remoteFileIdToVersion = {}; - for (FileMl fileMl in res.mlData.values) { - if (shouldDiscardRemoteEmbedding(fileMl)) continue; - if (fileMl.faceEmbedding.faces.isEmpty) { - faces.add( - Face.empty( - fileMl.fileID, - ), - ); - } else { - for (final f in fileMl.faceEmbedding.faces) { - f.fileInfo = FileInfo( - imageHeight: fileMl.height, - imageWidth: fileMl.width, - ); - faces.add(f); - } - } - remoteFileIdToVersion[fileMl.fileID] = - fileMl.faceEmbedding.version; - } - if (res.noEmbeddingFileIDs.isNotEmpty) { - _logger.info( - 'No embeddings found for ${res.noEmbeddingFileIDs.length} files', - ); - for (final fileID in res.noEmbeddingFileIDs) { - faces.add(Face.empty(fileID, error: false)); - remoteFileIdToVersion[fileID] = faceMlVersion; - } - } - - await FaceMLDataDB.instance.bulkInsertFaces(faces); - w?.logAndReset('stored embeddings'); - for (final entry in remoteFileIdToVersion.entries) { - alreadyIndexedFiles[entry.key] = entry.value; - } - _logger - .info('already indexed files ${remoteFileIdToVersion.length}'); - } catch (e, s) { - _logger.severe("err while getting files embeddings", e, s); - if (retryFetchCount < 1000) { - Future.delayed(Duration(seconds: retryFetchCount), () { - unawaited(indexAllImages(retryFetchCount: retryFetchCount * 2)); - }); - return; - } else { - _logger.severe( - "Failed to fetch embeddings for files after multiple retries", - e, - s, - ); - rethrow; - } - } - } - if (!await canUseHighBandwidth()) { - continue; - } - final smallerChunks = chunk.chunks(_fileDownloadLimit); - for (final smallestChunk in smallerChunks) { - for (final enteFile in smallestChunk) { - if (isImageIndexRunning == false) { - _logger.info("indexAllImages() was paused, stopping"); - break outerLoop; - } - if (_skipAnalysisEnteFile( - enteFile, - alreadyIndexedFiles, - )) { - fileSkippedCount++; - continue; - } - futures.add(processImage(enteFile)); - } - final awaitedFutures = await Future.wait(futures); - final sumFutures = awaitedFutures.fold( - 0, - (previousValue, element) => previousValue + (element ? 1 : 0), - ); - fileAnalyzedCount += sumFutures; - } - } - - stopwatch.stop(); - _logger.info( - "`indexAllImages()` finished. Analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image, skipped $fileSkippedCount images. MLController status: $canRunMLController)", - ); - } catch (e, s) { - _logger.severe("indexAllImages failed", e, s); - } finally { - isImageIndexRunning = false; - } - } - - bool shouldDiscardRemoteEmbedding(FileMl fileMl) { + bool _shouldDiscardRemoteEmbedding(FileMl fileMl) { if (fileMl.faceEmbedding.version < faceMlVersion) { debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} " "because version is ${fileMl.faceEmbedding.version} and we need $faceMlVersion"); @@ -769,7 +752,7 @@ class FaceMlService { ); try { - final FaceMlResult? result = await analyzeImageInSingleIsolate( + final FaceMlResult? result = await _analyzeImageInSingleIsolate( enteFile, // preferUsingThumbnailForEverything: false, // disposeImageIsolateAfterUse: false, @@ -861,12 +844,8 @@ class FaceMlService { } } - void pauseIndexing() { - isImageIndexRunning = false; - } - /// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageSync] in the isolate. - Future analyzeImageInSingleIsolate(EnteFile enteFile) async { + Future _analyzeImageInSingleIsolate(EnteFile enteFile) async { _checkEnteFileForID(enteFile); await ensureInitialized(); @@ -1057,94 +1036,6 @@ class FaceMlService { return imagePath; } - @Deprecated('Deprecated in favor of `_getImagePathForML`') - Future _getDataForML( - EnteFile enteFile, { - FileDataForML typeOfData = FileDataForML.fileData, - }) async { - Uint8List? data; - - switch (typeOfData) { - case FileDataForML.fileData: - final stopwatch = Stopwatch()..start(); - final File? actualIoFile = await getFile(enteFile, isOrigin: true); - if (actualIoFile != null) { - data = await actualIoFile.readAsBytes(); - } - stopwatch.stop(); - _logger.info( - "Getting file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", - ); - - break; - - case FileDataForML.thumbnailData: - final stopwatch = Stopwatch()..start(); - data = await getThumbnail(enteFile); - stopwatch.stop(); - _logger.info( - "Getting thumbnail data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", - ); - break; - - case FileDataForML.compressedFileData: - final stopwatch = Stopwatch()..start(); - final String tempPath = Configuration.instance.getTempDirectory() + - "${enteFile.uploadedFileID!}"; - final File? actualIoFile = await getFile(enteFile); - if (actualIoFile != null) { - final compressResult = await FlutterImageCompress.compressAndGetFile( - actualIoFile.path, - tempPath + ".jpg", - ); - if (compressResult != null) { - data = await compressResult.readAsBytes(); - } - } - stopwatch.stop(); - _logger.info( - "Getting compressed file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", - ); - break; - } - - return data; - } - - /// Detects faces in the given image data. - /// - /// `imageData`: The image data to analyze. - /// - /// Returns a list of face detection results. - /// - /// Throws [CouldNotInitializeFaceDetector], [CouldNotRunFaceDetector] or [GeneralFaceMlException] if something goes wrong. - Future> _detectFacesIsolate( - String imagePath, - // Uint8List fileData, - { - FaceMlResultBuilder? resultBuilder, - }) async { - try { - // Get the bounding boxes of the faces - final (List faces, dataSize) = - await FaceDetectionService.instance.predictInComputer(imagePath); - - // Add detected faces to the resultBuilder - if (resultBuilder != null) { - resultBuilder.addNewlyDetectedFaces(faces, dataSize); - } - - return faces; - } on YOLOFaceInterpreterInitializationException { - throw CouldNotInitializeFaceDetector(); - } on YOLOFaceInterpreterRunException { - throw CouldNotRunFaceDetector(); - } catch (e) { - _logger.severe('Face detection failed: $e'); - throw GeneralFaceMlException('Face detection failed: $e'); - } - } - /// Detects faces in the given image data. /// /// `imageData`: The image data to analyze. @@ -1183,38 +1074,6 @@ class FaceMlService { } } - /// Aligns multiple faces from the given image data. - /// - /// `imageData`: The image data in [Uint8List] that contains the faces. - /// `faces`: The face detection results in a list of [FaceDetectionAbsolute] for the faces to align. - /// - /// Returns a list of the aligned faces as image data. - /// - /// Throws [CouldNotWarpAffine] or [GeneralFaceMlException] if the face alignment fails. - Future _alignFaces( - String imagePath, - List faces, { - FaceMlResultBuilder? resultBuilder, - }) async { - try { - final (alignedFaces, alignmentResults, _, blurValues, _) = - await ImageMlIsolate.instance - .preprocessMobileFaceNetOnnx(imagePath, faces); - - if (resultBuilder != null) { - resultBuilder.addAlignmentResults( - alignmentResults, - blurValues, - ); - } - - return alignedFaces; - } catch (e, s) { - _logger.severe('Face alignment failed: $e', e, s); - throw CouldNotWarpAffine(); - } - } - /// Aligns multiple faces from the given image data. /// /// `imageData`: The image data in [Uint8List] that contains the faces. @@ -1256,45 +1115,6 @@ class FaceMlService { } } - /// Embeds multiple faces from the given input matrices. - /// - /// `facesMatrices`: The input matrices of the faces to embed. - /// - /// Returns a list of the face embeddings as lists of doubles. - /// - /// Throws [CouldNotInitializeFaceEmbeddor], [CouldNotRunFaceEmbeddor], [InputProblemFaceEmbeddor] or [GeneralFaceMlException] if the face embedding fails. - Future>> _embedFaces( - Float32List facesList, { - FaceMlResultBuilder? resultBuilder, - }) async { - try { - // Get the embedding of the faces - final List> embeddings = - await FaceEmbeddingService.instance.predictInComputer(facesList); - - // Add the embeddings to the resultBuilder - if (resultBuilder != null) { - resultBuilder.addEmbeddingsToExistingFaces(embeddings); - } - - return embeddings; - } on MobileFaceNetInterpreterInitializationException { - throw CouldNotInitializeFaceEmbeddor(); - } on MobileFaceNetInterpreterRunException { - throw CouldNotRunFaceEmbeddor(); - } on MobileFaceNetEmptyInput { - throw InputProblemFaceEmbeddor("Input is empty"); - } on MobileFaceNetWrongInputSize { - throw InputProblemFaceEmbeddor("Input size is wrong"); - } on MobileFaceNetWrongInputRange { - throw InputProblemFaceEmbeddor("Input range is wrong"); - // ignore: avoid_catches_without_on_clauses - } catch (e) { - _logger.severe('Face embedding (batch) failed: $e'); - throw GeneralFaceMlException('Face embedding (batch) failed: $e'); - } - } - static Future>> embedFacesSync( Float32List facesList, int interpreterAddress, { @@ -1334,10 +1154,9 @@ class FaceMlService { _logger.warning( '''Skipped analysis of image with enteFile, it might be the wrong format or has no uploadedFileID, or MLController doesn't allow it to run. enteFile: ${enteFile.toString()} - isImageIndexRunning: $isImageIndexRunning - canRunML: $canRunMLController ''', ); + _logStatus(); throw CouldNotRetrieveAnyFileData(); } } @@ -1361,7 +1180,8 @@ class FaceMlService { } bool _skipAnalysisEnteFile(EnteFile enteFile, Map indexedFileIds) { - if (isImageIndexRunning == false || canRunMLController == false) { + if (_isIndexingOrClusteringRunning == false || + _mlControllerStatus == false) { return true; } // Skip if the file is not uploaded or not owned by the user @@ -1378,4 +1198,50 @@ class FaceMlService { return indexedFileIds.containsKey(id) && indexedFileIds[id]! >= faceMlVersion; } + + bool _cannotRunMLFunction({String function = ""}) { + if (_isIndexingOrClusteringRunning) { + _logger.info( + "Cannot run $function because indexing or clustering is already running", + ); + _logStatus(); + return true; + } + if (_mlControllerStatus == false) { + _logger.info( + "Cannot run $function because MLController does not allow it", + ); + _logStatus(); + return true; + } + if (debugIndexingDisabled) { + _logger.info( + "Cannot run $function because debugIndexingDisabled is true", + ); + _logStatus(); + return true; + } + if (_shouldPauseIndexingAndClustering) { + // This should ideally not be triggered, because one of the above should be triggered instead. + _logger.warning( + "Cannot run $function because indexing and clustering is being paused", + ); + _logStatus(); + return true; + } + return false; + } + + void _logStatus() { + final String status = ''' + isInternalUser: ${flagService.internalUser} + isFaceIndexingEnabled: ${LocalSettings.instance.isFaceIndexingEnabled} + canRunMLController: $_mlControllerStatus + isIndexingOrClusteringRunning: $_isIndexingOrClusteringRunning + shouldPauseIndexingAndClustering: $_shouldPauseIndexingAndClustering + debugIndexingDisabled: $debugIndexingDisabled + shouldSyncPeople: $_shouldSyncPeople + '''; + _logger.info(status); + } } diff --git a/mobile/lib/services/machine_learning/machine_learning_controller.dart b/mobile/lib/services/machine_learning/machine_learning_controller.dart index 65daf614ce..852ebcd5b5 100644 --- a/mobile/lib/services/machine_learning/machine_learning_controller.dart +++ b/mobile/lib/services/machine_learning/machine_learning_controller.dart @@ -3,6 +3,8 @@ import "dart:io"; import "package:battery_info/battery_info_plugin.dart"; import "package:battery_info/model/android_battery_info.dart"; +import "package:battery_info/model/iso_battery_info.dart"; +import "package:flutter/foundation.dart" show kDebugMode; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/events/machine_learning_control_event.dart"; @@ -17,7 +19,8 @@ class MachineLearningController { static const kMaximumTemperature = 42; // 42 degree celsius static const kMinimumBatteryLevel = 20; // 20% - static const kDefaultInteractionTimeout = Duration(seconds: 15); + static const kDefaultInteractionTimeout = + kDebugMode ? Duration(seconds: 3) : Duration(seconds: 5); static const kUnhealthyStates = ["over_heat", "over_voltage", "dead"]; bool _isDeviceHealthy = true; @@ -31,13 +34,17 @@ class MachineLearningController { BatteryInfoPlugin() .androidBatteryInfoStream .listen((AndroidBatteryInfo? batteryInfo) { - _onBatteryStateUpdate(batteryInfo); + _onAndroidBatteryStateUpdate(batteryInfo); }); - } else { - // Always run Machine Learning on iOS - _canRunML = true; - Bus.instance.fire(MachineLearningControlEvent(true)); } + if (Platform.isIOS) { + BatteryInfoPlugin() + .iosBatteryInfoStream + .listen((IosBatteryInfo? batteryInfo) { + _oniOSBatteryStateUpdate(batteryInfo); + }); + } + _fireControlEvent(); } void onUserInteraction() { @@ -53,7 +60,8 @@ class MachineLearningController { } void _fireControlEvent() { - final shouldRunML = _isDeviceHealthy && !_isUserInteracting; + final shouldRunML = + _isDeviceHealthy && (Platform.isAndroid ? !_isUserInteracting : true); if (shouldRunML != _canRunML) { _canRunML = shouldRunML; _logger.info( @@ -76,18 +84,28 @@ class MachineLearningController { _startInteractionTimer(); } - void _onBatteryStateUpdate(AndroidBatteryInfo? batteryInfo) { + void _onAndroidBatteryStateUpdate(AndroidBatteryInfo? batteryInfo) { _logger.info("Battery info: ${batteryInfo!.toJson()}"); - _isDeviceHealthy = _computeIsDeviceHealthy(batteryInfo); + _isDeviceHealthy = _computeIsAndroidDeviceHealthy(batteryInfo); _fireControlEvent(); } - bool _computeIsDeviceHealthy(AndroidBatteryInfo info) { + void _oniOSBatteryStateUpdate(IosBatteryInfo? batteryInfo) { + _logger.info("Battery info: ${batteryInfo!.toJson()}"); + _isDeviceHealthy = _computeIsiOSDeviceHealthy(batteryInfo); + _fireControlEvent(); + } + + bool _computeIsAndroidDeviceHealthy(AndroidBatteryInfo info) { return _hasSufficientBattery(info.batteryLevel ?? kMinimumBatteryLevel) && _isAcceptableTemperature(info.temperature ?? kMaximumTemperature) && _isBatteryHealthy(info.health ?? ""); } + bool _computeIsiOSDeviceHealthy(IosBatteryInfo info) { + return _hasSufficientBattery(info.batteryLevel ?? kMinimumBatteryLevel); + } + bool _hasSufficientBattery(int batteryLevel) { return batteryLevel >= kMinimumBatteryLevel; } diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index d85b4ceb5d..db1713c2c3 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -1,6 +1,5 @@ import "dart:async"; import "dart:collection"; -import "dart:io"; import "dart:math" show min; import "package:computer/computer.dart"; @@ -103,17 +102,13 @@ class SemanticSearchService { if (shouldSyncImmediately) { unawaited(sync()); } - if (Platform.isAndroid) { - Bus.instance.on().listen((event) { - if (event.shouldRun) { - _startIndexing(); - } else { - _pauseIndexing(); - } - }); - } else { - _startIndexing(); - } + Bus.instance.on().listen((event) { + if (event.shouldRun) { + _startIndexing(); + } else { + _pauseIndexing(); + } + }); } Future release() async { diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index 5e21b03345..1ff73dbc89 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -848,8 +848,9 @@ class SearchService { final String clusterName = "$clusterId"; if (clusterIDToPersonID[clusterId] != null) { - throw Exception( - "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", + // This should not happen, means a faceID is assigned to multiple persons. + _logger.severe( + "`getAllFace`: Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", ); } if (files.length < kMinimumClusterSizeSearchResult && diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart index 01b10ff80c..726a9f2ceb 100644 --- a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -79,7 +79,7 @@ class _FaceDebugSectionWidgetState extends State { final isEnabled = await LocalSettings.instance.toggleFaceIndexing(); if (!isEnabled) { - FaceMlService.instance.pauseIndexing(); + FaceMlService.instance.pauseIndexingAndClustering(); } if (mounted) { setState(() {}); @@ -107,7 +107,7 @@ class _FaceDebugSectionWidgetState extends State { setState(() {}); } } catch (e, s) { - _logger.warning('indexing failed ', e, s); + _logger.warning('Remote fetch toggle failed ', e, s); await showGenericErrorDialog(context: context, error: e); } }, @@ -115,22 +115,25 @@ class _FaceDebugSectionWidgetState extends State { sectionOptionSpacing, MenuItemWidget( captionedTextWidget: CaptionedTextWidget( - title: FaceMlService.instance.canRunMLController - ? "canRunML enabled" - : "canRunML disabled", + title: FaceMlService.instance.debugIndexingDisabled + ? "Debug enable indexing again" + : "Debug disable indexing", ), pressedColor: getEnteColorScheme(context).fillFaint, trailingIcon: Icons.chevron_right_outlined, trailingIconIsMuted: true, onTap: () async { try { - FaceMlService.instance.canRunMLController = - !FaceMlService.instance.canRunMLController; + FaceMlService.instance.debugIndexingDisabled = + !FaceMlService.instance.debugIndexingDisabled; + if (FaceMlService.instance.debugIndexingDisabled) { + FaceMlService.instance.pauseIndexingAndClustering(); + } if (mounted) { setState(() {}); } } catch (e, s) { - _logger.warning('canRunML toggle failed ', e, s); + _logger.warning('debugIndexingDisabled toggle failed ', e, s); await showGenericErrorDialog(context: context, error: e); } }, @@ -145,6 +148,7 @@ class _FaceDebugSectionWidgetState extends State { trailingIconIsMuted: true, onTap: () async { try { + FaceMlService.instance.debugIndexingDisabled = false; unawaited(FaceMlService.instance.indexAndClusterAll()); } catch (e, s) { _logger.warning('indexAndClusterAll failed ', e, s); @@ -162,6 +166,7 @@ class _FaceDebugSectionWidgetState extends State { trailingIconIsMuted: true, onTap: () async { try { + FaceMlService.instance.debugIndexingDisabled = false; unawaited(FaceMlService.instance.indexAllImages()); } catch (e, s) { _logger.warning('indexing failed ', e, s); @@ -189,6 +194,7 @@ class _FaceDebugSectionWidgetState extends State { onTap: () async { try { await PersonService.instance.storeRemoteFeedback(); + FaceMlService.instance.debugIndexingDisabled = false; await FaceMlService.instance .clusterAllImages(clusterInBuckets: true); Bus.instance.fire(PeopleChangedEvent()); diff --git a/mobile/lib/ui/settings/machine_learning_settings_page.dart b/mobile/lib/ui/settings/machine_learning_settings_page.dart index 1e63cf6458..47e2166282 100644 --- a/mobile/lib/ui/settings/machine_learning_settings_page.dart +++ b/mobile/lib/ui/settings/machine_learning_settings_page.dart @@ -208,7 +208,7 @@ class _MachineLearningSettingsPageState if (isEnabled) { unawaited(FaceMlService.instance.ensureInitialized()); } else { - FaceMlService.instance.pauseIndexing(); + FaceMlService.instance.pauseIndexingAndClustering(); } if (mounted) { setState(() {}); diff --git a/mobile/plugins/ente_feature_flag/lib/src/service.dart b/mobile/plugins/ente_feature_flag/lib/src/service.dart index 8d7f226794..ce90352030 100644 --- a/mobile/plugins/ente_feature_flag/lib/src/service.dart +++ b/mobile/plugins/ente_feature_flag/lib/src/service.dart @@ -67,7 +67,7 @@ class FlagService { bool get mapEnabled => flags.mapEnabled; - bool get faceSearchEnabled => internalUser || flags.faceSearchEnabled; + bool get faceSearchEnabled => internalUser || flags.betaUser; bool get passKeyEnabled => flags.passKeyEnabled || internalOrBetaUser; diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index b9d5345c39..31c010671e 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -12,7 +12,7 @@ description: ente photos application # Read more about iOS versioning at # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html -version: 0.8.98+618 +version: 0.8.106+630 publish_to: none environment: diff --git a/web/apps/photos/src/services/download/index.ts b/web/apps/photos/src/services/download/index.ts index d0be660c9f..eb979af875 100644 --- a/web/apps/photos/src/services/download/index.ts +++ b/web/apps/photos/src/services/download/index.ts @@ -11,7 +11,7 @@ import { Remote } from "comlink"; import isElectron from "is-electron"; import * as ffmpeg from "services/ffmpeg"; import { EnteFile } from "types/file"; -import { generateStreamFromArrayBuffer, getRenderableImage } from "utils/file"; +import { getRenderableImage } from "utils/file"; import { PhotosDownloadClient } from "./clients/photos"; import { PublicAlbumsDownloadClient } from "./clients/publicAlbums"; @@ -289,7 +289,7 @@ class DownloadManagerImpl { await this.cryptoWorker.fromB64(file.file.decryptionHeader), file.key, ); - return generateStreamFromArrayBuffer(decrypted); + return new Response(decrypted).body; } catch (e) { if (e.message === CustomError.PROCESSING_FAILED) { log.error( @@ -318,81 +318,78 @@ class DownloadManagerImpl { const contentLength = +res.headers.get("Content-Length") ?? 0; let downloadedBytes = 0; - const stream = new ReadableStream({ - start: async (controller) => { - try { - const decryptionHeader = await this.cryptoWorker.fromB64( - file.file.decryptionHeader, - ); - const fileKey = await this.cryptoWorker.fromB64(file.key); - const { pullState, decryptionChunkSize } = - await this.cryptoWorker.initChunkDecryption( - decryptionHeader, - fileKey, + const decryptionHeader = await this.cryptoWorker.fromB64( + file.file.decryptionHeader, + ); + const fileKey = await this.cryptoWorker.fromB64(file.key); + const { pullState, decryptionChunkSize } = + await this.cryptoWorker.initChunkDecryption( + decryptionHeader, + fileKey, + ); + + let leftoverBytes = new Uint8Array(); + + return new ReadableStream({ + pull: async (controller) => { + // Each time pull is called, we want to enqueue at least once. + let didEnqueue = false; + do { + // done is a boolean and value is an Uint8Array. When done + // is true value will be empty. + const { done, value } = await reader.read(); + + let data: Uint8Array; + if (done) { + data = leftoverBytes; + } else { + downloadedBytes += value.length; + onDownloadProgress({ + loaded: downloadedBytes, + total: contentLength, + }); + + data = new Uint8Array( + leftoverBytes.length + value.length, ); - - let data = new Uint8Array(); - let more = true; - while (more) { - more = false; - - // "done" is a Boolean and value a "Uint8Array" - const { done, value } = await reader.read(); - - // Is there more data to read? - if (!done) { - downloadedBytes += value.length; - onDownloadProgress({ - loaded: downloadedBytes, - total: contentLength, - }); - - const buffer = new Uint8Array( - data.length + value.length, - ); - buffer.set(new Uint8Array(data), 0); - buffer.set(new Uint8Array(value), data.length); - - // Note that buffer.length might be a multiple of - // decryptionChunkSize. We let these accumulate, and - // drain it all with a nested while loop when done. - - if (buffer.length > decryptionChunkSize) { - const { decryptedData } = - await this.cryptoWorker.decryptFileChunk( - buffer.slice(0, decryptionChunkSize), - pullState, - ); - controller.enqueue(decryptedData); - data = buffer.slice(decryptionChunkSize); - } else { - data = buffer; - } - more = true; - } else { - while (data && data.length) { - const { decryptedData } = - await this.cryptoWorker.decryptFileChunk( - data.slice(0, decryptionChunkSize), - pullState, - ); - controller.enqueue(decryptedData); - data = - data.length > decryptionChunkSize - ? data.slice(decryptionChunkSize) - : undefined; - } - controller.close(); - } + data.set(new Uint8Array(leftoverBytes), 0); + data.set(new Uint8Array(value), leftoverBytes.length); } - } catch (e) { - log.error("Failed to process file stream", e); - controller.error(e); - } + + // data.length might be a multiple of decryptionChunkSize, + // and we might need multiple iterations to drain it all. + while (data.length >= decryptionChunkSize) { + const { decryptedData } = + await this.cryptoWorker.decryptFileChunk( + data.slice(0, decryptionChunkSize), + pullState, + ); + controller.enqueue(decryptedData); + didEnqueue = true; + data = data.slice(decryptionChunkSize); + } + + if (done) { + // Send off the remaining bytes without waiting for a + // full chunk, no more bytes are going to come. + if (data.length) { + const { decryptedData } = + await this.cryptoWorker.decryptFileChunk( + data, + pullState, + ); + controller.enqueue(decryptedData); + } + // Don't loop again even if we didn't enqueue. + didEnqueue = true; + controller.close(); + } else { + // Save it for the next pull. + leftoverBytes = data; + } + } while (!didEnqueue); }, }); - - return stream; } trackDownloadProgress = (fileID: number, fileSize: number) => { diff --git a/web/apps/photos/src/services/export/index.ts b/web/apps/photos/src/services/export/index.ts index 3a68837e75..16472b3b4f 100644 --- a/web/apps/photos/src/services/export/index.ts +++ b/web/apps/photos/src/services/export/index.ts @@ -29,7 +29,6 @@ import { getNonEmptyPersonalCollections, } from "utils/collection"; import { - generateStreamFromArrayBuffer, getPersonalFiles, getUpdatedEXIFFileForDownload, mergeMetadata, @@ -1026,7 +1025,6 @@ class ExportService { videoExportName, ); - const imageStream = generateStreamFromArrayBuffer(livePhoto.imageData); await this.saveMetadataFile( collectionExportPath, imageExportName, @@ -1035,10 +1033,9 @@ class ExportService { await writeStream( electron, `${collectionExportPath}/${imageExportName}`, - imageStream, + new Response(livePhoto.imageData).body, ); - const videoStream = generateStreamFromArrayBuffer(livePhoto.videoData); await this.saveMetadataFile( collectionExportPath, videoExportName, @@ -1048,7 +1045,7 @@ class ExportService { await writeStream( electron, `${collectionExportPath}/${videoExportName}`, - videoStream, + new Response(livePhoto.videoData).body, ); } catch (e) { await fs.rm(`${collectionExportPath}/${imageExportName}`); diff --git a/web/apps/photos/src/utils/file/index.ts b/web/apps/photos/src/utils/file/index.ts index af5c06e8e0..f2f9932dd8 100644 --- a/web/apps/photos/src/utils/file/index.ts +++ b/web/apps/photos/src/utils/file/index.ts @@ -262,15 +262,6 @@ export async function decryptFile( } } -export function generateStreamFromArrayBuffer(data: Uint8Array) { - return new ReadableStream({ - async start(controller: ReadableStreamDefaultController) { - controller.enqueue(data); - controller.close(); - }, - }); -} - /** * The returned blob.type is filled in, whenever possible, with the MIME type of * the data that we're dealing with. @@ -649,7 +640,7 @@ async function downloadFileDesktop( imageFileName, fs.exists, ); - const imageStream = generateStreamFromArrayBuffer(imageData); + const imageStream = new Response(imageData).body; await writeStream( electron, `${downloadDir}/${imageExportName}`, @@ -661,7 +652,7 @@ async function downloadFileDesktop( videoFileName, fs.exists, ); - const videoStream = generateStreamFromArrayBuffer(videoData); + const videoStream = new Response(videoData).body; await writeStream( electron, `${downloadDir}/${videoExportName}`, diff --git a/web/packages/next/blob-cache.ts b/web/packages/next/blob-cache.ts index 7223d0fdc1..0db9464521 100644 --- a/web/packages/next/blob-cache.ts +++ b/web/packages/next/blob-cache.ts @@ -136,6 +136,10 @@ export const openBlobCache = async ( * * new Blob([arrayBuffer, andOrAnyArray, andOrstring]) * + * To convert from a Uint8Array/ArrayBuffer/Blob to a ReadableStream + * + * new Response(array).body + * * Refs: * - https://github.com/yigitunallar/arraybuffer-vs-blob * - https://stackoverflow.com/questions/11821096/what-is-the-difference-between-an-arraybuffer-and-a-blob