Merge branch 'main' into migrate_files_db_to_sqlite_async
This commit is contained in:
@@ -13,6 +13,8 @@ import "package:photos/face/model/face.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
import 'package:sqlite_async/sqlite_async.dart';
|
||||
|
||||
/// Stores all data for the FacesML-related features. The database can be accessed by `FaceMLDataDB.instance.database`.
|
||||
@@ -249,7 +251,7 @@ class FaceMLDataDB {
|
||||
final List<int> fileId = [recentFileID];
|
||||
int? avatarFileId;
|
||||
if (avatarFaceId != null) {
|
||||
avatarFileId = int.tryParse(avatarFaceId.split('_')[0]);
|
||||
avatarFileId = tryGetFileIdFromFaceId(avatarFaceId);
|
||||
if (avatarFileId != null) {
|
||||
fileId.add(avatarFileId);
|
||||
}
|
||||
@@ -401,8 +403,10 @@ class FaceMLDataDB {
|
||||
final personID = map[personIdColumn] as String;
|
||||
final clusterID = map[fcClusterID] as int;
|
||||
final faceID = map[fcFaceId] as String;
|
||||
result.putIfAbsent(personID, () => {}).putIfAbsent(clusterID, () => {})
|
||||
.add(faceID);
|
||||
result
|
||||
.putIfAbsent(personID, () => {})
|
||||
.putIfAbsent(clusterID, () => {})
|
||||
.add(faceID);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -476,8 +480,7 @@ class FaceMLDataDB {
|
||||
for (final map in maps) {
|
||||
final clusterID = map[fcClusterID] as int;
|
||||
final faceID = map[fcFaceId] as String;
|
||||
final x = faceID.split('_').first;
|
||||
final fileID = int.parse(x);
|
||||
final fileID = getFileIdFromFaceId(faceID);
|
||||
result[fileID] = (result[fileID] ?? {})..add(clusterID);
|
||||
}
|
||||
return result;
|
||||
@@ -665,19 +668,38 @@ class FaceMLDataDB {
|
||||
return maps.first['count'] as int;
|
||||
}
|
||||
|
||||
Future<int> getClusteredFaceCount() async {
|
||||
Future<int> getClusteredFileCount() async {
|
||||
final db = await instance.asyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT COUNT(DISTINCT $fcFaceId) as count FROM $faceClustersTable',
|
||||
'SELECT $fcFaceId FROM $faceClustersTable',
|
||||
);
|
||||
return maps.first['count'] as int;
|
||||
final Set<int> fileIDs = {};
|
||||
for (final map in maps) {
|
||||
final int fileID = getFileIdFromFaceId(map[fcFaceId] as String);
|
||||
fileIDs.add(fileID);
|
||||
}
|
||||
return fileIDs.length;
|
||||
}
|
||||
|
||||
Future<double> getClusteredToTotalFacesRatio() async {
|
||||
final int totalFaces = await getTotalFaceCount();
|
||||
final int clusteredFaces = await getClusteredFaceCount();
|
||||
Future<double> getClusteredToIndexableFilesRatio() async {
|
||||
final int indexableFiles = (await getIndexableFileIDs()).length;
|
||||
final int clusteredFiles = await getClusteredFileCount();
|
||||
|
||||
return clusteredFaces / totalFaces;
|
||||
return clusteredFiles / indexableFiles;
|
||||
}
|
||||
|
||||
Future<int> getUnclusteredFaceCount() async {
|
||||
final db = await instance.asyncDB;
|
||||
const String query = '''
|
||||
SELECT f.$faceIDColumn
|
||||
FROM $facesTable f
|
||||
LEFT JOIN $faceClustersTable fc ON f.$faceIDColumn = fc.$fcFaceId
|
||||
WHERE f.$faceScore > $kMinimumQualityFaceScore
|
||||
AND f.$faceBlur > $kLaplacianHardThreshold
|
||||
AND fc.$fcFaceId IS NULL
|
||||
''';
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(query);
|
||||
return maps.length;
|
||||
}
|
||||
|
||||
Future<int> getBlurryFaceCount([
|
||||
@@ -795,7 +817,7 @@ class FaceMLDataDB {
|
||||
for (final map in maps) {
|
||||
final clusterID = map[clusterIDColumn] as int;
|
||||
final String faceID = map[fcFaceId] as String;
|
||||
final fileID = int.parse(faceID.split('_').first);
|
||||
final fileID = getFileIdFromFaceId(faceID);
|
||||
result[fileID] = (result[fileID] ?? {})..add(clusterID);
|
||||
}
|
||||
return result;
|
||||
@@ -814,8 +836,8 @@ class FaceMLDataDB {
|
||||
final Map<int, Set<int>> result = {};
|
||||
for (final map in maps) {
|
||||
final clusterID = map[fcClusterID] as int;
|
||||
final faceId = map[fcFaceId] as String;
|
||||
final fileID = int.parse(faceId.split("_").first);
|
||||
final faceID = map[fcFaceId] as String;
|
||||
final fileID = getFileIdFromFaceId(faceID);
|
||||
result[fileID] = (result[fileID] ?? {})..add(clusterID);
|
||||
}
|
||||
return result;
|
||||
@@ -964,7 +986,7 @@ class FaceMLDataDB {
|
||||
final Map<String, int> faceIDToClusterID = {};
|
||||
for (final row in faceIdsResult) {
|
||||
final faceID = row[fcFaceId] as String;
|
||||
if (fileIds.contains(faceID.split('_').first)) {
|
||||
if (fileIds.contains(getFileIdFromFaceId(faceID))) {
|
||||
maxClusterID += 1;
|
||||
faceIDToClusterID[faceID] = maxClusterID;
|
||||
}
|
||||
@@ -990,7 +1012,7 @@ class FaceMLDataDB {
|
||||
final Map<String, int> faceIDToClusterID = {};
|
||||
for (final row in faceIdsResult) {
|
||||
final faceID = row[fcFaceId] as String;
|
||||
if (fileIds.contains(faceID.split('_').first)) {
|
||||
if (fileIds.contains(getFileIdFromFaceId(faceID))) {
|
||||
maxClusterID += 1;
|
||||
faceIDToClusterID[faceID] = maxClusterID;
|
||||
}
|
||||
|
||||
@@ -498,19 +498,8 @@ class FaceClusteringService {
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
|
||||
if (fileIDToCreationTime != null) {
|
||||
faceInfos.sort((a, b) {
|
||||
if (a.fileCreationTime == null && b.fileCreationTime == null) {
|
||||
return 0;
|
||||
} else if (a.fileCreationTime == null) {
|
||||
return 1;
|
||||
} else if (b.fileCreationTime == null) {
|
||||
return -1;
|
||||
} else {
|
||||
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
|
||||
}
|
||||
});
|
||||
_sortFaceInfosOnCreationTime(faceInfos);
|
||||
}
|
||||
|
||||
// Sort the faceInfos such that the ones with null clusterId are at the end
|
||||
@@ -796,19 +785,8 @@ class FaceClusteringService {
|
||||
);
|
||||
}
|
||||
|
||||
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
|
||||
if (fileIDToCreationTime != null) {
|
||||
faceInfos.sort((a, b) {
|
||||
if (a.fileCreationTime == null && b.fileCreationTime == null) {
|
||||
return 0;
|
||||
} else if (a.fileCreationTime == null) {
|
||||
return 1;
|
||||
} else if (b.fileCreationTime == null) {
|
||||
return -1;
|
||||
} else {
|
||||
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
|
||||
}
|
||||
});
|
||||
_sortFaceInfosOnCreationTime(faceInfos);
|
||||
}
|
||||
|
||||
if (faceInfos.isEmpty) {
|
||||
@@ -996,19 +974,8 @@ class FaceClusteringService {
|
||||
);
|
||||
}
|
||||
|
||||
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
|
||||
if (fileIDToCreationTime != null) {
|
||||
faceInfos.sort((a, b) {
|
||||
if (a.fileCreationTime == null && b.fileCreationTime == null) {
|
||||
return 0;
|
||||
} else if (a.fileCreationTime == null) {
|
||||
return 1;
|
||||
} else if (b.fileCreationTime == null) {
|
||||
return -1;
|
||||
} else {
|
||||
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
|
||||
}
|
||||
});
|
||||
_sortFaceInfosOnCreationTime(faceInfos);
|
||||
}
|
||||
|
||||
// Get the embeddings
|
||||
@@ -1027,3 +994,20 @@ class FaceClusteringService {
|
||||
return clusteredFaceIDs;
|
||||
}
|
||||
}
|
||||
|
||||
/// Sort the faceInfos based on fileCreationTime, in descending order, so newest faces are first
|
||||
void _sortFaceInfosOnCreationTime(
|
||||
List<FaceInfo> faceInfos,
|
||||
) {
|
||||
faceInfos.sort((b, a) {
|
||||
if (a.fileCreationTime == null && b.fileCreationTime == null) {
|
||||
return 0;
|
||||
} else if (a.fileCreationTime == null) {
|
||||
return 1;
|
||||
} else if (b.fileCreationTime == null) {
|
||||
return -1;
|
||||
} else {
|
||||
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -310,5 +310,9 @@ class FaceResultBuilder {
|
||||
}
|
||||
|
||||
int getFileIdFromFaceId(String faceId) {
|
||||
return int.parse(faceId.split("_")[0]);
|
||||
return int.parse(faceId.split("_").first);
|
||||
}
|
||||
|
||||
int? tryGetFileIdFromFaceId(String faceId) {
|
||||
return int.tryParse(faceId.split("_").first);
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import "package:flutter/foundation.dart" show debugPrint, kDebugMode;
|
||||
import "package:logging/logging.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:package_info_plus/package_info_plus.dart";
|
||||
import "package:photos/core/configuration.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
import "package:photos/events/diff_sync_complete_event.dart";
|
||||
@@ -99,6 +98,7 @@ class FaceMlService {
|
||||
|
||||
final int _fileDownloadLimit = 5;
|
||||
final int _embeddingFetchLimit = 200;
|
||||
final int _kForceClusteringFaceCount = 4000;
|
||||
|
||||
Future<void> init({bool initializeImageMlIsolate = false}) async {
|
||||
if (LocalSettings.instance.isFaceIndexingEnabled == false) {
|
||||
@@ -358,16 +358,17 @@ class FaceMlService {
|
||||
if (_cannotRunMLFunction()) return;
|
||||
|
||||
await sync(forceSync: _shouldSyncPeople);
|
||||
await indexAllImages();
|
||||
final indexingCompleteRatio = await _getIndexedDoneRatio();
|
||||
if (indexingCompleteRatio < 0.95) {
|
||||
|
||||
final int unclusteredFacesCount =
|
||||
await FaceMLDataDB.instance.getUnclusteredFaceCount();
|
||||
if (unclusteredFacesCount > _kForceClusteringFaceCount) {
|
||||
_logger.info(
|
||||
"Indexing is not far enough to start clustering, skipping clustering. Indexing is at $indexingCompleteRatio",
|
||||
"There are $unclusteredFacesCount unclustered faces, doing clustering first",
|
||||
);
|
||||
return;
|
||||
} else {
|
||||
await clusterAllImages();
|
||||
}
|
||||
await indexAllImages();
|
||||
await clusterAllImages();
|
||||
}
|
||||
|
||||
void pauseIndexingAndClustering() {
|
||||
@@ -445,7 +446,7 @@ class FaceMlService {
|
||||
|
||||
if (LocalSettings.instance.remoteFetchEnabled) {
|
||||
try {
|
||||
final List<int> fileIds = [];
|
||||
final Set<int> fileIds = {}; // if there are duplicates here server returns 400
|
||||
// Try to find embeddings on the remote server
|
||||
for (final f in chunk) {
|
||||
fileIds.add(f.uploadedFileID!);
|
||||
@@ -590,8 +591,8 @@ class FaceMlService {
|
||||
allFaceInfoForClustering.add(faceInfo);
|
||||
}
|
||||
}
|
||||
// sort the embeddings based on file creation time, oldest first
|
||||
allFaceInfoForClustering.sort((a, b) {
|
||||
// sort the embeddings based on file creation time, newest first
|
||||
allFaceInfoForClustering.sort((b, a) {
|
||||
return fileIDToCreationTime[a.fileID]!
|
||||
.compareTo(fileIDToCreationTime[b.fileID]!);
|
||||
});
|
||||
@@ -1171,24 +1172,6 @@ class FaceMlService {
|
||||
}
|
||||
}
|
||||
|
||||
Future<double> _getIndexedDoneRatio() async {
|
||||
final w = (kDebugMode ? EnteWatch('_getIndexedDoneRatio') : null)?..start();
|
||||
|
||||
final int alreadyIndexedCount = await FaceMLDataDB.instance
|
||||
.getIndexedFileCount(minimumMlVersion: faceMlVersion);
|
||||
final int totalIndexableCount = (await getIndexableFileIDs()).length;
|
||||
final ratio = alreadyIndexedCount / totalIndexableCount;
|
||||
|
||||
w?.log('getIndexedDoneRatio');
|
||||
|
||||
return ratio;
|
||||
}
|
||||
|
||||
static Future<List<int>> getIndexableFileIDs() async {
|
||||
return FilesDB.instance
|
||||
.getOwnedFileIDs(Configuration.instance.getUserID()!);
|
||||
}
|
||||
|
||||
bool _skipAnalysisEnteFile(EnteFile enteFile, Map<int, int> indexedFileIds) {
|
||||
if (_isIndexingOrClusteringRunning == false ||
|
||||
_mlControllerStatus == false) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import "dart:async";
|
||||
import "dart:convert";
|
||||
|
||||
import "package:computer/computer.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:photos/core/network/network.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
@@ -16,6 +17,8 @@ import "package:shared_preferences/shared_preferences.dart";
|
||||
class RemoteFileMLService {
|
||||
RemoteFileMLService._privateConstructor();
|
||||
|
||||
static final Computer _computer = Computer.shared();
|
||||
|
||||
static final RemoteFileMLService instance =
|
||||
RemoteFileMLService._privateConstructor();
|
||||
|
||||
@@ -52,13 +55,13 @@ class RemoteFileMLService {
|
||||
}
|
||||
|
||||
Future<FilesMLDataResponse> getFilessEmbedding(
|
||||
List<int> fileIds,
|
||||
Set<int> fileIds,
|
||||
) async {
|
||||
try {
|
||||
final res = await _dio.post(
|
||||
"/embeddings/files",
|
||||
data: {
|
||||
"fileIDs": fileIds,
|
||||
"fileIDs": fileIds.toList(),
|
||||
"model": 'file-ml-clip-face',
|
||||
},
|
||||
);
|
||||
@@ -107,15 +110,17 @@ class RemoteFileMLService {
|
||||
final input = EmbeddingsDecoderInput(embedding, fileKey);
|
||||
inputs.add(input);
|
||||
}
|
||||
// todo: use compute or isolate
|
||||
return decryptFileMLComputer(
|
||||
{
|
||||
return _computer.compute<Map<String, dynamic>, Map<int, FileMl>>(
|
||||
_decryptFileMLComputer,
|
||||
param: {
|
||||
"inputs": inputs,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Future<Map<int, FileMl>> decryptFileMLComputer(
|
||||
}
|
||||
|
||||
Future<Map<int, FileMl>> _decryptFileMLComputer(
|
||||
Map<String, dynamic> args,
|
||||
) async {
|
||||
final result = <int, FileMl>{};
|
||||
@@ -134,5 +139,4 @@ class RemoteFileMLService {
|
||||
result[input.embedding.fileID] = decodedEmbedding;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx
|
||||
import "package:photos/utils/debouncer.dart";
|
||||
import "package:photos/utils/device_info.dart";
|
||||
import "package:photos/utils/local_settings.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
import "package:photos/utils/thumbnail_util.dart";
|
||||
|
||||
class SemanticSearchService {
|
||||
@@ -160,8 +161,7 @@ class SemanticSearchService {
|
||||
}
|
||||
|
||||
Future<IndexStatus> getIndexStatus() async {
|
||||
final indexableFileIDs = await FilesDB.instance
|
||||
.getOwnedFileIDs(Configuration.instance.getUserID()!);
|
||||
final indexableFileIDs = await getIndexableFileIDs();
|
||||
return IndexStatus(
|
||||
min(_cachedEmbeddings.length, indexableFileIDs.length),
|
||||
(await _getFileIDsToBeIndexed()).length,
|
||||
@@ -222,8 +222,7 @@ class SemanticSearchService {
|
||||
}
|
||||
|
||||
Future<List<int>> _getFileIDsToBeIndexed() async {
|
||||
final uploadedFileIDs = await FilesDB.instance
|
||||
.getOwnedFileIDs(Configuration.instance.getUserID()!);
|
||||
final uploadedFileIDs = await getIndexableFileIDs();
|
||||
final embeddedFileIDs =
|
||||
await EmbeddingsDB.instance.getFileIDs(_currentModel);
|
||||
|
||||
|
||||
@@ -754,15 +754,6 @@ class SearchService {
|
||||
|
||||
Future<List<GenericSearchResult>> getAllFace(int? limit) async {
|
||||
try {
|
||||
// Don't return anything if clustering is not nearly complete yet
|
||||
final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
|
||||
final clusteredFaces =
|
||||
await FaceMLDataDB.instance.getClusteredFaceCount();
|
||||
final clusteringDoneRatio = clusteredFaces / foundFaces;
|
||||
if (clusteringDoneRatio < 0.9) {
|
||||
return [];
|
||||
}
|
||||
|
||||
debugPrint("getting faces");
|
||||
final Map<int, Set<int>> fileIdToClusterID =
|
||||
await FaceMLDataDB.instance.getFileIdToClusterIds();
|
||||
|
||||
@@ -177,7 +177,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
|
||||
sectionOptionSpacing,
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: FutureBuilder<double>(
|
||||
future: FaceMLDataDB.instance.getClusteredToTotalFacesRatio(),
|
||||
future: FaceMLDataDB.instance.getClusteredToIndexableFilesRatio(),
|
||||
builder: (context, snapshot) {
|
||||
if (snapshot.hasData) {
|
||||
return CaptionedTextWidget(
|
||||
|
||||
@@ -26,6 +26,7 @@ import "package:photos/ui/components/title_bar_widget.dart";
|
||||
import "package:photos/ui/components/toggle_switch_widget.dart";
|
||||
import "package:photos/utils/data_util.dart";
|
||||
import "package:photos/utils/local_settings.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
|
||||
final _logger = Logger("MachineLearningSettingsPage");
|
||||
|
||||
@@ -438,19 +439,16 @@ class FaceRecognitionStatusWidgetState
|
||||
});
|
||||
}
|
||||
|
||||
Future<(int, int, int, double)> getIndexStatus() async {
|
||||
Future<(int, int, double)> getIndexStatus() async {
|
||||
try {
|
||||
final indexedFiles = await FaceMLDataDB.instance
|
||||
.getIndexedFileCount(minimumMlVersion: faceMlVersion);
|
||||
final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length;
|
||||
final indexableFiles = (await getIndexableFileIDs()).length;
|
||||
final showIndexedFiles = min(indexedFiles, indexableFiles);
|
||||
final pendingFiles = max(indexableFiles - indexedFiles, 0);
|
||||
final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
|
||||
final clusteredFaces =
|
||||
await FaceMLDataDB.instance.getClusteredFaceCount();
|
||||
final clusteringDoneRatio = clusteredFaces / foundFaces;
|
||||
final clusteringDoneRatio = await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
|
||||
|
||||
return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio);
|
||||
return (showIndexedFiles, pendingFiles, clusteringDoneRatio);
|
||||
} catch (e, s) {
|
||||
_logger.severe('Error getting face recognition status', e, s);
|
||||
rethrow;
|
||||
@@ -479,8 +477,7 @@ class FaceRecognitionStatusWidgetState
|
||||
if (snapshot.hasData) {
|
||||
final int indexedFiles = snapshot.data!.$1;
|
||||
final int pendingFiles = snapshot.data!.$2;
|
||||
final int foundFaces = snapshot.data!.$3;
|
||||
final double clusteringDoneRatio = snapshot.data!.$4;
|
||||
final double clusteringDoneRatio = snapshot.data!.$3;
|
||||
final double clusteringPercentage =
|
||||
(clusteringDoneRatio * 100).clamp(0, 100);
|
||||
|
||||
@@ -512,19 +509,6 @@ class FaceRecognitionStatusWidgetState
|
||||
isGestureDetectorDisabled: true,
|
||||
key: ValueKey("pending_items_" + pendingFiles.toString()),
|
||||
),
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: CaptionedTextWidget(
|
||||
title: S.of(context).foundFaces,
|
||||
),
|
||||
trailingWidget: Text(
|
||||
NumberFormat().format(foundFaces),
|
||||
style: Theme.of(context).textTheme.bodySmall,
|
||||
),
|
||||
singleBorderRadius: 8,
|
||||
alignCaptionedTextToLeft: true,
|
||||
isGestureDetectorDisabled: true,
|
||||
key: ValueKey("found_faces_" + foundFaces.toString()),
|
||||
),
|
||||
MenuItemWidget(
|
||||
captionedTextWidget: CaptionedTextWidget(
|
||||
title: S.of(context).clusteringProgress,
|
||||
|
||||
7
mobile/lib/utils/ml_util.dart
Normal file
7
mobile/lib/utils/ml_util.dart
Normal file
@@ -0,0 +1,7 @@
|
||||
import "package:photos/core/configuration.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
|
||||
Future<List<int>> getIndexableFileIDs() async {
|
||||
return FilesDB.instance
|
||||
.getOwnedFileIDs(Configuration.instance.getUserID()!);
|
||||
}
|
||||
@@ -12,7 +12,7 @@ description: ente photos application
|
||||
# Read more about iOS versioning at
|
||||
# https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
|
||||
|
||||
version: 0.8.109+633
|
||||
version: 0.8.110+634
|
||||
publish_to: none
|
||||
|
||||
environment:
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
"@ente/accounts": "*",
|
||||
"@ente/eslint-config": "*",
|
||||
"@ente/shared": "*",
|
||||
"jssha": "~3.3.1",
|
||||
"otpauth": "^9"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,14 +46,11 @@ const AuthenticatorCodesPage = () => {
|
||||
appContext.showNavBar(false);
|
||||
}, []);
|
||||
|
||||
const lcSearch = searchTerm.toLowerCase();
|
||||
const filteredCodes = codes.filter(
|
||||
(secret) =>
|
||||
(secret.issuer ?? "")
|
||||
.toLowerCase()
|
||||
.includes(searchTerm.toLowerCase()) ||
|
||||
(secret.account ?? "")
|
||||
.toLowerCase()
|
||||
.includes(searchTerm.toLowerCase()),
|
||||
(code) =>
|
||||
code.issuer?.toLowerCase().includes(lcSearch) ||
|
||||
code.account?.toLowerCase().includes(lcSearch),
|
||||
);
|
||||
|
||||
if (!hasFetched) {
|
||||
@@ -190,28 +187,21 @@ const CodeDisplay: React.FC<CodeDisplay> = ({ code }) => {
|
||||
useEffect(() => {
|
||||
// Generate to set the initial otp and nextOTP on component mount.
|
||||
regen();
|
||||
const codeType = code.type;
|
||||
const codePeriodInMs = code.period * 1000;
|
||||
const timeToNextCode =
|
||||
codePeriodInMs - (new Date().getTime() % codePeriodInMs);
|
||||
const interval = null;
|
||||
|
||||
const periodMs = code.period * 1000;
|
||||
const timeToNextCode = periodMs - (Date.now() % periodMs);
|
||||
|
||||
let interval: ReturnType<typeof setInterval> | undefined;
|
||||
// Wait until we are at the start of the next code period, and then
|
||||
// start the interval loop.
|
||||
setTimeout(() => {
|
||||
// We need to call regen() once before the interval loop to set the
|
||||
// initial otp and nextOTP.
|
||||
regen();
|
||||
codeType.toLowerCase() === "totp" ||
|
||||
codeType.toLowerCase() === "hotp"
|
||||
? setInterval(() => {
|
||||
regen();
|
||||
}, codePeriodInMs)
|
||||
: null;
|
||||
interval = setInterval(regen, periodMs);
|
||||
}, timeToNextCode);
|
||||
|
||||
return () => {
|
||||
if (interval) clearInterval(interval);
|
||||
};
|
||||
return () => interval && clearInterval(interval);
|
||||
}, [code]);
|
||||
|
||||
return (
|
||||
@@ -270,7 +260,7 @@ const OTPDisplay: React.FC<OTPDisplayProps> = ({ code, otp, nextOTP }) => {
|
||||
textAlign: "left",
|
||||
}}
|
||||
>
|
||||
{code.issuer}
|
||||
{code.issuer ?? ""}
|
||||
</p>
|
||||
<p
|
||||
style={{
|
||||
@@ -283,7 +273,7 @@ const OTPDisplay: React.FC<OTPDisplayProps> = ({ code, otp, nextOTP }) => {
|
||||
color: "grey",
|
||||
}}
|
||||
>
|
||||
{code.account}
|
||||
{code.account ?? ""}
|
||||
</p>
|
||||
<p
|
||||
style={{
|
||||
@@ -349,7 +339,7 @@ const TimerProgress: React.FC<TimerProgressProps> = ({ period }) => {
|
||||
|
||||
useEffect(() => {
|
||||
const advance = () => {
|
||||
const timeRemaining = us - ((new Date().getTime() * 1000) % us);
|
||||
const timeRemaining = us - ((Date.now() * 1000) % us);
|
||||
setProgress(timeRemaining / us);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { ensure } from "@/utils/ensure";
|
||||
import { HOTP, TOTP } from "otpauth";
|
||||
import { URI } from "vscode-uri";
|
||||
import { Steam } from "./steam";
|
||||
|
||||
/**
|
||||
* A parsed representation of an *OTP code URI.
|
||||
@@ -10,13 +11,19 @@ export interface Code {
|
||||
/** A unique id for the corresponding "auth entity" in our system. */
|
||||
id?: String;
|
||||
/** The type of the code. */
|
||||
type: "totp" | "hotp";
|
||||
type: "totp" | "hotp" | "steam";
|
||||
/** The user's account or email for which this code is used. */
|
||||
account: string;
|
||||
account?: string;
|
||||
/** The name of the entity that issued this code. */
|
||||
issuer: string;
|
||||
/** Number of digits in the generated OTP. */
|
||||
digits: number;
|
||||
/**
|
||||
* Length of the generated OTP.
|
||||
*
|
||||
* This is vernacularly called "digits", which is an accurate description
|
||||
* for the OG TOTP/HOTP codes. However, steam codes are not just digits, so
|
||||
* we name this as a content-neutral "length".
|
||||
*/
|
||||
length: number;
|
||||
/**
|
||||
* The time period (in seconds) for which a single OTP generated from this
|
||||
* code remains valid.
|
||||
@@ -32,7 +39,7 @@ export interface Code {
|
||||
/** The (HMAC) algorithm used by the OTP generator. */
|
||||
algorithm: "sha1" | "sha256" | "sha512";
|
||||
/** The original string from which this code was generated. */
|
||||
uriString?: string;
|
||||
uriString: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -45,100 +52,109 @@ export interface Code {
|
||||
*
|
||||
* - (TOTP)
|
||||
* otpauth://totp/ACME:user@example.org?algorithm=SHA1&digits=6&issuer=acme&period=30&secret=ALPHANUM
|
||||
*
|
||||
* See also `auth/test/models/code_test.dart`.
|
||||
*/
|
||||
export const codeFromURIString = (id: string, uriString: string): Code => {
|
||||
const santizedRawData = uriString
|
||||
.replaceAll("+", "%2B")
|
||||
.replaceAll(":", "%3A")
|
||||
.replaceAll("\r", "")
|
||||
// trim quotes
|
||||
.replace(/^"|"$/g, "");
|
||||
|
||||
const uriParams = {};
|
||||
const searchParamsString =
|
||||
decodeURIComponent(santizedRawData).split("?")[1];
|
||||
searchParamsString.split("&").forEach((pair) => {
|
||||
const [key, value] = pair.split("=");
|
||||
uriParams[key] = value;
|
||||
});
|
||||
|
||||
const uri = URI.parse(santizedRawData);
|
||||
let uriPath = decodeURIComponent(uri.path);
|
||||
if (uriPath.startsWith("/otpauth://") || uriPath.startsWith("otpauth://")) {
|
||||
uriPath = uriPath.split("otpauth://")[1];
|
||||
} else if (uriPath.startsWith("otpauth%3A//")) {
|
||||
uriPath = uriPath.split("otpauth%3A//")[1];
|
||||
try {
|
||||
return _codeFromURIString(id, uriString);
|
||||
} catch (e) {
|
||||
// We might have legacy encodings of account names that contain a "#",
|
||||
// which causes the rest of the URL to be treated as a fragment, and
|
||||
// ignored. See if this was potentially such a case, otherwise rethrow.
|
||||
if (uriString.includes("#"))
|
||||
return _codeFromURIString(id, uriString.replaceAll("#", "%23"));
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
const _codeFromURIString = (id: string, uriString: string): Code => {
|
||||
const url = new URL(uriString);
|
||||
|
||||
// A URL like
|
||||
//
|
||||
// new URL("otpauth://hotp/Test?secret=AAABBBCCCDDDEEEFFF&issuer=Test&counter=0")
|
||||
//
|
||||
// is parsed differently by the browser and Node depending on the scheme.
|
||||
// When the scheme is http(s), then both of them consider "hotp" as the
|
||||
// `host`. However, when the scheme is "otpauth", as is our case, the
|
||||
// browser considers the entire thing as part of the pathname. so we get.
|
||||
//
|
||||
// host: ""
|
||||
// pathname: "//hotp/Test"
|
||||
//
|
||||
// Since this code run on browsers only, we parse as per that behaviour.
|
||||
|
||||
const [type, path] = parsePathname(url);
|
||||
|
||||
return {
|
||||
id,
|
||||
type: _getType(uriPath),
|
||||
account: _getAccount(uriPath),
|
||||
issuer: _getIssuer(uriPath, uriParams),
|
||||
digits: parseDigits(uriParams),
|
||||
period: parsePeriod(uriParams),
|
||||
secret: parseSecret(uriParams),
|
||||
algorithm: parseAlgorithm(uriParams),
|
||||
type,
|
||||
account: parseAccount(path),
|
||||
issuer: parseIssuer(url, path),
|
||||
length: parseLength(url, type),
|
||||
period: parsePeriod(url),
|
||||
secret: parseSecret(url),
|
||||
algorithm: parseAlgorithm(url),
|
||||
uriString,
|
||||
};
|
||||
};
|
||||
|
||||
const _getType = (uriPath: string): Code["type"] => {
|
||||
const oauthType = uriPath.split("/")[0].substring(0);
|
||||
if (oauthType.toLowerCase() === "totp") {
|
||||
return "totp";
|
||||
} else if (oauthType.toLowerCase() === "hotp") {
|
||||
return "hotp";
|
||||
}
|
||||
throw new Error(`Unsupported format with host ${oauthType}`);
|
||||
const parsePathname = (url: URL): [type: Code["type"], path: string] => {
|
||||
const p = url.pathname.toLowerCase();
|
||||
if (p.startsWith("//totp")) return ["totp", url.pathname.slice(6)];
|
||||
if (p.startsWith("//hotp")) return ["hotp", url.pathname.slice(6)];
|
||||
if (p.startsWith("//steam")) return ["steam", url.pathname.slice(7)];
|
||||
throw new Error(`Unsupported code or unparseable path "${url.pathname}"`);
|
||||
};
|
||||
|
||||
const _getAccount = (uriPath: string): string => {
|
||||
try {
|
||||
const path = decodeURIComponent(uriPath);
|
||||
if (path.includes(":")) {
|
||||
return path.split(":")[1];
|
||||
} else if (path.includes("/")) {
|
||||
return path.split("/")[1];
|
||||
}
|
||||
} catch (e) {
|
||||
return "";
|
||||
}
|
||||
const parseAccount = (path: string): string | undefined => {
|
||||
// "/ACME:user@example.org" => "user@example.org"
|
||||
let p = decodeURIComponent(path);
|
||||
if (p.startsWith("/")) p = p.slice(1);
|
||||
if (p.includes(":")) p = p.split(":").slice(1).join(":");
|
||||
return p;
|
||||
};
|
||||
|
||||
const _getIssuer = (uriPath: string, uriParams: { get?: any }): string => {
|
||||
try {
|
||||
if (uriParams["issuer"] !== undefined) {
|
||||
let issuer = uriParams["issuer"];
|
||||
// This is to handle bug in the ente auth app
|
||||
if (issuer.endsWith("period")) {
|
||||
issuer = issuer.substring(0, issuer.length - 6);
|
||||
}
|
||||
return issuer;
|
||||
const parseIssuer = (url: URL, path: string): string => {
|
||||
// If there is a "issuer" search param, use that.
|
||||
let issuer = url.searchParams.get("issuer");
|
||||
if (issuer) {
|
||||
// This is to handle bug in old versions of Ente Auth app.
|
||||
if (issuer.endsWith("period")) {
|
||||
issuer = issuer.substring(0, issuer.length - 6);
|
||||
}
|
||||
let path = decodeURIComponent(uriPath);
|
||||
if (path.startsWith("totp/") || path.startsWith("hotp/")) {
|
||||
path = path.substring(5);
|
||||
}
|
||||
if (path.includes(":")) {
|
||||
return path.split(":")[0];
|
||||
} else if (path.includes("-")) {
|
||||
return path.split("-")[0];
|
||||
}
|
||||
return path;
|
||||
} catch (e) {
|
||||
return "";
|
||||
return issuer;
|
||||
}
|
||||
|
||||
// Otherwise use the `prefix:` from the account as the issuer.
|
||||
// "/ACME:user@example.org" => "ACME"
|
||||
let p = decodeURIComponent(path);
|
||||
if (p.startsWith("/")) p = p.slice(1);
|
||||
|
||||
if (p.includes(":")) p = p.split(":")[0];
|
||||
else if (p.includes("-")) p = p.split("-")[0];
|
||||
|
||||
return p;
|
||||
};
|
||||
|
||||
const parseDigits = (uriParams): number =>
|
||||
parseInt(uriParams["digits"] ?? "", 10) || 6;
|
||||
/**
|
||||
* Parse the length of the generated code.
|
||||
*
|
||||
* The URI query param is called digits since originally TOTP/HOTP codes used
|
||||
* this for generating numeric codes. Now we also support steam, which instead
|
||||
* shows non-numeric codes, and also with a different default length of 5.
|
||||
*/
|
||||
const parseLength = (url: URL, type: Code["type"]): number => {
|
||||
const defaultLength = type == "steam" ? 5 : 6;
|
||||
return parseInt(url.searchParams.get("digits") ?? "", 10) || defaultLength;
|
||||
};
|
||||
|
||||
const parsePeriod = (uriParams): number =>
|
||||
parseInt(uriParams["period"] ?? "", 10) || 30;
|
||||
const parsePeriod = (url: URL): number =>
|
||||
parseInt(url.searchParams.get("period") ?? "", 10) || 30;
|
||||
|
||||
const parseAlgorithm = (uriParams): Code["algorithm"] => {
|
||||
switch (uriParams["algorithm"]?.toLowerCase()) {
|
||||
const parseAlgorithm = (url: URL): Code["algorithm"] => {
|
||||
switch (url.searchParams.get("algorithm")?.toLowerCase()) {
|
||||
case "sha256":
|
||||
return "sha256";
|
||||
case "sha512":
|
||||
@@ -148,8 +164,8 @@ const parseAlgorithm = (uriParams): Code["algorithm"] => {
|
||||
}
|
||||
};
|
||||
|
||||
const parseSecret = (uriParams): string =>
|
||||
uriParams["secret"].replaceAll(" ", "").toUpperCase();
|
||||
const parseSecret = (url: URL): string =>
|
||||
ensure(url.searchParams.get("secret")).replaceAll(" ", "").toUpperCase();
|
||||
|
||||
/**
|
||||
* Generate a pair of OTPs (one time passwords) from the given {@link code}.
|
||||
@@ -168,11 +184,11 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => {
|
||||
secret: code.secret,
|
||||
algorithm: code.algorithm,
|
||||
period: code.period,
|
||||
digits: code.digits,
|
||||
digits: code.length,
|
||||
});
|
||||
otp = totp.generate();
|
||||
nextOTP = totp.generate({
|
||||
timestamp: new Date().getTime() + code.period * 1000,
|
||||
timestamp: Date.now() + code.period * 1000,
|
||||
});
|
||||
break;
|
||||
}
|
||||
@@ -187,6 +203,17 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => {
|
||||
nextOTP = hotp.generate({ counter: 1 });
|
||||
break;
|
||||
}
|
||||
|
||||
case "steam": {
|
||||
const steam = new Steam({
|
||||
secret: code.secret,
|
||||
});
|
||||
otp = steam.generate();
|
||||
nextOTP = steam.generate({
|
||||
timestamp: Date.now() + code.period * 1000,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
return [otp, nextOTP];
|
||||
};
|
||||
|
||||
@@ -35,7 +35,7 @@ export const getAuthCodes = async (): Promise<Code[]> => {
|
||||
);
|
||||
return codeFromURIString(entity.id, decryptedCode);
|
||||
} catch (e) {
|
||||
log.error(`failed to parse codeId = ${entity.id}`);
|
||||
log.error(`Failed to parse codeID ${entity.id}`, e);
|
||||
return null;
|
||||
}
|
||||
}),
|
||||
|
||||
74
web/apps/auth/src/services/steam.ts
Normal file
74
web/apps/auth/src/services/steam.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import jsSHA from "jssha";
|
||||
import { Secret } from "otpauth";
|
||||
|
||||
/**
|
||||
* Steam OTPs.
|
||||
*
|
||||
* Steam's algorithm is a custom variant of TOTP that uses a 26-character
|
||||
* alphabet instead of digits.
|
||||
*
|
||||
* A Dart implementation of the algorithm can be found in
|
||||
* https://github.com/elliotwutingfeng/steam_totp/blob/main/lib/src/steam_totp_base.dart
|
||||
* (MIT license), and we use that as a reference. Our implementation is written
|
||||
* in the style of the other TOTP/HOTP classes that are provided by the otpauth
|
||||
* JS library that we use for the normal TOTP/HOTP generation
|
||||
* https://github.com/hectorm/otpauth/blob/master/src/hotp.js (MIT license).
|
||||
*/
|
||||
export class Steam {
|
||||
secret: Secret;
|
||||
period: number;
|
||||
|
||||
constructor({ secret }: { secret: string }) {
|
||||
this.secret = Secret.fromBase32(secret);
|
||||
this.period = 30;
|
||||
}
|
||||
|
||||
generate({ timestamp }: { timestamp: number } = { timestamp: Date.now() }) {
|
||||
// Same as regular TOTP.
|
||||
const counter = Math.floor(timestamp / 1000 / this.period);
|
||||
|
||||
// Same as regular HOTP, but algorithm is fixed to SHA-1.
|
||||
const digest = sha1HMACDigest(this.secret.buffer, uintToArray(counter));
|
||||
|
||||
// Same calculation as regular HOTP.
|
||||
const offset = digest[digest.length - 1] & 15;
|
||||
let otp =
|
||||
((digest[offset] & 127) << 24) |
|
||||
((digest[offset + 1] & 255) << 16) |
|
||||
((digest[offset + 2] & 255) << 8) |
|
||||
(digest[offset + 3] & 255);
|
||||
|
||||
// However, instead of using this as the OTP, use it to index into
|
||||
// the steam OTP alphabet.
|
||||
const alphabet = "23456789BCDFGHJKMNPQRTVWXY";
|
||||
const N = alphabet.length;
|
||||
const steamOTP = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
steamOTP.push(alphabet[otp % N]);
|
||||
otp = Math.trunc(otp / N);
|
||||
}
|
||||
return steamOTP.join("");
|
||||
}
|
||||
}
|
||||
|
||||
// Equivalent to
|
||||
// https://github.com/hectorm/otpauth/blob/master/src/utils/encoding/uint.js
|
||||
const uintToArray = (n: number): Uint8Array => {
|
||||
const result = new Uint8Array(8);
|
||||
for (let i = 7; i >= 0; i--) {
|
||||
result[i] = n & 255;
|
||||
n >>= 8;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
// We don't necessarily need a dependency on `jssha`, we could use SubtleCrypto
|
||||
// here too. However, SubtleCrypto has an async interface, and we already have a
|
||||
// transitive dependency on `jssha` via `otpauth`, so just using it here doesn't
|
||||
// increase our bundle size any further.
|
||||
const sha1HMACDigest = (key: ArrayBuffer, message: Uint8Array) => {
|
||||
const hmac = new jsSHA("SHA-1", "UINT8ARRAY");
|
||||
hmac.setHMACKey(key, "ARRAYBUFFER");
|
||||
hmac.update(message);
|
||||
return hmac.getHMAC("UINT8ARRAY");
|
||||
};
|
||||
@@ -43,7 +43,6 @@
|
||||
"similarity-transformation": "^0.0.1",
|
||||
"transformation-matrix": "^2.16",
|
||||
"uuid": "^9.0.1",
|
||||
"vscode-uri": "^3.0.7",
|
||||
"xml-js": "^1.6.11",
|
||||
"zxcvbn": "^4.4.2"
|
||||
},
|
||||
|
||||
@@ -198,3 +198,7 @@ some cases.
|
||||
|
||||
- [otpauth](https://github.com/hectorm/otpauth) is used for the generation of
|
||||
the actual OTP from the user's TOTP/HOTP secret.
|
||||
|
||||
- However, otpauth doesn't support steam OTPs. For these, we need to compute
|
||||
the SHA-1, and we use the same library, `jssha` that `otpauth` uses (since
|
||||
it is already part of our bundle).
|
||||
|
||||
@@ -4804,11 +4804,6 @@ void-elements@3.1.0:
|
||||
resolved "https://registry.yarnpkg.com/void-elements/-/void-elements-3.1.0.tgz#614f7fbf8d801f0bb5f0661f5b2f5785750e4f09"
|
||||
integrity sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==
|
||||
|
||||
vscode-uri@^3.0.7:
|
||||
version "3.0.8"
|
||||
resolved "https://registry.yarnpkg.com/vscode-uri/-/vscode-uri-3.0.8.tgz#1770938d3e72588659a172d0fd4642780083ff9f"
|
||||
integrity sha512-AyFQ0EVmsOZOlAnxoFOGOq1SQDWAB7C6aqMGS23svWAllfOaxbuFvcT8D1i8z3Gyn8fraVeZNNmN6e9bxxXkKw==
|
||||
|
||||
webidl-conversions@^3.0.0:
|
||||
version "3.0.1"
|
||||
resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"
|
||||
|
||||
Reference in New Issue
Block a user