diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 0656f12e61..5f8534b9d3 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -445,6 +445,7 @@ func main() { publicAPI.POST("/users/two-factor/remove", userHandler.RemoveTwoFactor) publicAPI.POST("/users/two-factor/passkeys/begin", userHandler.BeginPasskeyAuthenticationCeremony) publicAPI.POST("/users/two-factor/passkeys/finish", userHandler.FinishPasskeyAuthenticationCeremony) + publicAPI.GET("/users/two-factor/passkeys/get-token", userHandler.GetTokenForPasskeySession) privateAPI.GET("/users/two-factor/recovery-status", userHandler.GetTwoFactorRecoveryStatus) privateAPI.POST("/users/two-factor/passkeys/configure-recovery", userHandler.ConfigurePasskeyRecovery) privateAPI.GET("/users/two-factor/status", userHandler.GetTwoFactorStatus) diff --git a/server/pkg/api/user.go b/server/pkg/api/user.go index 4624ab05ff..c38df8f7b7 100644 --- a/server/pkg/api/user.go +++ b/server/pkg/api/user.go @@ -383,6 +383,20 @@ func (h *UserHandler) FinishPasskeyAuthenticationCeremony(c *gin.Context) { c.JSON(http.StatusOK, response) } +func (h *UserHandler) GetTokenForPasskeySession(c *gin.Context) { + sessionID := c.Query("sessionID") + if sessionID == "" { + handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("sessionID is required"), "")) + return + } + response, err := h.UserController.PasskeyRepo.GetTokenData(sessionID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to get token data")) + return + } + c.JSON(http.StatusOK, response) +} + func (h *UserHandler) IsPasskeyRecoveryEnabled(c *gin.Context) { userID := auth.GetUserID(c.Request.Header) response, err := h.UserController.GetKeyAttributeAndToken(c, userID) diff --git a/server/pkg/repo/passkey/passkey.go b/server/pkg/repo/passkey/passkey.go index 52221ab74b..677d17763f 100644 --- a/server/pkg/repo/passkey/passkey.go +++ b/server/pkg/repo/passkey/passkey.go @@ -19,6 +19,13 @@ import ( "github.com/go-webauthn/webauthn/webauthn" ) +const ( + // MaxSessionTokenFetchLimit specifies the maximum number of requests a client can make to retrieve token data for a given session ID. + MaxSessionTokenFetchLimit = 2 + // TokenFetchAllowedDurationInMin is the duration in minutes for which the token fetch is allowed after the session is verified. + TokenFetchAllowedDurationInMin = 2 +) + type Repository struct { DB *sql.DB webAuthnInstance *webauthn.WebAuthn @@ -177,6 +184,64 @@ func (r *Repository) StoreTokenData(sessionID string, tokenData ente.TwoFactorAu return stacktrace.Propagate(err, "") } +// GetTokenData retrieves the token data associated with a given session ID. +// The function will return the token data if the following conditions are met: +// - The token data is not null. +// - The session was verified less than 5 minutes ago. +// - The token fetch count is less than 2. +// If these conditions are met, the function will also increment the token fetch count by 1. +// +// Parameters: +// - sessionID: The ID of the session for which to retrieve the token data. +// +// Returns: +// - A pointer to a TwoFactorAuthorizationResponse object containing the token data, if the conditions are met. +// - An error, if an error occurred while retrieving the token data or if the conditions are not met. +func (r *Repository) GetTokenData(sessionID string) (*ente.TwoFactorAuthorizationResponse, error) { + var tokenDataJson []byte + var verifiedAt sql.NullInt64 + var fetchCount int + err := r.DB.QueryRow(`SELECT token_data, verified_at, token_fetch_cnt FROM passkey_login_sessions WHERE session_id = $1`, sessionID).Scan(&tokenDataJson, &verifiedAt, &fetchCount) + if err != nil { + if err == sql.ErrNoRows { + return nil, ente.ErrNotFound + } + return nil, stacktrace.Propagate(err, "") + } + if verifiedAt.Valid == false { + return nil, &ente.ApiError{ + Code: "SESSION_NOT_VERIFIED", + Message: "Session is not verified yet", + HttpStatusCode: http.StatusBadRequest, + } + } + if verifiedAt.Int64 < ente_time.MicrosecondsBeforeMinutes(TokenFetchAllowedDurationInMin) { + return nil, &ente.ApiError{ + Code: "INVALID_SESSION", + Message: "Session verified but expired now", + HttpStatusCode: http.StatusGone, + } + } + if fetchCount >= MaxSessionTokenFetchLimit { + return nil, &ente.ApiError{ + Code: "INVALID_SESSION", + Message: "Token fetch limit reached", + HttpStatusCode: http.StatusGone, + } + } + var tokenData ente.TwoFactorAuthorizationResponse + err = json.Unmarshal(tokenDataJson, &tokenData) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + // update the token_fetch_count + _, err = r.DB.Exec(`UPDATE passkey_login_sessions SET token_fetch_cnt = token_fetch_cnt + 1 WHERE session_id = $1`, sessionID) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return &tokenData, nil +} + func (r *Repository) CreateBeginAuthenticationData(user *ente.User) (options *protocol.CredentialAssertion, session *webauthn.SessionData, id uuid.UUID, err error) { passkeyUser := &PasskeyUser{ User: user,