diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index c72b197b46..9b5f42f540 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -13,6 +13,8 @@ import "package:photos/face/model/face.dart"; import "package:photos/models/file/file.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; +import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; +import "package:photos/utils/ml_util.dart"; import 'package:sqlite_async/sqlite_async.dart'; /// Stores all data for the FacesML-related features. The database can be accessed by `FaceMLDataDB.instance.database`. @@ -249,7 +251,7 @@ class FaceMLDataDB { final List fileId = [recentFileID]; int? avatarFileId; if (avatarFaceId != null) { - avatarFileId = int.tryParse(avatarFaceId.split('_')[0]); + avatarFileId = tryGetFileIdFromFaceId(avatarFaceId); if (avatarFileId != null) { fileId.add(avatarFileId); } @@ -401,8 +403,10 @@ class FaceMLDataDB { final personID = map[personIdColumn] as String; final clusterID = map[fcClusterID] as int; final faceID = map[fcFaceId] as String; - result.putIfAbsent(personID, () => {}).putIfAbsent(clusterID, () => {}) - .add(faceID); + result + .putIfAbsent(personID, () => {}) + .putIfAbsent(clusterID, () => {}) + .add(faceID); } return result; } @@ -476,8 +480,7 @@ class FaceMLDataDB { for (final map in maps) { final clusterID = map[fcClusterID] as int; final faceID = map[fcFaceId] as String; - final x = faceID.split('_').first; - final fileID = int.parse(x); + final fileID = getFileIdFromFaceId(faceID); result[fileID] = (result[fileID] ?? {})..add(clusterID); } return result; @@ -665,19 +668,38 @@ class FaceMLDataDB { return maps.first['count'] as int; } - Future getClusteredFaceCount() async { + Future getClusteredFileCount() async { final db = await instance.asyncDB; final List> maps = await db.getAll( - 'SELECT COUNT(DISTINCT $fcFaceId) as count FROM $faceClustersTable', + 'SELECT $fcFaceId FROM $faceClustersTable', ); - return maps.first['count'] as int; + final Set fileIDs = {}; + for (final map in maps) { + final int fileID = getFileIdFromFaceId(map[fcFaceId] as String); + fileIDs.add(fileID); + } + return fileIDs.length; } - Future getClusteredToTotalFacesRatio() async { - final int totalFaces = await getTotalFaceCount(); - final int clusteredFaces = await getClusteredFaceCount(); + Future getClusteredToIndexableFilesRatio() async { + final int indexableFiles = (await getIndexableFileIDs()).length; + final int clusteredFiles = await getClusteredFileCount(); - return clusteredFaces / totalFaces; + return clusteredFiles / indexableFiles; + } + + Future getUnclusteredFaceCount() async { + final db = await instance.asyncDB; + const String query = ''' + SELECT f.$faceIDColumn + FROM $facesTable f + LEFT JOIN $faceClustersTable fc ON f.$faceIDColumn = fc.$fcFaceId + WHERE f.$faceScore > $kMinimumQualityFaceScore + AND f.$faceBlur > $kLaplacianHardThreshold + AND fc.$fcFaceId IS NULL + '''; + final List> maps = await db.getAll(query); + return maps.length; } Future getBlurryFaceCount([ @@ -795,7 +817,7 @@ class FaceMLDataDB { for (final map in maps) { final clusterID = map[clusterIDColumn] as int; final String faceID = map[fcFaceId] as String; - final fileID = int.parse(faceID.split('_').first); + final fileID = getFileIdFromFaceId(faceID); result[fileID] = (result[fileID] ?? {})..add(clusterID); } return result; @@ -814,8 +836,8 @@ class FaceMLDataDB { final Map> result = {}; for (final map in maps) { final clusterID = map[fcClusterID] as int; - final faceId = map[fcFaceId] as String; - final fileID = int.parse(faceId.split("_").first); + final faceID = map[fcFaceId] as String; + final fileID = getFileIdFromFaceId(faceID); result[fileID] = (result[fileID] ?? {})..add(clusterID); } return result; @@ -964,7 +986,7 @@ class FaceMLDataDB { final Map faceIDToClusterID = {}; for (final row in faceIdsResult) { final faceID = row[fcFaceId] as String; - if (fileIds.contains(faceID.split('_').first)) { + if (fileIds.contains(getFileIdFromFaceId(faceID))) { maxClusterID += 1; faceIDToClusterID[faceID] = maxClusterID; } @@ -990,7 +1012,7 @@ class FaceMLDataDB { final Map faceIDToClusterID = {}; for (final row in faceIdsResult) { final faceID = row[fcFaceId] as String; - if (fileIds.contains(faceID.split('_').first)) { + if (fileIds.contains(getFileIdFromFaceId(faceID))) { maxClusterID += 1; faceIDToClusterID[faceID] = maxClusterID; } diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 1b8d9c3bd5..1a635b0f07 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -498,19 +498,8 @@ class FaceClusteringService { } } - // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first if (fileIDToCreationTime != null) { - faceInfos.sort((a, b) { - if (a.fileCreationTime == null && b.fileCreationTime == null) { - return 0; - } else if (a.fileCreationTime == null) { - return 1; - } else if (b.fileCreationTime == null) { - return -1; - } else { - return a.fileCreationTime!.compareTo(b.fileCreationTime!); - } - }); + _sortFaceInfosOnCreationTime(faceInfos); } // Sort the faceInfos such that the ones with null clusterId are at the end @@ -796,19 +785,8 @@ class FaceClusteringService { ); } - // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first if (fileIDToCreationTime != null) { - faceInfos.sort((a, b) { - if (a.fileCreationTime == null && b.fileCreationTime == null) { - return 0; - } else if (a.fileCreationTime == null) { - return 1; - } else if (b.fileCreationTime == null) { - return -1; - } else { - return a.fileCreationTime!.compareTo(b.fileCreationTime!); - } - }); + _sortFaceInfosOnCreationTime(faceInfos); } if (faceInfos.isEmpty) { @@ -996,19 +974,8 @@ class FaceClusteringService { ); } - // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first if (fileIDToCreationTime != null) { - faceInfos.sort((a, b) { - if (a.fileCreationTime == null && b.fileCreationTime == null) { - return 0; - } else if (a.fileCreationTime == null) { - return 1; - } else if (b.fileCreationTime == null) { - return -1; - } else { - return a.fileCreationTime!.compareTo(b.fileCreationTime!); - } - }); + _sortFaceInfosOnCreationTime(faceInfos); } // Get the embeddings @@ -1027,3 +994,20 @@ class FaceClusteringService { return clusteredFaceIDs; } } + +/// Sort the faceInfos based on fileCreationTime, in descending order, so newest faces are first +void _sortFaceInfosOnCreationTime( + List faceInfos, +) { + faceInfos.sort((b, a) { + if (a.fileCreationTime == null && b.fileCreationTime == null) { + return 0; + } else if (a.fileCreationTime == null) { + return 1; + } else if (b.fileCreationTime == null) { + return -1; + } else { + return a.fileCreationTime!.compareTo(b.fileCreationTime!); + } + }); +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart index 19f954013e..9f87b87220 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart @@ -310,5 +310,9 @@ class FaceResultBuilder { } int getFileIdFromFaceId(String faceId) { - return int.parse(faceId.split("_")[0]); + return int.parse(faceId.split("_").first); } + +int? tryGetFileIdFromFaceId(String faceId) { + return int.tryParse(faceId.split("_").first); +} \ No newline at end of file diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 38079753c2..5f3d15bdc4 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -12,7 +12,6 @@ import "package:flutter/foundation.dart" show debugPrint, kDebugMode; import "package:logging/logging.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:package_info_plus/package_info_plus.dart"; -import "package:photos/core/configuration.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; import "package:photos/events/diff_sync_complete_event.dart"; @@ -99,6 +98,7 @@ class FaceMlService { final int _fileDownloadLimit = 5; final int _embeddingFetchLimit = 200; + final int _kForceClusteringFaceCount = 4000; Future init({bool initializeImageMlIsolate = false}) async { if (LocalSettings.instance.isFaceIndexingEnabled == false) { @@ -358,16 +358,17 @@ class FaceMlService { if (_cannotRunMLFunction()) return; await sync(forceSync: _shouldSyncPeople); - await indexAllImages(); - final indexingCompleteRatio = await _getIndexedDoneRatio(); - if (indexingCompleteRatio < 0.95) { + + final int unclusteredFacesCount = + await FaceMLDataDB.instance.getUnclusteredFaceCount(); + if (unclusteredFacesCount > _kForceClusteringFaceCount) { _logger.info( - "Indexing is not far enough to start clustering, skipping clustering. Indexing is at $indexingCompleteRatio", + "There are $unclusteredFacesCount unclustered faces, doing clustering first", ); - return; - } else { await clusterAllImages(); } + await indexAllImages(); + await clusterAllImages(); } void pauseIndexingAndClustering() { @@ -445,7 +446,7 @@ class FaceMlService { if (LocalSettings.instance.remoteFetchEnabled) { try { - final List fileIds = []; + final Set fileIds = {}; // if there are duplicates here server returns 400 // Try to find embeddings on the remote server for (final f in chunk) { fileIds.add(f.uploadedFileID!); @@ -590,8 +591,8 @@ class FaceMlService { allFaceInfoForClustering.add(faceInfo); } } - // sort the embeddings based on file creation time, oldest first - allFaceInfoForClustering.sort((a, b) { + // sort the embeddings based on file creation time, newest first + allFaceInfoForClustering.sort((b, a) { return fileIDToCreationTime[a.fileID]! .compareTo(fileIDToCreationTime[b.fileID]!); }); @@ -1171,24 +1172,6 @@ class FaceMlService { } } - Future _getIndexedDoneRatio() async { - final w = (kDebugMode ? EnteWatch('_getIndexedDoneRatio') : null)?..start(); - - final int alreadyIndexedCount = await FaceMLDataDB.instance - .getIndexedFileCount(minimumMlVersion: faceMlVersion); - final int totalIndexableCount = (await getIndexableFileIDs()).length; - final ratio = alreadyIndexedCount / totalIndexableCount; - - w?.log('getIndexedDoneRatio'); - - return ratio; - } - - static Future> getIndexableFileIDs() async { - return FilesDB.instance - .getOwnedFileIDs(Configuration.instance.getUserID()!); - } - bool _skipAnalysisEnteFile(EnteFile enteFile, Map indexedFileIds) { if (_isIndexingOrClusteringRunning == false || _mlControllerStatus == false) { diff --git a/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart b/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart index eafbc6323d..4712916d07 100644 --- a/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart +++ b/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart @@ -1,6 +1,7 @@ import "dart:async"; import "dart:convert"; +import "package:computer/computer.dart"; import "package:logging/logging.dart"; import "package:photos/core/network/network.dart"; import "package:photos/db/files_db.dart"; @@ -16,6 +17,8 @@ import "package:shared_preferences/shared_preferences.dart"; class RemoteFileMLService { RemoteFileMLService._privateConstructor(); + static final Computer _computer = Computer.shared(); + static final RemoteFileMLService instance = RemoteFileMLService._privateConstructor(); @@ -52,13 +55,13 @@ class RemoteFileMLService { } Future getFilessEmbedding( - List fileIds, + Set fileIds, ) async { try { final res = await _dio.post( "/embeddings/files", data: { - "fileIDs": fileIds, + "fileIDs": fileIds.toList(), "model": 'file-ml-clip-face', }, ); @@ -107,15 +110,17 @@ class RemoteFileMLService { final input = EmbeddingsDecoderInput(embedding, fileKey); inputs.add(input); } - // todo: use compute or isolate - return decryptFileMLComputer( - { + return _computer.compute, Map>( + _decryptFileMLComputer, + param: { "inputs": inputs, }, ); } - Future> decryptFileMLComputer( +} + +Future> _decryptFileMLComputer( Map args, ) async { final result = {}; @@ -134,5 +139,4 @@ class RemoteFileMLService { result[input.embedding.fileID] = decodedEmbedding; } return result; - } -} + } \ No newline at end of file diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index db1713c2c3..1384750811 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -23,6 +23,7 @@ import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx import "package:photos/utils/debouncer.dart"; import "package:photos/utils/device_info.dart"; import "package:photos/utils/local_settings.dart"; +import "package:photos/utils/ml_util.dart"; import "package:photos/utils/thumbnail_util.dart"; class SemanticSearchService { @@ -160,8 +161,7 @@ class SemanticSearchService { } Future getIndexStatus() async { - final indexableFileIDs = await FilesDB.instance - .getOwnedFileIDs(Configuration.instance.getUserID()!); + final indexableFileIDs = await getIndexableFileIDs(); return IndexStatus( min(_cachedEmbeddings.length, indexableFileIDs.length), (await _getFileIDsToBeIndexed()).length, @@ -222,8 +222,7 @@ class SemanticSearchService { } Future> _getFileIDsToBeIndexed() async { - final uploadedFileIDs = await FilesDB.instance - .getOwnedFileIDs(Configuration.instance.getUserID()!); + final uploadedFileIDs = await getIndexableFileIDs(); final embeddedFileIDs = await EmbeddingsDB.instance.getFileIDs(_currentModel); diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index 1ff73dbc89..d15eddb718 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -754,15 +754,6 @@ class SearchService { Future> getAllFace(int? limit) async { try { - // Don't return anything if clustering is not nearly complete yet - final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount(); - final clusteredFaces = - await FaceMLDataDB.instance.getClusteredFaceCount(); - final clusteringDoneRatio = clusteredFaces / foundFaces; - if (clusteringDoneRatio < 0.9) { - return []; - } - debugPrint("getting faces"); final Map> fileIdToClusterID = await FaceMLDataDB.instance.getFileIdToClusterIds(); diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart index 726a9f2ceb..376793769f 100644 --- a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -177,7 +177,7 @@ class _FaceDebugSectionWidgetState extends State { sectionOptionSpacing, MenuItemWidget( captionedTextWidget: FutureBuilder( - future: FaceMLDataDB.instance.getClusteredToTotalFacesRatio(), + future: FaceMLDataDB.instance.getClusteredToIndexableFilesRatio(), builder: (context, snapshot) { if (snapshot.hasData) { return CaptionedTextWidget( diff --git a/mobile/lib/ui/settings/machine_learning_settings_page.dart b/mobile/lib/ui/settings/machine_learning_settings_page.dart index 47e2166282..4e9178a57d 100644 --- a/mobile/lib/ui/settings/machine_learning_settings_page.dart +++ b/mobile/lib/ui/settings/machine_learning_settings_page.dart @@ -26,6 +26,7 @@ import "package:photos/ui/components/title_bar_widget.dart"; import "package:photos/ui/components/toggle_switch_widget.dart"; import "package:photos/utils/data_util.dart"; import "package:photos/utils/local_settings.dart"; +import "package:photos/utils/ml_util.dart"; final _logger = Logger("MachineLearningSettingsPage"); @@ -438,19 +439,16 @@ class FaceRecognitionStatusWidgetState }); } - Future<(int, int, int, double)> getIndexStatus() async { + Future<(int, int, double)> getIndexStatus() async { try { final indexedFiles = await FaceMLDataDB.instance .getIndexedFileCount(minimumMlVersion: faceMlVersion); - final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length; + final indexableFiles = (await getIndexableFileIDs()).length; final showIndexedFiles = min(indexedFiles, indexableFiles); final pendingFiles = max(indexableFiles - indexedFiles, 0); - final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount(); - final clusteredFaces = - await FaceMLDataDB.instance.getClusteredFaceCount(); - final clusteringDoneRatio = clusteredFaces / foundFaces; + final clusteringDoneRatio = await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio(); - return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio); + return (showIndexedFiles, pendingFiles, clusteringDoneRatio); } catch (e, s) { _logger.severe('Error getting face recognition status', e, s); rethrow; @@ -479,8 +477,7 @@ class FaceRecognitionStatusWidgetState if (snapshot.hasData) { final int indexedFiles = snapshot.data!.$1; final int pendingFiles = snapshot.data!.$2; - final int foundFaces = snapshot.data!.$3; - final double clusteringDoneRatio = snapshot.data!.$4; + final double clusteringDoneRatio = snapshot.data!.$3; final double clusteringPercentage = (clusteringDoneRatio * 100).clamp(0, 100); @@ -512,19 +509,6 @@ class FaceRecognitionStatusWidgetState isGestureDetectorDisabled: true, key: ValueKey("pending_items_" + pendingFiles.toString()), ), - MenuItemWidget( - captionedTextWidget: CaptionedTextWidget( - title: S.of(context).foundFaces, - ), - trailingWidget: Text( - NumberFormat().format(foundFaces), - style: Theme.of(context).textTheme.bodySmall, - ), - singleBorderRadius: 8, - alignCaptionedTextToLeft: true, - isGestureDetectorDisabled: true, - key: ValueKey("found_faces_" + foundFaces.toString()), - ), MenuItemWidget( captionedTextWidget: CaptionedTextWidget( title: S.of(context).clusteringProgress, diff --git a/mobile/lib/utils/ml_util.dart b/mobile/lib/utils/ml_util.dart new file mode 100644 index 0000000000..4033e29349 --- /dev/null +++ b/mobile/lib/utils/ml_util.dart @@ -0,0 +1,7 @@ +import "package:photos/core/configuration.dart"; +import "package:photos/db/files_db.dart"; + +Future> getIndexableFileIDs() async { + return FilesDB.instance + .getOwnedFileIDs(Configuration.instance.getUserID()!); + } \ No newline at end of file diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index a1f6607763..1417d17f3e 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -12,7 +12,7 @@ description: ente photos application # Read more about iOS versioning at # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html -version: 0.8.109+633 +version: 0.8.110+634 publish_to: none environment: diff --git a/web/apps/auth/package.json b/web/apps/auth/package.json index 463ff06e8d..268f6f5c68 100644 --- a/web/apps/auth/package.json +++ b/web/apps/auth/package.json @@ -7,6 +7,7 @@ "@ente/accounts": "*", "@ente/eslint-config": "*", "@ente/shared": "*", + "jssha": "~3.3.1", "otpauth": "^9" } } diff --git a/web/apps/auth/src/pages/auth.tsx b/web/apps/auth/src/pages/auth.tsx index 4006cc9e19..f6661b1b13 100644 --- a/web/apps/auth/src/pages/auth.tsx +++ b/web/apps/auth/src/pages/auth.tsx @@ -46,14 +46,11 @@ const AuthenticatorCodesPage = () => { appContext.showNavBar(false); }, []); + const lcSearch = searchTerm.toLowerCase(); const filteredCodes = codes.filter( - (secret) => - (secret.issuer ?? "") - .toLowerCase() - .includes(searchTerm.toLowerCase()) || - (secret.account ?? "") - .toLowerCase() - .includes(searchTerm.toLowerCase()), + (code) => + code.issuer?.toLowerCase().includes(lcSearch) || + code.account?.toLowerCase().includes(lcSearch), ); if (!hasFetched) { @@ -190,28 +187,21 @@ const CodeDisplay: React.FC = ({ code }) => { useEffect(() => { // Generate to set the initial otp and nextOTP on component mount. regen(); - const codeType = code.type; - const codePeriodInMs = code.period * 1000; - const timeToNextCode = - codePeriodInMs - (new Date().getTime() % codePeriodInMs); - const interval = null; + + const periodMs = code.period * 1000; + const timeToNextCode = periodMs - (Date.now() % periodMs); + + let interval: ReturnType | undefined; // Wait until we are at the start of the next code period, and then // start the interval loop. setTimeout(() => { // We need to call regen() once before the interval loop to set the // initial otp and nextOTP. regen(); - codeType.toLowerCase() === "totp" || - codeType.toLowerCase() === "hotp" - ? setInterval(() => { - regen(); - }, codePeriodInMs) - : null; + interval = setInterval(regen, periodMs); }, timeToNextCode); - return () => { - if (interval) clearInterval(interval); - }; + return () => interval && clearInterval(interval); }, [code]); return ( @@ -270,7 +260,7 @@ const OTPDisplay: React.FC = ({ code, otp, nextOTP }) => { textAlign: "left", }} > - {code.issuer} + {code.issuer ?? ""}

= ({ code, otp, nextOTP }) => { color: "grey", }} > - {code.account} + {code.account ?? ""}

= ({ period }) => { useEffect(() => { const advance = () => { - const timeRemaining = us - ((new Date().getTime() * 1000) % us); + const timeRemaining = us - ((Date.now() * 1000) % us); setProgress(timeRemaining / us); }; diff --git a/web/apps/auth/src/services/code.ts b/web/apps/auth/src/services/code.ts index 064cc78743..b5da0ffe55 100644 --- a/web/apps/auth/src/services/code.ts +++ b/web/apps/auth/src/services/code.ts @@ -1,5 +1,6 @@ +import { ensure } from "@/utils/ensure"; import { HOTP, TOTP } from "otpauth"; -import { URI } from "vscode-uri"; +import { Steam } from "./steam"; /** * A parsed representation of an *OTP code URI. @@ -10,13 +11,19 @@ export interface Code { /** A unique id for the corresponding "auth entity" in our system. */ id?: String; /** The type of the code. */ - type: "totp" | "hotp"; + type: "totp" | "hotp" | "steam"; /** The user's account or email for which this code is used. */ - account: string; + account?: string; /** The name of the entity that issued this code. */ issuer: string; - /** Number of digits in the generated OTP. */ - digits: number; + /** + * Length of the generated OTP. + * + * This is vernacularly called "digits", which is an accurate description + * for the OG TOTP/HOTP codes. However, steam codes are not just digits, so + * we name this as a content-neutral "length". + */ + length: number; /** * The time period (in seconds) for which a single OTP generated from this * code remains valid. @@ -32,7 +39,7 @@ export interface Code { /** The (HMAC) algorithm used by the OTP generator. */ algorithm: "sha1" | "sha256" | "sha512"; /** The original string from which this code was generated. */ - uriString?: string; + uriString: string; } /** @@ -45,100 +52,109 @@ export interface Code { * * - (TOTP) * otpauth://totp/ACME:user@example.org?algorithm=SHA1&digits=6&issuer=acme&period=30&secret=ALPHANUM + * + * See also `auth/test/models/code_test.dart`. */ export const codeFromURIString = (id: string, uriString: string): Code => { - const santizedRawData = uriString - .replaceAll("+", "%2B") - .replaceAll(":", "%3A") - .replaceAll("\r", "") - // trim quotes - .replace(/^"|"$/g, ""); - - const uriParams = {}; - const searchParamsString = - decodeURIComponent(santizedRawData).split("?")[1]; - searchParamsString.split("&").forEach((pair) => { - const [key, value] = pair.split("="); - uriParams[key] = value; - }); - - const uri = URI.parse(santizedRawData); - let uriPath = decodeURIComponent(uri.path); - if (uriPath.startsWith("/otpauth://") || uriPath.startsWith("otpauth://")) { - uriPath = uriPath.split("otpauth://")[1]; - } else if (uriPath.startsWith("otpauth%3A//")) { - uriPath = uriPath.split("otpauth%3A//")[1]; + try { + return _codeFromURIString(id, uriString); + } catch (e) { + // We might have legacy encodings of account names that contain a "#", + // which causes the rest of the URL to be treated as a fragment, and + // ignored. See if this was potentially such a case, otherwise rethrow. + if (uriString.includes("#")) + return _codeFromURIString(id, uriString.replaceAll("#", "%23")); + throw e; } +}; + +const _codeFromURIString = (id: string, uriString: string): Code => { + const url = new URL(uriString); + + // A URL like + // + // new URL("otpauth://hotp/Test?secret=AAABBBCCCDDDEEEFFF&issuer=Test&counter=0") + // + // is parsed differently by the browser and Node depending on the scheme. + // When the scheme is http(s), then both of them consider "hotp" as the + // `host`. However, when the scheme is "otpauth", as is our case, the + // browser considers the entire thing as part of the pathname. so we get. + // + // host: "" + // pathname: "//hotp/Test" + // + // Since this code run on browsers only, we parse as per that behaviour. + + const [type, path] = parsePathname(url); return { id, - type: _getType(uriPath), - account: _getAccount(uriPath), - issuer: _getIssuer(uriPath, uriParams), - digits: parseDigits(uriParams), - period: parsePeriod(uriParams), - secret: parseSecret(uriParams), - algorithm: parseAlgorithm(uriParams), + type, + account: parseAccount(path), + issuer: parseIssuer(url, path), + length: parseLength(url, type), + period: parsePeriod(url), + secret: parseSecret(url), + algorithm: parseAlgorithm(url), uriString, }; }; -const _getType = (uriPath: string): Code["type"] => { - const oauthType = uriPath.split("/")[0].substring(0); - if (oauthType.toLowerCase() === "totp") { - return "totp"; - } else if (oauthType.toLowerCase() === "hotp") { - return "hotp"; - } - throw new Error(`Unsupported format with host ${oauthType}`); +const parsePathname = (url: URL): [type: Code["type"], path: string] => { + const p = url.pathname.toLowerCase(); + if (p.startsWith("//totp")) return ["totp", url.pathname.slice(6)]; + if (p.startsWith("//hotp")) return ["hotp", url.pathname.slice(6)]; + if (p.startsWith("//steam")) return ["steam", url.pathname.slice(7)]; + throw new Error(`Unsupported code or unparseable path "${url.pathname}"`); }; -const _getAccount = (uriPath: string): string => { - try { - const path = decodeURIComponent(uriPath); - if (path.includes(":")) { - return path.split(":")[1]; - } else if (path.includes("/")) { - return path.split("/")[1]; - } - } catch (e) { - return ""; - } +const parseAccount = (path: string): string | undefined => { + // "/ACME:user@example.org" => "user@example.org" + let p = decodeURIComponent(path); + if (p.startsWith("/")) p = p.slice(1); + if (p.includes(":")) p = p.split(":").slice(1).join(":"); + return p; }; -const _getIssuer = (uriPath: string, uriParams: { get?: any }): string => { - try { - if (uriParams["issuer"] !== undefined) { - let issuer = uriParams["issuer"]; - // This is to handle bug in the ente auth app - if (issuer.endsWith("period")) { - issuer = issuer.substring(0, issuer.length - 6); - } - return issuer; +const parseIssuer = (url: URL, path: string): string => { + // If there is a "issuer" search param, use that. + let issuer = url.searchParams.get("issuer"); + if (issuer) { + // This is to handle bug in old versions of Ente Auth app. + if (issuer.endsWith("period")) { + issuer = issuer.substring(0, issuer.length - 6); } - let path = decodeURIComponent(uriPath); - if (path.startsWith("totp/") || path.startsWith("hotp/")) { - path = path.substring(5); - } - if (path.includes(":")) { - return path.split(":")[0]; - } else if (path.includes("-")) { - return path.split("-")[0]; - } - return path; - } catch (e) { - return ""; + return issuer; } + + // Otherwise use the `prefix:` from the account as the issuer. + // "/ACME:user@example.org" => "ACME" + let p = decodeURIComponent(path); + if (p.startsWith("/")) p = p.slice(1); + + if (p.includes(":")) p = p.split(":")[0]; + else if (p.includes("-")) p = p.split("-")[0]; + + return p; }; -const parseDigits = (uriParams): number => - parseInt(uriParams["digits"] ?? "", 10) || 6; +/** + * Parse the length of the generated code. + * + * The URI query param is called digits since originally TOTP/HOTP codes used + * this for generating numeric codes. Now we also support steam, which instead + * shows non-numeric codes, and also with a different default length of 5. + */ +const parseLength = (url: URL, type: Code["type"]): number => { + const defaultLength = type == "steam" ? 5 : 6; + return parseInt(url.searchParams.get("digits") ?? "", 10) || defaultLength; +}; -const parsePeriod = (uriParams): number => - parseInt(uriParams["period"] ?? "", 10) || 30; +const parsePeriod = (url: URL): number => + parseInt(url.searchParams.get("period") ?? "", 10) || 30; -const parseAlgorithm = (uriParams): Code["algorithm"] => { - switch (uriParams["algorithm"]?.toLowerCase()) { +const parseAlgorithm = (url: URL): Code["algorithm"] => { + switch (url.searchParams.get("algorithm")?.toLowerCase()) { case "sha256": return "sha256"; case "sha512": @@ -148,8 +164,8 @@ const parseAlgorithm = (uriParams): Code["algorithm"] => { } }; -const parseSecret = (uriParams): string => - uriParams["secret"].replaceAll(" ", "").toUpperCase(); +const parseSecret = (url: URL): string => + ensure(url.searchParams.get("secret")).replaceAll(" ", "").toUpperCase(); /** * Generate a pair of OTPs (one time passwords) from the given {@link code}. @@ -168,11 +184,11 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => { secret: code.secret, algorithm: code.algorithm, period: code.period, - digits: code.digits, + digits: code.length, }); otp = totp.generate(); nextOTP = totp.generate({ - timestamp: new Date().getTime() + code.period * 1000, + timestamp: Date.now() + code.period * 1000, }); break; } @@ -187,6 +203,17 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => { nextOTP = hotp.generate({ counter: 1 }); break; } + + case "steam": { + const steam = new Steam({ + secret: code.secret, + }); + otp = steam.generate(); + nextOTP = steam.generate({ + timestamp: Date.now() + code.period * 1000, + }); + break; + } } return [otp, nextOTP]; }; diff --git a/web/apps/auth/src/services/remote.ts b/web/apps/auth/src/services/remote.ts index 07b15d7d71..11d57aa23b 100644 --- a/web/apps/auth/src/services/remote.ts +++ b/web/apps/auth/src/services/remote.ts @@ -35,7 +35,7 @@ export const getAuthCodes = async (): Promise => { ); return codeFromURIString(entity.id, decryptedCode); } catch (e) { - log.error(`failed to parse codeId = ${entity.id}`); + log.error(`Failed to parse codeID ${entity.id}`, e); return null; } }), diff --git a/web/apps/auth/src/services/steam.ts b/web/apps/auth/src/services/steam.ts new file mode 100644 index 0000000000..f214640c24 --- /dev/null +++ b/web/apps/auth/src/services/steam.ts @@ -0,0 +1,74 @@ +import jsSHA from "jssha"; +import { Secret } from "otpauth"; + +/** + * Steam OTPs. + * + * Steam's algorithm is a custom variant of TOTP that uses a 26-character + * alphabet instead of digits. + * + * A Dart implementation of the algorithm can be found in + * https://github.com/elliotwutingfeng/steam_totp/blob/main/lib/src/steam_totp_base.dart + * (MIT license), and we use that as a reference. Our implementation is written + * in the style of the other TOTP/HOTP classes that are provided by the otpauth + * JS library that we use for the normal TOTP/HOTP generation + * https://github.com/hectorm/otpauth/blob/master/src/hotp.js (MIT license). + */ +export class Steam { + secret: Secret; + period: number; + + constructor({ secret }: { secret: string }) { + this.secret = Secret.fromBase32(secret); + this.period = 30; + } + + generate({ timestamp }: { timestamp: number } = { timestamp: Date.now() }) { + // Same as regular TOTP. + const counter = Math.floor(timestamp / 1000 / this.period); + + // Same as regular HOTP, but algorithm is fixed to SHA-1. + const digest = sha1HMACDigest(this.secret.buffer, uintToArray(counter)); + + // Same calculation as regular HOTP. + const offset = digest[digest.length - 1] & 15; + let otp = + ((digest[offset] & 127) << 24) | + ((digest[offset + 1] & 255) << 16) | + ((digest[offset + 2] & 255) << 8) | + (digest[offset + 3] & 255); + + // However, instead of using this as the OTP, use it to index into + // the steam OTP alphabet. + const alphabet = "23456789BCDFGHJKMNPQRTVWXY"; + const N = alphabet.length; + const steamOTP = []; + for (let i = 0; i < 5; i++) { + steamOTP.push(alphabet[otp % N]); + otp = Math.trunc(otp / N); + } + return steamOTP.join(""); + } +} + +// Equivalent to +// https://github.com/hectorm/otpauth/blob/master/src/utils/encoding/uint.js +const uintToArray = (n: number): Uint8Array => { + const result = new Uint8Array(8); + for (let i = 7; i >= 0; i--) { + result[i] = n & 255; + n >>= 8; + } + return result; +}; + +// We don't necessarily need a dependency on `jssha`, we could use SubtleCrypto +// here too. However, SubtleCrypto has an async interface, and we already have a +// transitive dependency on `jssha` via `otpauth`, so just using it here doesn't +// increase our bundle size any further. +const sha1HMACDigest = (key: ArrayBuffer, message: Uint8Array) => { + const hmac = new jsSHA("SHA-1", "UINT8ARRAY"); + hmac.setHMACKey(key, "ARRAYBUFFER"); + hmac.update(message); + return hmac.getHMAC("UINT8ARRAY"); +}; diff --git a/web/apps/photos/package.json b/web/apps/photos/package.json index 1541878c51..0ec924b29b 100644 --- a/web/apps/photos/package.json +++ b/web/apps/photos/package.json @@ -43,7 +43,6 @@ "similarity-transformation": "^0.0.1", "transformation-matrix": "^2.16", "uuid": "^9.0.1", - "vscode-uri": "^3.0.7", "xml-js": "^1.6.11", "zxcvbn": "^4.4.2" }, diff --git a/web/docs/dependencies.md b/web/docs/dependencies.md index 83a2b27990..3ea8fb2409 100644 --- a/web/docs/dependencies.md +++ b/web/docs/dependencies.md @@ -198,3 +198,7 @@ some cases. - [otpauth](https://github.com/hectorm/otpauth) is used for the generation of the actual OTP from the user's TOTP/HOTP secret. + +- However, otpauth doesn't support steam OTPs. For these, we need to compute + the SHA-1, and we use the same library, `jssha` that `otpauth` uses (since + it is already part of our bundle). diff --git a/web/yarn.lock b/web/yarn.lock index 894a44dd02..aaa0d517a8 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -4804,11 +4804,6 @@ void-elements@3.1.0: resolved "https://registry.yarnpkg.com/void-elements/-/void-elements-3.1.0.tgz#614f7fbf8d801f0bb5f0661f5b2f5785750e4f09" integrity sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w== -vscode-uri@^3.0.7: - version "3.0.8" - resolved "https://registry.yarnpkg.com/vscode-uri/-/vscode-uri-3.0.8.tgz#1770938d3e72588659a172d0fd4642780083ff9f" - integrity sha512-AyFQ0EVmsOZOlAnxoFOGOq1SQDWAB7C6aqMGS23svWAllfOaxbuFvcT8D1i8z3Gyn8fraVeZNNmN6e9bxxXkKw== - webidl-conversions@^3.0.0: version "3.0.1" resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"