[server] Reduce passkey JWT duration + API to get token via sessionID (#2111)

## Description

## Tests
Monkey tested locally
This commit is contained in:
Neeraj Gupta
2024-06-13 18:07:37 +05:30
committed by GitHub
8 changed files with 148 additions and 1 deletions

View File

@@ -445,6 +445,7 @@ func main() {
publicAPI.POST("/users/two-factor/remove", userHandler.RemoveTwoFactor)
publicAPI.POST("/users/two-factor/passkeys/begin", userHandler.BeginPasskeyAuthenticationCeremony)
publicAPI.POST("/users/two-factor/passkeys/finish", userHandler.FinishPasskeyAuthenticationCeremony)
publicAPI.GET("/users/two-factor/passkeys/get-token", userHandler.GetTokenForPasskeySession)
privateAPI.GET("/users/two-factor/recovery-status", userHandler.GetTwoFactorRecoveryStatus)
privateAPI.POST("/users/two-factor/passkeys/configure-recovery", userHandler.ConfigurePasskeyRecovery)
privateAPI.GET("/users/two-factor/status", userHandler.GetTwoFactorStatus)

View File

@@ -125,6 +125,12 @@ var ErrFileNotFoundInAlbum = ApiError{
Message: "File is either deleted or moved to different collection",
}
var ErrSessionAlreadyClaimed = ApiError{
Code: "SESSION_ALREADY_CLAIMED",
Message: "Session is already claimed",
HttpStatusCode: http.StatusConflict,
}
var ErrPublicCollectDisabled = ApiError{
Code: PublicCollectDisabled,
Message: "User has not enabled public collect for this url",

View File

@@ -0,0 +1,6 @@
-- Add types for the new dcs that are introduced for the derived data
ALTER TABLE passkey_login_sessions
DROP COLUMN IF EXISTS token_fetch_cnt,
DROP COLUMN IF EXISTS verified_at,
DROP COLUMN IF EXISTS token_data;

View File

@@ -0,0 +1,6 @@
-- Add columns to passkey_login_sessions table for facilitating token fetch in case of passkey redirect
-- not working.
ALTER TABLE passkey_login_sessions
ADD COLUMN IF NOT EXISTS token_fetch_cnt int default 0,
ADD COLUMN IF NOT EXISTS verified_at BIGINT,
ADD COLUMN IF NOT EXISTS token_data jsonb;

View File

@@ -325,6 +325,17 @@ func (h *UserHandler) BeginPasskeyAuthenticationCeremony(c *gin.Context) {
return
}
isSessionAlreadyClaimed, err := h.UserController.PasskeyRepo.IsSessionAlreadyClaimed(request.SessionID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
if isSessionAlreadyClaimed {
handler.Error(c, stacktrace.Propagate(&ente.ErrSessionAlreadyClaimed, "Session already claimed"))
return
}
user, err := h.UserController.UserRepo.Get(userID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
@@ -374,6 +385,26 @@ func (h *UserHandler) FinishPasskeyAuthenticationCeremony(c *gin.Context) {
return
}
err = h.UserController.PasskeyRepo.StoreTokenData(request.SessionID, response)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, "failed to store token data"))
return
}
c.JSON(http.StatusOK, response)
}
func (h *UserHandler) GetTokenForPasskeySession(c *gin.Context) {
sessionID := c.Query("sessionID")
if sessionID == "" {
handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("sessionID is required"), ""))
return
}
response, err := h.UserController.PasskeyRepo.GetTokenData(sessionID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, "failed to get token data"))
return
}
c.JSON(http.StatusOK, response)
}

View File

@@ -13,11 +13,15 @@ import (
const ValidForDays = 1
func (c *UserController) GetJWTToken(userID int64, scope enteJWT.ClaimScope) (string, error) {
tokenExpiry := time.NDaysFromNow(1)
if scope == enteJWT.ACCOUNTS {
tokenExpiry = time.NMinFromNow(30)
}
// Create a new token object, specifying signing method and the claims
// you would like it to contain.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.WebCommonJWTClaim{
UserID: userID,
ExpiryTime: time.NDaysFromNow(1),
ExpiryTime: tokenExpiry,
ClaimScope: &scope,
})
// Sign and get the complete encoded token as a string using the secret

View File

@@ -19,6 +19,13 @@ import (
"github.com/go-webauthn/webauthn/webauthn"
)
const (
// MaxSessionTokenFetchLimit specifies the maximum number of requests a client can make to retrieve token data for a given session ID.
MaxSessionTokenFetchLimit = 2
// TokenFetchAllowedDurationInMin is the duration in minutes for which the token fetch is allowed after the session is verified.
TokenFetchAllowedDurationInMin = 2
)
type Repository struct {
DB *sql.DB
webAuthnInstance *webauthn.WebAuthn
@@ -167,6 +174,87 @@ func (r *Repository) GetUserIDWithPasskeyTwoFactorSession(sessionID string) (use
return
}
// IsSessionAlreadyClaimed checks if the both token_data and verified_at are not null for a given session ID
func (r *Repository) IsSessionAlreadyClaimed(sessionID string) (bool, error) {
var verifiedAt sql.NullInt64
err := r.DB.QueryRow(`SELECT verified_at FROM passkey_login_sessions WHERE session_id = $1`, sessionID).Scan(&verifiedAt)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, stacktrace.Propagate(err, "")
}
return verifiedAt.Valid, nil
}
// StoreTokenData takes a sessionID, and tokenData, and updates the tokenData in the database
func (r *Repository) StoreTokenData(sessionID string, tokenData ente.TwoFactorAuthorizationResponse) error {
tokenDataJson, err := json.Marshal(tokenData)
if err != nil {
return stacktrace.Propagate(err, "")
}
_, err = r.DB.Exec(`UPDATE passkey_login_sessions SET token_data = $1, verified_at = now_utc_micro_seconds() WHERE session_id = $2`, tokenDataJson, sessionID)
return stacktrace.Propagate(err, "")
}
// GetTokenData retrieves the token data associated with a given session ID.
// The function will return the token data if the following conditions are met:
// - The token data is not null.
// - The session was verified less than 5 minutes ago.
// - The token fetch count is less than 2.
// If these conditions are met, the function will also increment the token fetch count by 1.
//
// Parameters:
// - sessionID: The ID of the session for which to retrieve the token data.
//
// Returns:
// - A pointer to a TwoFactorAuthorizationResponse object containing the token data, if the conditions are met.
// - An error, if an error occurred while retrieving the token data or if the conditions are not met.
func (r *Repository) GetTokenData(sessionID string) (*ente.TwoFactorAuthorizationResponse, error) {
var tokenDataJson []byte
var verifiedAt sql.NullInt64
var fetchCount int
err := r.DB.QueryRow(`SELECT token_data, verified_at, token_fetch_cnt FROM passkey_login_sessions WHERE session_id = $1`, sessionID).Scan(&tokenDataJson, &verifiedAt, &fetchCount)
if err != nil {
if err == sql.ErrNoRows {
return nil, ente.ErrNotFound
}
return nil, stacktrace.Propagate(err, "")
}
if !verifiedAt.Valid {
return nil, &ente.ApiError{
Code: "SESSION_NOT_VERIFIED",
Message: "Session is not verified yet",
HttpStatusCode: http.StatusBadRequest,
}
}
if verifiedAt.Int64 < ente_time.MicrosecondsBeforeMinutes(TokenFetchAllowedDurationInMin) {
return nil, &ente.ApiError{
Code: "INVALID_SESSION",
Message: "Session verified but expired now",
HttpStatusCode: http.StatusGone,
}
}
if fetchCount >= MaxSessionTokenFetchLimit {
return nil, &ente.ApiError{
Code: "INVALID_SESSION",
Message: "Token fetch limit reached",
HttpStatusCode: http.StatusGone,
}
}
var tokenData ente.TwoFactorAuthorizationResponse
err = json.Unmarshal(tokenDataJson, &tokenData)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
// update the token_fetch_count
_, err = r.DB.Exec(`UPDATE passkey_login_sessions SET token_fetch_cnt = token_fetch_cnt + 1 WHERE session_id = $1`, sessionID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &tokenData, nil
}
func (r *Repository) CreateBeginAuthenticationData(user *ente.User) (options *protocol.CredentialAssertion, session *webauthn.SessionData, id uuid.UUID, err error) {
passkeyUser := &PasskeyUser{
User: user,

View File

@@ -48,6 +48,11 @@ func NDaysFromNow(n int) int64 {
return time.Now().AddDate(0, 0, n).UnixNano() / 1000
}
// NMinFromNow returns the time n min from now in micro seconds
func NMinFromNow(n int64) int64 {
return time.Now().Add(time.Minute*time.Duration(n)).UnixNano() / 1000
}
// MicrosecondsBeforeMinutes returns the unix time n minutes before now in micro seconds
func MicrosecondsBeforeMinutes(noOfMinutes int64) int64 {
return Microseconds() - (MicroSecondsInOneMinute * noOfMinutes)