Merge branch 'main' into migrate_files_db_to_sqlite_async

This commit is contained in:
ashilkn
2024-05-24 14:39:46 +05:30
19 changed files with 310 additions and 242 deletions

View File

@@ -13,6 +13,8 @@ import "package:photos/face/model/face.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/utils/ml_util.dart";
import 'package:sqlite_async/sqlite_async.dart';
/// Stores all data for the FacesML-related features. The database can be accessed by `FaceMLDataDB.instance.database`.
@@ -249,7 +251,7 @@ class FaceMLDataDB {
final List<int> fileId = [recentFileID];
int? avatarFileId;
if (avatarFaceId != null) {
avatarFileId = int.tryParse(avatarFaceId.split('_')[0]);
avatarFileId = tryGetFileIdFromFaceId(avatarFaceId);
if (avatarFileId != null) {
fileId.add(avatarFileId);
}
@@ -401,8 +403,10 @@ class FaceMLDataDB {
final personID = map[personIdColumn] as String;
final clusterID = map[fcClusterID] as int;
final faceID = map[fcFaceId] as String;
result.putIfAbsent(personID, () => {}).putIfAbsent(clusterID, () => {})
.add(faceID);
result
.putIfAbsent(personID, () => {})
.putIfAbsent(clusterID, () => {})
.add(faceID);
}
return result;
}
@@ -476,8 +480,7 @@ class FaceMLDataDB {
for (final map in maps) {
final clusterID = map[fcClusterID] as int;
final faceID = map[fcFaceId] as String;
final x = faceID.split('_').first;
final fileID = int.parse(x);
final fileID = getFileIdFromFaceId(faceID);
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
@@ -665,19 +668,38 @@ class FaceMLDataDB {
return maps.first['count'] as int;
}
Future<int> getClusteredFaceCount() async {
Future<int> getClusteredFileCount() async {
final db = await instance.asyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT COUNT(DISTINCT $fcFaceId) as count FROM $faceClustersTable',
'SELECT $fcFaceId FROM $faceClustersTable',
);
return maps.first['count'] as int;
final Set<int> fileIDs = {};
for (final map in maps) {
final int fileID = getFileIdFromFaceId(map[fcFaceId] as String);
fileIDs.add(fileID);
}
return fileIDs.length;
}
Future<double> getClusteredToTotalFacesRatio() async {
final int totalFaces = await getTotalFaceCount();
final int clusteredFaces = await getClusteredFaceCount();
Future<double> getClusteredToIndexableFilesRatio() async {
final int indexableFiles = (await getIndexableFileIDs()).length;
final int clusteredFiles = await getClusteredFileCount();
return clusteredFaces / totalFaces;
return clusteredFiles / indexableFiles;
}
Future<int> getUnclusteredFaceCount() async {
final db = await instance.asyncDB;
const String query = '''
SELECT f.$faceIDColumn
FROM $facesTable f
LEFT JOIN $faceClustersTable fc ON f.$faceIDColumn = fc.$fcFaceId
WHERE f.$faceScore > $kMinimumQualityFaceScore
AND f.$faceBlur > $kLaplacianHardThreshold
AND fc.$fcFaceId IS NULL
''';
final List<Map<String, dynamic>> maps = await db.getAll(query);
return maps.length;
}
Future<int> getBlurryFaceCount([
@@ -795,7 +817,7 @@ class FaceMLDataDB {
for (final map in maps) {
final clusterID = map[clusterIDColumn] as int;
final String faceID = map[fcFaceId] as String;
final fileID = int.parse(faceID.split('_').first);
final fileID = getFileIdFromFaceId(faceID);
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
@@ -814,8 +836,8 @@ class FaceMLDataDB {
final Map<int, Set<int>> result = {};
for (final map in maps) {
final clusterID = map[fcClusterID] as int;
final faceId = map[fcFaceId] as String;
final fileID = int.parse(faceId.split("_").first);
final faceID = map[fcFaceId] as String;
final fileID = getFileIdFromFaceId(faceID);
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
@@ -964,7 +986,7 @@ class FaceMLDataDB {
final Map<String, int> faceIDToClusterID = {};
for (final row in faceIdsResult) {
final faceID = row[fcFaceId] as String;
if (fileIds.contains(faceID.split('_').first)) {
if (fileIds.contains(getFileIdFromFaceId(faceID))) {
maxClusterID += 1;
faceIDToClusterID[faceID] = maxClusterID;
}
@@ -990,7 +1012,7 @@ class FaceMLDataDB {
final Map<String, int> faceIDToClusterID = {};
for (final row in faceIdsResult) {
final faceID = row[fcFaceId] as String;
if (fileIds.contains(faceID.split('_').first)) {
if (fileIds.contains(getFileIdFromFaceId(faceID))) {
maxClusterID += 1;
faceIDToClusterID[faceID] = maxClusterID;
}

View File

@@ -498,19 +498,8 @@ class FaceClusteringService {
}
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
_sortFaceInfosOnCreationTime(faceInfos);
}
// Sort the faceInfos such that the ones with null clusterId are at the end
@@ -796,19 +785,8 @@ class FaceClusteringService {
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
_sortFaceInfosOnCreationTime(faceInfos);
}
if (faceInfos.isEmpty) {
@@ -996,19 +974,8 @@ class FaceClusteringService {
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
_sortFaceInfosOnCreationTime(faceInfos);
}
// Get the embeddings
@@ -1027,3 +994,20 @@ class FaceClusteringService {
return clusteredFaceIDs;
}
}
/// Sort the faceInfos based on fileCreationTime, in descending order, so newest faces are first
void _sortFaceInfosOnCreationTime(
List<FaceInfo> faceInfos,
) {
faceInfos.sort((b, a) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
}

View File

@@ -310,5 +310,9 @@ class FaceResultBuilder {
}
int getFileIdFromFaceId(String faceId) {
return int.parse(faceId.split("_")[0]);
return int.parse(faceId.split("_").first);
}
int? tryGetFileIdFromFaceId(String faceId) {
return int.tryParse(faceId.split("_").first);
}

View File

@@ -12,7 +12,6 @@ import "package:flutter/foundation.dart" show debugPrint, kDebugMode;
import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:package_info_plus/package_info_plus.dart";
import "package:photos/core/configuration.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
import "package:photos/events/diff_sync_complete_event.dart";
@@ -99,6 +98,7 @@ class FaceMlService {
final int _fileDownloadLimit = 5;
final int _embeddingFetchLimit = 200;
final int _kForceClusteringFaceCount = 4000;
Future<void> init({bool initializeImageMlIsolate = false}) async {
if (LocalSettings.instance.isFaceIndexingEnabled == false) {
@@ -358,16 +358,17 @@ class FaceMlService {
if (_cannotRunMLFunction()) return;
await sync(forceSync: _shouldSyncPeople);
await indexAllImages();
final indexingCompleteRatio = await _getIndexedDoneRatio();
if (indexingCompleteRatio < 0.95) {
final int unclusteredFacesCount =
await FaceMLDataDB.instance.getUnclusteredFaceCount();
if (unclusteredFacesCount > _kForceClusteringFaceCount) {
_logger.info(
"Indexing is not far enough to start clustering, skipping clustering. Indexing is at $indexingCompleteRatio",
"There are $unclusteredFacesCount unclustered faces, doing clustering first",
);
return;
} else {
await clusterAllImages();
}
await indexAllImages();
await clusterAllImages();
}
void pauseIndexingAndClustering() {
@@ -445,7 +446,7 @@ class FaceMlService {
if (LocalSettings.instance.remoteFetchEnabled) {
try {
final List<int> fileIds = [];
final Set<int> fileIds = {}; // if there are duplicates here server returns 400
// Try to find embeddings on the remote server
for (final f in chunk) {
fileIds.add(f.uploadedFileID!);
@@ -590,8 +591,8 @@ class FaceMlService {
allFaceInfoForClustering.add(faceInfo);
}
}
// sort the embeddings based on file creation time, oldest first
allFaceInfoForClustering.sort((a, b) {
// sort the embeddings based on file creation time, newest first
allFaceInfoForClustering.sort((b, a) {
return fileIDToCreationTime[a.fileID]!
.compareTo(fileIDToCreationTime[b.fileID]!);
});
@@ -1171,24 +1172,6 @@ class FaceMlService {
}
}
Future<double> _getIndexedDoneRatio() async {
final w = (kDebugMode ? EnteWatch('_getIndexedDoneRatio') : null)?..start();
final int alreadyIndexedCount = await FaceMLDataDB.instance
.getIndexedFileCount(minimumMlVersion: faceMlVersion);
final int totalIndexableCount = (await getIndexableFileIDs()).length;
final ratio = alreadyIndexedCount / totalIndexableCount;
w?.log('getIndexedDoneRatio');
return ratio;
}
static Future<List<int>> getIndexableFileIDs() async {
return FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
}
bool _skipAnalysisEnteFile(EnteFile enteFile, Map<int, int> indexedFileIds) {
if (_isIndexingOrClusteringRunning == false ||
_mlControllerStatus == false) {

View File

@@ -1,6 +1,7 @@
import "dart:async";
import "dart:convert";
import "package:computer/computer.dart";
import "package:logging/logging.dart";
import "package:photos/core/network/network.dart";
import "package:photos/db/files_db.dart";
@@ -16,6 +17,8 @@ import "package:shared_preferences/shared_preferences.dart";
class RemoteFileMLService {
RemoteFileMLService._privateConstructor();
static final Computer _computer = Computer.shared();
static final RemoteFileMLService instance =
RemoteFileMLService._privateConstructor();
@@ -52,13 +55,13 @@ class RemoteFileMLService {
}
Future<FilesMLDataResponse> getFilessEmbedding(
List<int> fileIds,
Set<int> fileIds,
) async {
try {
final res = await _dio.post(
"/embeddings/files",
data: {
"fileIDs": fileIds,
"fileIDs": fileIds.toList(),
"model": 'file-ml-clip-face',
},
);
@@ -107,15 +110,17 @@ class RemoteFileMLService {
final input = EmbeddingsDecoderInput(embedding, fileKey);
inputs.add(input);
}
// todo: use compute or isolate
return decryptFileMLComputer(
{
return _computer.compute<Map<String, dynamic>, Map<int, FileMl>>(
_decryptFileMLComputer,
param: {
"inputs": inputs,
},
);
}
Future<Map<int, FileMl>> decryptFileMLComputer(
}
Future<Map<int, FileMl>> _decryptFileMLComputer(
Map<String, dynamic> args,
) async {
final result = <int, FileMl>{};
@@ -134,5 +139,4 @@ class RemoteFileMLService {
result[input.embedding.fileID] = decodedEmbedding;
}
return result;
}
}
}

View File

@@ -23,6 +23,7 @@ import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx
import "package:photos/utils/debouncer.dart";
import "package:photos/utils/device_info.dart";
import "package:photos/utils/local_settings.dart";
import "package:photos/utils/ml_util.dart";
import "package:photos/utils/thumbnail_util.dart";
class SemanticSearchService {
@@ -160,8 +161,7 @@ class SemanticSearchService {
}
Future<IndexStatus> getIndexStatus() async {
final indexableFileIDs = await FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
final indexableFileIDs = await getIndexableFileIDs();
return IndexStatus(
min(_cachedEmbeddings.length, indexableFileIDs.length),
(await _getFileIDsToBeIndexed()).length,
@@ -222,8 +222,7 @@ class SemanticSearchService {
}
Future<List<int>> _getFileIDsToBeIndexed() async {
final uploadedFileIDs = await FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
final uploadedFileIDs = await getIndexableFileIDs();
final embeddedFileIDs =
await EmbeddingsDB.instance.getFileIDs(_currentModel);

View File

@@ -754,15 +754,6 @@ class SearchService {
Future<List<GenericSearchResult>> getAllFace(int? limit) async {
try {
// Don't return anything if clustering is not nearly complete yet
final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
final clusteredFaces =
await FaceMLDataDB.instance.getClusteredFaceCount();
final clusteringDoneRatio = clusteredFaces / foundFaces;
if (clusteringDoneRatio < 0.9) {
return [];
}
debugPrint("getting faces");
final Map<int, Set<int>> fileIdToClusterID =
await FaceMLDataDB.instance.getFileIdToClusterIds();

View File

@@ -177,7 +177,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: FutureBuilder<double>(
future: FaceMLDataDB.instance.getClusteredToTotalFacesRatio(),
future: FaceMLDataDB.instance.getClusteredToIndexableFilesRatio(),
builder: (context, snapshot) {
if (snapshot.hasData) {
return CaptionedTextWidget(

View File

@@ -26,6 +26,7 @@ import "package:photos/ui/components/title_bar_widget.dart";
import "package:photos/ui/components/toggle_switch_widget.dart";
import "package:photos/utils/data_util.dart";
import "package:photos/utils/local_settings.dart";
import "package:photos/utils/ml_util.dart";
final _logger = Logger("MachineLearningSettingsPage");
@@ -438,19 +439,16 @@ class FaceRecognitionStatusWidgetState
});
}
Future<(int, int, int, double)> getIndexStatus() async {
Future<(int, int, double)> getIndexStatus() async {
try {
final indexedFiles = await FaceMLDataDB.instance
.getIndexedFileCount(minimumMlVersion: faceMlVersion);
final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length;
final indexableFiles = (await getIndexableFileIDs()).length;
final showIndexedFiles = min(indexedFiles, indexableFiles);
final pendingFiles = max(indexableFiles - indexedFiles, 0);
final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
final clusteredFaces =
await FaceMLDataDB.instance.getClusteredFaceCount();
final clusteringDoneRatio = clusteredFaces / foundFaces;
final clusteringDoneRatio = await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio);
return (showIndexedFiles, pendingFiles, clusteringDoneRatio);
} catch (e, s) {
_logger.severe('Error getting face recognition status', e, s);
rethrow;
@@ -479,8 +477,7 @@ class FaceRecognitionStatusWidgetState
if (snapshot.hasData) {
final int indexedFiles = snapshot.data!.$1;
final int pendingFiles = snapshot.data!.$2;
final int foundFaces = snapshot.data!.$3;
final double clusteringDoneRatio = snapshot.data!.$4;
final double clusteringDoneRatio = snapshot.data!.$3;
final double clusteringPercentage =
(clusteringDoneRatio * 100).clamp(0, 100);
@@ -512,19 +509,6 @@ class FaceRecognitionStatusWidgetState
isGestureDetectorDisabled: true,
key: ValueKey("pending_items_" + pendingFiles.toString()),
),
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: S.of(context).foundFaces,
),
trailingWidget: Text(
NumberFormat().format(foundFaces),
style: Theme.of(context).textTheme.bodySmall,
),
singleBorderRadius: 8,
alignCaptionedTextToLeft: true,
isGestureDetectorDisabled: true,
key: ValueKey("found_faces_" + foundFaces.toString()),
),
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: S.of(context).clusteringProgress,

View File

@@ -0,0 +1,7 @@
import "package:photos/core/configuration.dart";
import "package:photos/db/files_db.dart";
Future<List<int>> getIndexableFileIDs() async {
return FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
}

View File

@@ -12,7 +12,7 @@ description: ente photos application
# Read more about iOS versioning at
# https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
version: 0.8.109+633
version: 0.8.110+634
publish_to: none
environment:

View File

@@ -7,6 +7,7 @@
"@ente/accounts": "*",
"@ente/eslint-config": "*",
"@ente/shared": "*",
"jssha": "~3.3.1",
"otpauth": "^9"
}
}

View File

@@ -46,14 +46,11 @@ const AuthenticatorCodesPage = () => {
appContext.showNavBar(false);
}, []);
const lcSearch = searchTerm.toLowerCase();
const filteredCodes = codes.filter(
(secret) =>
(secret.issuer ?? "")
.toLowerCase()
.includes(searchTerm.toLowerCase()) ||
(secret.account ?? "")
.toLowerCase()
.includes(searchTerm.toLowerCase()),
(code) =>
code.issuer?.toLowerCase().includes(lcSearch) ||
code.account?.toLowerCase().includes(lcSearch),
);
if (!hasFetched) {
@@ -190,28 +187,21 @@ const CodeDisplay: React.FC<CodeDisplay> = ({ code }) => {
useEffect(() => {
// Generate to set the initial otp and nextOTP on component mount.
regen();
const codeType = code.type;
const codePeriodInMs = code.period * 1000;
const timeToNextCode =
codePeriodInMs - (new Date().getTime() % codePeriodInMs);
const interval = null;
const periodMs = code.period * 1000;
const timeToNextCode = periodMs - (Date.now() % periodMs);
let interval: ReturnType<typeof setInterval> | undefined;
// Wait until we are at the start of the next code period, and then
// start the interval loop.
setTimeout(() => {
// We need to call regen() once before the interval loop to set the
// initial otp and nextOTP.
regen();
codeType.toLowerCase() === "totp" ||
codeType.toLowerCase() === "hotp"
? setInterval(() => {
regen();
}, codePeriodInMs)
: null;
interval = setInterval(regen, periodMs);
}, timeToNextCode);
return () => {
if (interval) clearInterval(interval);
};
return () => interval && clearInterval(interval);
}, [code]);
return (
@@ -270,7 +260,7 @@ const OTPDisplay: React.FC<OTPDisplayProps> = ({ code, otp, nextOTP }) => {
textAlign: "left",
}}
>
{code.issuer}
{code.issuer ?? ""}
</p>
<p
style={{
@@ -283,7 +273,7 @@ const OTPDisplay: React.FC<OTPDisplayProps> = ({ code, otp, nextOTP }) => {
color: "grey",
}}
>
{code.account}
{code.account ?? ""}
</p>
<p
style={{
@@ -349,7 +339,7 @@ const TimerProgress: React.FC<TimerProgressProps> = ({ period }) => {
useEffect(() => {
const advance = () => {
const timeRemaining = us - ((new Date().getTime() * 1000) % us);
const timeRemaining = us - ((Date.now() * 1000) % us);
setProgress(timeRemaining / us);
};

View File

@@ -1,5 +1,6 @@
import { ensure } from "@/utils/ensure";
import { HOTP, TOTP } from "otpauth";
import { URI } from "vscode-uri";
import { Steam } from "./steam";
/**
* A parsed representation of an *OTP code URI.
@@ -10,13 +11,19 @@ export interface Code {
/** A unique id for the corresponding "auth entity" in our system. */
id?: String;
/** The type of the code. */
type: "totp" | "hotp";
type: "totp" | "hotp" | "steam";
/** The user's account or email for which this code is used. */
account: string;
account?: string;
/** The name of the entity that issued this code. */
issuer: string;
/** Number of digits in the generated OTP. */
digits: number;
/**
* Length of the generated OTP.
*
* This is vernacularly called "digits", which is an accurate description
* for the OG TOTP/HOTP codes. However, steam codes are not just digits, so
* we name this as a content-neutral "length".
*/
length: number;
/**
* The time period (in seconds) for which a single OTP generated from this
* code remains valid.
@@ -32,7 +39,7 @@ export interface Code {
/** The (HMAC) algorithm used by the OTP generator. */
algorithm: "sha1" | "sha256" | "sha512";
/** The original string from which this code was generated. */
uriString?: string;
uriString: string;
}
/**
@@ -45,100 +52,109 @@ export interface Code {
*
* - (TOTP)
* otpauth://totp/ACME:user@example.org?algorithm=SHA1&digits=6&issuer=acme&period=30&secret=ALPHANUM
*
* See also `auth/test/models/code_test.dart`.
*/
export const codeFromURIString = (id: string, uriString: string): Code => {
const santizedRawData = uriString
.replaceAll("+", "%2B")
.replaceAll(":", "%3A")
.replaceAll("\r", "")
// trim quotes
.replace(/^"|"$/g, "");
const uriParams = {};
const searchParamsString =
decodeURIComponent(santizedRawData).split("?")[1];
searchParamsString.split("&").forEach((pair) => {
const [key, value] = pair.split("=");
uriParams[key] = value;
});
const uri = URI.parse(santizedRawData);
let uriPath = decodeURIComponent(uri.path);
if (uriPath.startsWith("/otpauth://") || uriPath.startsWith("otpauth://")) {
uriPath = uriPath.split("otpauth://")[1];
} else if (uriPath.startsWith("otpauth%3A//")) {
uriPath = uriPath.split("otpauth%3A//")[1];
try {
return _codeFromURIString(id, uriString);
} catch (e) {
// We might have legacy encodings of account names that contain a "#",
// which causes the rest of the URL to be treated as a fragment, and
// ignored. See if this was potentially such a case, otherwise rethrow.
if (uriString.includes("#"))
return _codeFromURIString(id, uriString.replaceAll("#", "%23"));
throw e;
}
};
const _codeFromURIString = (id: string, uriString: string): Code => {
const url = new URL(uriString);
// A URL like
//
// new URL("otpauth://hotp/Test?secret=AAABBBCCCDDDEEEFFF&issuer=Test&counter=0")
//
// is parsed differently by the browser and Node depending on the scheme.
// When the scheme is http(s), then both of them consider "hotp" as the
// `host`. However, when the scheme is "otpauth", as is our case, the
// browser considers the entire thing as part of the pathname. so we get.
//
// host: ""
// pathname: "//hotp/Test"
//
// Since this code run on browsers only, we parse as per that behaviour.
const [type, path] = parsePathname(url);
return {
id,
type: _getType(uriPath),
account: _getAccount(uriPath),
issuer: _getIssuer(uriPath, uriParams),
digits: parseDigits(uriParams),
period: parsePeriod(uriParams),
secret: parseSecret(uriParams),
algorithm: parseAlgorithm(uriParams),
type,
account: parseAccount(path),
issuer: parseIssuer(url, path),
length: parseLength(url, type),
period: parsePeriod(url),
secret: parseSecret(url),
algorithm: parseAlgorithm(url),
uriString,
};
};
const _getType = (uriPath: string): Code["type"] => {
const oauthType = uriPath.split("/")[0].substring(0);
if (oauthType.toLowerCase() === "totp") {
return "totp";
} else if (oauthType.toLowerCase() === "hotp") {
return "hotp";
}
throw new Error(`Unsupported format with host ${oauthType}`);
const parsePathname = (url: URL): [type: Code["type"], path: string] => {
const p = url.pathname.toLowerCase();
if (p.startsWith("//totp")) return ["totp", url.pathname.slice(6)];
if (p.startsWith("//hotp")) return ["hotp", url.pathname.slice(6)];
if (p.startsWith("//steam")) return ["steam", url.pathname.slice(7)];
throw new Error(`Unsupported code or unparseable path "${url.pathname}"`);
};
const _getAccount = (uriPath: string): string => {
try {
const path = decodeURIComponent(uriPath);
if (path.includes(":")) {
return path.split(":")[1];
} else if (path.includes("/")) {
return path.split("/")[1];
}
} catch (e) {
return "";
}
const parseAccount = (path: string): string | undefined => {
// "/ACME:user@example.org" => "user@example.org"
let p = decodeURIComponent(path);
if (p.startsWith("/")) p = p.slice(1);
if (p.includes(":")) p = p.split(":").slice(1).join(":");
return p;
};
const _getIssuer = (uriPath: string, uriParams: { get?: any }): string => {
try {
if (uriParams["issuer"] !== undefined) {
let issuer = uriParams["issuer"];
// This is to handle bug in the ente auth app
if (issuer.endsWith("period")) {
issuer = issuer.substring(0, issuer.length - 6);
}
return issuer;
const parseIssuer = (url: URL, path: string): string => {
// If there is a "issuer" search param, use that.
let issuer = url.searchParams.get("issuer");
if (issuer) {
// This is to handle bug in old versions of Ente Auth app.
if (issuer.endsWith("period")) {
issuer = issuer.substring(0, issuer.length - 6);
}
let path = decodeURIComponent(uriPath);
if (path.startsWith("totp/") || path.startsWith("hotp/")) {
path = path.substring(5);
}
if (path.includes(":")) {
return path.split(":")[0];
} else if (path.includes("-")) {
return path.split("-")[0];
}
return path;
} catch (e) {
return "";
return issuer;
}
// Otherwise use the `prefix:` from the account as the issuer.
// "/ACME:user@example.org" => "ACME"
let p = decodeURIComponent(path);
if (p.startsWith("/")) p = p.slice(1);
if (p.includes(":")) p = p.split(":")[0];
else if (p.includes("-")) p = p.split("-")[0];
return p;
};
const parseDigits = (uriParams): number =>
parseInt(uriParams["digits"] ?? "", 10) || 6;
/**
* Parse the length of the generated code.
*
* The URI query param is called digits since originally TOTP/HOTP codes used
* this for generating numeric codes. Now we also support steam, which instead
* shows non-numeric codes, and also with a different default length of 5.
*/
const parseLength = (url: URL, type: Code["type"]): number => {
const defaultLength = type == "steam" ? 5 : 6;
return parseInt(url.searchParams.get("digits") ?? "", 10) || defaultLength;
};
const parsePeriod = (uriParams): number =>
parseInt(uriParams["period"] ?? "", 10) || 30;
const parsePeriod = (url: URL): number =>
parseInt(url.searchParams.get("period") ?? "", 10) || 30;
const parseAlgorithm = (uriParams): Code["algorithm"] => {
switch (uriParams["algorithm"]?.toLowerCase()) {
const parseAlgorithm = (url: URL): Code["algorithm"] => {
switch (url.searchParams.get("algorithm")?.toLowerCase()) {
case "sha256":
return "sha256";
case "sha512":
@@ -148,8 +164,8 @@ const parseAlgorithm = (uriParams): Code["algorithm"] => {
}
};
const parseSecret = (uriParams): string =>
uriParams["secret"].replaceAll(" ", "").toUpperCase();
const parseSecret = (url: URL): string =>
ensure(url.searchParams.get("secret")).replaceAll(" ", "").toUpperCase();
/**
* Generate a pair of OTPs (one time passwords) from the given {@link code}.
@@ -168,11 +184,11 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => {
secret: code.secret,
algorithm: code.algorithm,
period: code.period,
digits: code.digits,
digits: code.length,
});
otp = totp.generate();
nextOTP = totp.generate({
timestamp: new Date().getTime() + code.period * 1000,
timestamp: Date.now() + code.period * 1000,
});
break;
}
@@ -187,6 +203,17 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => {
nextOTP = hotp.generate({ counter: 1 });
break;
}
case "steam": {
const steam = new Steam({
secret: code.secret,
});
otp = steam.generate();
nextOTP = steam.generate({
timestamp: Date.now() + code.period * 1000,
});
break;
}
}
return [otp, nextOTP];
};

View File

@@ -35,7 +35,7 @@ export const getAuthCodes = async (): Promise<Code[]> => {
);
return codeFromURIString(entity.id, decryptedCode);
} catch (e) {
log.error(`failed to parse codeId = ${entity.id}`);
log.error(`Failed to parse codeID ${entity.id}`, e);
return null;
}
}),

View File

@@ -0,0 +1,74 @@
import jsSHA from "jssha";
import { Secret } from "otpauth";
/**
* Steam OTPs.
*
* Steam's algorithm is a custom variant of TOTP that uses a 26-character
* alphabet instead of digits.
*
* A Dart implementation of the algorithm can be found in
* https://github.com/elliotwutingfeng/steam_totp/blob/main/lib/src/steam_totp_base.dart
* (MIT license), and we use that as a reference. Our implementation is written
* in the style of the other TOTP/HOTP classes that are provided by the otpauth
* JS library that we use for the normal TOTP/HOTP generation
* https://github.com/hectorm/otpauth/blob/master/src/hotp.js (MIT license).
*/
export class Steam {
secret: Secret;
period: number;
constructor({ secret }: { secret: string }) {
this.secret = Secret.fromBase32(secret);
this.period = 30;
}
generate({ timestamp }: { timestamp: number } = { timestamp: Date.now() }) {
// Same as regular TOTP.
const counter = Math.floor(timestamp / 1000 / this.period);
// Same as regular HOTP, but algorithm is fixed to SHA-1.
const digest = sha1HMACDigest(this.secret.buffer, uintToArray(counter));
// Same calculation as regular HOTP.
const offset = digest[digest.length - 1] & 15;
let otp =
((digest[offset] & 127) << 24) |
((digest[offset + 1] & 255) << 16) |
((digest[offset + 2] & 255) << 8) |
(digest[offset + 3] & 255);
// However, instead of using this as the OTP, use it to index into
// the steam OTP alphabet.
const alphabet = "23456789BCDFGHJKMNPQRTVWXY";
const N = alphabet.length;
const steamOTP = [];
for (let i = 0; i < 5; i++) {
steamOTP.push(alphabet[otp % N]);
otp = Math.trunc(otp / N);
}
return steamOTP.join("");
}
}
// Equivalent to
// https://github.com/hectorm/otpauth/blob/master/src/utils/encoding/uint.js
const uintToArray = (n: number): Uint8Array => {
const result = new Uint8Array(8);
for (let i = 7; i >= 0; i--) {
result[i] = n & 255;
n >>= 8;
}
return result;
};
// We don't necessarily need a dependency on `jssha`, we could use SubtleCrypto
// here too. However, SubtleCrypto has an async interface, and we already have a
// transitive dependency on `jssha` via `otpauth`, so just using it here doesn't
// increase our bundle size any further.
const sha1HMACDigest = (key: ArrayBuffer, message: Uint8Array) => {
const hmac = new jsSHA("SHA-1", "UINT8ARRAY");
hmac.setHMACKey(key, "ARRAYBUFFER");
hmac.update(message);
return hmac.getHMAC("UINT8ARRAY");
};

View File

@@ -43,7 +43,6 @@
"similarity-transformation": "^0.0.1",
"transformation-matrix": "^2.16",
"uuid": "^9.0.1",
"vscode-uri": "^3.0.7",
"xml-js": "^1.6.11",
"zxcvbn": "^4.4.2"
},

View File

@@ -198,3 +198,7 @@ some cases.
- [otpauth](https://github.com/hectorm/otpauth) is used for the generation of
the actual OTP from the user's TOTP/HOTP secret.
- However, otpauth doesn't support steam OTPs. For these, we need to compute
the SHA-1, and we use the same library, `jssha` that `otpauth` uses (since
it is already part of our bundle).

View File

@@ -4804,11 +4804,6 @@ void-elements@3.1.0:
resolved "https://registry.yarnpkg.com/void-elements/-/void-elements-3.1.0.tgz#614f7fbf8d801f0bb5f0661f5b2f5785750e4f09"
integrity sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==
vscode-uri@^3.0.7:
version "3.0.8"
resolved "https://registry.yarnpkg.com/vscode-uri/-/vscode-uri-3.0.8.tgz#1770938d3e72588659a172d0fd4642780083ff9f"
integrity sha512-AyFQ0EVmsOZOlAnxoFOGOq1SQDWAB7C6aqMGS23svWAllfOaxbuFvcT8D1i8z3Gyn8fraVeZNNmN6e9bxxXkKw==
webidl-conversions@^3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"