diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 0656f12e61..5f8534b9d3 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -445,6 +445,7 @@ func main() { publicAPI.POST("/users/two-factor/remove", userHandler.RemoveTwoFactor) publicAPI.POST("/users/two-factor/passkeys/begin", userHandler.BeginPasskeyAuthenticationCeremony) publicAPI.POST("/users/two-factor/passkeys/finish", userHandler.FinishPasskeyAuthenticationCeremony) + publicAPI.GET("/users/two-factor/passkeys/get-token", userHandler.GetTokenForPasskeySession) privateAPI.GET("/users/two-factor/recovery-status", userHandler.GetTwoFactorRecoveryStatus) privateAPI.POST("/users/two-factor/passkeys/configure-recovery", userHandler.ConfigurePasskeyRecovery) privateAPI.GET("/users/two-factor/status", userHandler.GetTwoFactorStatus) diff --git a/server/ente/errors.go b/server/ente/errors.go index 96e7bd4a1e..89fdebb17f 100644 --- a/server/ente/errors.go +++ b/server/ente/errors.go @@ -125,6 +125,12 @@ var ErrFileNotFoundInAlbum = ApiError{ Message: "File is either deleted or moved to different collection", } +var ErrSessionAlreadyClaimed = ApiError{ + Code: "SESSION_ALREADY_CLAIMED", + Message: "Session is already claimed", + HttpStatusCode: http.StatusConflict, +} + var ErrPublicCollectDisabled = ApiError{ Code: PublicCollectDisabled, Message: "User has not enabled public collect for this url", diff --git a/server/migrations/87_passkey_login_token.down.sql b/server/migrations/87_passkey_login_token.down.sql new file mode 100644 index 0000000000..1af9da976c --- /dev/null +++ b/server/migrations/87_passkey_login_token.down.sql @@ -0,0 +1,6 @@ +-- Add types for the new dcs that are introduced for the derived data + +ALTER TABLE passkey_login_sessions + DROP COLUMN IF EXISTS token_fetch_cnt, + DROP COLUMN IF EXISTS verified_at, + DROP COLUMN IF EXISTS token_data; diff --git a/server/migrations/87_passkey_login_token.up.sql b/server/migrations/87_passkey_login_token.up.sql new file mode 100644 index 0000000000..0c44050aca --- /dev/null +++ b/server/migrations/87_passkey_login_token.up.sql @@ -0,0 +1,6 @@ +-- Add columns to passkey_login_sessions table for facilitating token fetch in case of passkey redirect +-- not working. +ALTER TABLE passkey_login_sessions + ADD COLUMN IF NOT EXISTS token_fetch_cnt int default 0, + ADD COLUMN IF NOT EXISTS verified_at BIGINT, + ADD COLUMN IF NOT EXISTS token_data jsonb; diff --git a/server/pkg/api/user.go b/server/pkg/api/user.go index eca3804e5c..71e050fdeb 100644 --- a/server/pkg/api/user.go +++ b/server/pkg/api/user.go @@ -325,6 +325,17 @@ func (h *UserHandler) BeginPasskeyAuthenticationCeremony(c *gin.Context) { return } + isSessionAlreadyClaimed, err := h.UserController.PasskeyRepo.IsSessionAlreadyClaimed(request.SessionID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + + if isSessionAlreadyClaimed { + handler.Error(c, stacktrace.Propagate(&ente.ErrSessionAlreadyClaimed, "Session already claimed")) + return + } + user, err := h.UserController.UserRepo.Get(userID) if err != nil { handler.Error(c, stacktrace.Propagate(err, "")) @@ -374,6 +385,26 @@ func (h *UserHandler) FinishPasskeyAuthenticationCeremony(c *gin.Context) { return } + err = h.UserController.PasskeyRepo.StoreTokenData(request.SessionID, response) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to store token data")) + return + } + + c.JSON(http.StatusOK, response) +} + +func (h *UserHandler) GetTokenForPasskeySession(c *gin.Context) { + sessionID := c.Query("sessionID") + if sessionID == "" { + handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("sessionID is required"), "")) + return + } + response, err := h.UserController.PasskeyRepo.GetTokenData(sessionID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to get token data")) + return + } c.JSON(http.StatusOK, response) } diff --git a/server/pkg/controller/user/jwt.go b/server/pkg/controller/user/jwt.go index d920e36b0b..d804f4cef3 100644 --- a/server/pkg/controller/user/jwt.go +++ b/server/pkg/controller/user/jwt.go @@ -13,11 +13,15 @@ import ( const ValidForDays = 1 func (c *UserController) GetJWTToken(userID int64, scope enteJWT.ClaimScope) (string, error) { + tokenExpiry := time.NDaysFromNow(1) + if scope == enteJWT.ACCOUNTS { + tokenExpiry = time.NMinFromNow(30) + } // Create a new token object, specifying signing method and the claims // you would like it to contain. token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.WebCommonJWTClaim{ UserID: userID, - ExpiryTime: time.NDaysFromNow(1), + ExpiryTime: tokenExpiry, ClaimScope: &scope, }) // Sign and get the complete encoded token as a string using the secret diff --git a/server/pkg/repo/passkey/passkey.go b/server/pkg/repo/passkey/passkey.go index 5f8d3d642d..131f16b836 100644 --- a/server/pkg/repo/passkey/passkey.go +++ b/server/pkg/repo/passkey/passkey.go @@ -19,6 +19,13 @@ import ( "github.com/go-webauthn/webauthn/webauthn" ) +const ( + // MaxSessionTokenFetchLimit specifies the maximum number of requests a client can make to retrieve token data for a given session ID. + MaxSessionTokenFetchLimit = 2 + // TokenFetchAllowedDurationInMin is the duration in minutes for which the token fetch is allowed after the session is verified. + TokenFetchAllowedDurationInMin = 2 +) + type Repository struct { DB *sql.DB webAuthnInstance *webauthn.WebAuthn @@ -167,6 +174,87 @@ func (r *Repository) GetUserIDWithPasskeyTwoFactorSession(sessionID string) (use return } +// IsSessionAlreadyClaimed checks if the both token_data and verified_at are not null for a given session ID +func (r *Repository) IsSessionAlreadyClaimed(sessionID string) (bool, error) { + var verifiedAt sql.NullInt64 + err := r.DB.QueryRow(`SELECT verified_at FROM passkey_login_sessions WHERE session_id = $1`, sessionID).Scan(&verifiedAt) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, stacktrace.Propagate(err, "") + } + return verifiedAt.Valid, nil +} + +// StoreTokenData takes a sessionID, and tokenData, and updates the tokenData in the database +func (r *Repository) StoreTokenData(sessionID string, tokenData ente.TwoFactorAuthorizationResponse) error { + tokenDataJson, err := json.Marshal(tokenData) + if err != nil { + return stacktrace.Propagate(err, "") + } + _, err = r.DB.Exec(`UPDATE passkey_login_sessions SET token_data = $1, verified_at = now_utc_micro_seconds() WHERE session_id = $2`, tokenDataJson, sessionID) + return stacktrace.Propagate(err, "") +} + +// GetTokenData retrieves the token data associated with a given session ID. +// The function will return the token data if the following conditions are met: +// - The token data is not null. +// - The session was verified less than 5 minutes ago. +// - The token fetch count is less than 2. +// If these conditions are met, the function will also increment the token fetch count by 1. +// +// Parameters: +// - sessionID: The ID of the session for which to retrieve the token data. +// +// Returns: +// - A pointer to a TwoFactorAuthorizationResponse object containing the token data, if the conditions are met. +// - An error, if an error occurred while retrieving the token data or if the conditions are not met. +func (r *Repository) GetTokenData(sessionID string) (*ente.TwoFactorAuthorizationResponse, error) { + var tokenDataJson []byte + var verifiedAt sql.NullInt64 + var fetchCount int + err := r.DB.QueryRow(`SELECT token_data, verified_at, token_fetch_cnt FROM passkey_login_sessions WHERE session_id = $1`, sessionID).Scan(&tokenDataJson, &verifiedAt, &fetchCount) + if err != nil { + if err == sql.ErrNoRows { + return nil, ente.ErrNotFound + } + return nil, stacktrace.Propagate(err, "") + } + if !verifiedAt.Valid { + return nil, &ente.ApiError{ + Code: "SESSION_NOT_VERIFIED", + Message: "Session is not verified yet", + HttpStatusCode: http.StatusBadRequest, + } + } + if verifiedAt.Int64 < ente_time.MicrosecondsBeforeMinutes(TokenFetchAllowedDurationInMin) { + return nil, &ente.ApiError{ + Code: "INVALID_SESSION", + Message: "Session verified but expired now", + HttpStatusCode: http.StatusGone, + } + } + if fetchCount >= MaxSessionTokenFetchLimit { + return nil, &ente.ApiError{ + Code: "INVALID_SESSION", + Message: "Token fetch limit reached", + HttpStatusCode: http.StatusGone, + } + } + var tokenData ente.TwoFactorAuthorizationResponse + err = json.Unmarshal(tokenDataJson, &tokenData) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + // update the token_fetch_count + _, err = r.DB.Exec(`UPDATE passkey_login_sessions SET token_fetch_cnt = token_fetch_cnt + 1 WHERE session_id = $1`, sessionID) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return &tokenData, nil +} + func (r *Repository) CreateBeginAuthenticationData(user *ente.User) (options *protocol.CredentialAssertion, session *webauthn.SessionData, id uuid.UUID, err error) { passkeyUser := &PasskeyUser{ User: user, diff --git a/server/pkg/utils/time/time.go b/server/pkg/utils/time/time.go index c03f97696d..a07df4b262 100644 --- a/server/pkg/utils/time/time.go +++ b/server/pkg/utils/time/time.go @@ -48,6 +48,11 @@ func NDaysFromNow(n int) int64 { return time.Now().AddDate(0, 0, n).UnixNano() / 1000 } +// NMinFromNow returns the time n min from now in micro seconds +func NMinFromNow(n int64) int64 { + return time.Now().Add(time.Minute*time.Duration(n)).UnixNano() / 1000 +} + // MicrosecondsBeforeMinutes returns the unix time n minutes before now in micro seconds func MicrosecondsBeforeMinutes(noOfMinutes int64) int64 { return Microseconds() - (MicroSecondsInOneMinute * noOfMinutes)