Support for link password validation

This commit is contained in:
Neeraj Gupta
2025-07-17 15:27:21 +05:30
parent 8d108dc719
commit 51c00eefd4
7 changed files with 104 additions and 67 deletions

View File

@@ -51,8 +51,8 @@ type FileUrl struct {
CreatedAt int64 `json:"createdAt"`
}
type PublicFileAccessContext struct {
ID string
type FileLinkAccessContext struct {
LinkID string
IP string
UserAgent string
FileID int64

View File

@@ -8,7 +8,6 @@ import (
"github.com/ente-io/museum/pkg/repo/public"
"github.com/ente-io/museum/ente"
enteJWT "github.com/ente-io/museum/ente/jwt"
emailCtrl "github.com/ente-io/museum/pkg/controller/email"
"github.com/ente-io/museum/pkg/repo"
"github.com/ente-io/museum/pkg/utils/auth"
@@ -16,7 +15,6 @@ import (
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/lithammer/shortuuid/v3"
"github.com/sirupsen/logrus"
)
@@ -94,7 +92,7 @@ func (c *CollectionLinkController) CreateLink(ctx context.Context, req ente.Crea
}
func (c *CollectionLinkController) GetActiveCollectionLinkToken(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) {
return c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, collectionID)
return c.PublicCollectionRepo.GetActiveCollectionLinkRow(ctx, collectionID)
}
func (c *CollectionLinkController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) {
@@ -126,7 +124,7 @@ func (c *CollectionLinkController) Disable(ctx context.Context, cID int64) error
}
func (c *CollectionLinkController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) {
publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, req.CollectionID)
publicCollectionToken, err := c.PublicCollectionRepo.GetActiveCollectionLinkRow(ctx, req.CollectionID)
if err != nil {
return ente.PublicURL{}, err
}
@@ -180,50 +178,15 @@ func (c *CollectionLinkController) UpdateSharedUrl(ctx context.Context, req ente
// attack for guessing password.
func (c *CollectionLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
accessContext := auth.MustGetPublicAccessContext(ctx)
publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, accessContext.CollectionID)
collectionLinkRow, err := c.PublicCollectionRepo.GetActiveCollectionLinkRow(ctx, accessContext.CollectionID)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get public collection info")
}
if publicCollectionToken.PassHash == nil || *publicCollectionToken.PassHash == "" {
return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link")
}
if req.PassHash != *publicCollectionToken.PassHash {
return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.LinkPasswordClaim{
PassHash: req.PassHash,
ExpiryTime: time.NDaysFromNow(365),
})
// Sign and get the complete encoded token as a string using the secret
tokenString, err := token.SignedString(c.JwtSecret)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &ente.VerifyPasswordResponse{
JWTToken: tokenString,
}, nil
return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req)
}
func (c *CollectionLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.LinkPasswordClaim{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil
}
return c.JwtSecret, nil
})
if err != nil {
return stacktrace.Propagate(err, "JWT parsed failed")
}
claims, ok := token.Claims.(*enteJWT.LinkPasswordClaim)
if !ok {
return stacktrace.Propagate(errors.New("no claim in jwt token"), "")
}
if token.Valid && claims.PassHash == passwordHash {
return nil
}
return ente.ErrInvalidPassword
return validateJWTToken(c.JwtSecret, jwtToken, passwordHash)
}
// ReportAbuse captures abuse report for a publicly shared collection.

View File

@@ -34,6 +34,23 @@ func (c *FileLinkController) CreateLink(ctx *gin.Context, req ente.CreateFileUrl
return nil, stacktrace.Propagate(err, "failed to create public file link")
}
// VerifyPassword verifies if the user has provided correct pw hash. If yes, it returns a signed jwt token which can be
// used by the client to pass in other requests for public collection.
// Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force
// attack for guessing password.
func (c *FileLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
accessContext := auth.MustGetFileLinkAccessContext(ctx)
collectionLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, accessContext.FileID)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get public collection info")
}
return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req)
}
func (c *FileLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
return validateJWTToken(c.JwtSecret, jwtToken, passwordHash)
}
func (c *FileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl {
app := auth.GetApp(ctx)
var url string

View File

@@ -1 +1,54 @@
package public
import (
"errors"
"fmt"
"github.com/ente-io/museum/ente"
enteJWT "github.com/ente-io/museum/ente/jwt"
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
"github.com/golang-jwt/jwt"
)
func validateJWTToken(secret []byte, jwtToken string, passwordHash string) error {
token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.LinkPasswordClaim{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil
}
return secret, nil
})
if err != nil {
return stacktrace.Propagate(err, "JWT parsed failed")
}
claims, ok := token.Claims.(*enteJWT.LinkPasswordClaim)
if !ok {
return stacktrace.Propagate(errors.New("no claim in jwt token"), "")
}
if token.Valid && claims.PassHash == passwordHash {
return nil
}
return ente.ErrInvalidPassword
}
func verifyPassword(secret []byte, expectedPassHash *string, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
if expectedPassHash == nil || *expectedPassHash == "" {
return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link")
}
if req.PassHash != *expectedPassHash {
return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.LinkPasswordClaim{
PassHash: req.PassHash,
ExpiryTime: time.NDaysFromNow(365),
})
// Sign and get the complete encoded token as a string using the secret
tokenString, err := token.SignedString(secret)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &ente.VerifyPasswordResponse{
JWTToken: tokenString,
}, nil
}

View File

@@ -25,12 +25,12 @@ var filePasswordWhiteListedURLs = []string{"/public-collection/info", "/public-c
// FileLinkMiddleware intercepts and authenticates incoming requests
type FileLinkMiddleware struct {
FileLinkRepo *public.FileLinkRepository
PublicCollectionCtrl *publicCtrl.CollectionLinkController
CollectionRepo *repo.CollectionRepository
Cache *cache.Cache
BillingCtrl *controller.BillingController
DiscordController *discord.DiscordController
FileLinkRepo *public.FileLinkRepository
FileLinkCtrl *publicCtrl.FileLinkController
CollectionRepo *repo.CollectionRepository
Cache *cache.Cache
BillingCtrl *controller.BillingController
DiscordController *discord.DiscordController
}
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
@@ -48,27 +48,27 @@ func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) stri
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
cachedValue, cacheHit := m.Cache.Get(cacheKey)
var publicCollectionSummary *ente.FileLinkRow
var fileLinkRow *ente.FileLinkRow
var err error
if !cacheHit {
publicCollectionSummary, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken)
fileLinkRow, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if publicCollectionSummary.IsDisabled {
if fileLinkRow.IsDisabled {
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"})
return
}
// validate if user still has active paid subscription
if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(publicCollectionSummary.OwnerID, true); err != nil {
if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(fileLinkRow.OwnerID, true); err != nil {
logrus.WithError(err).Warn("failed to verify active paid subscription")
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
return
}
// validate device limit
reached, err := m.isDeviceLimitReached(c, publicCollectionSummary, clientIP, userAgent)
reached, err := m.isDeviceLimitReached(c, fileLinkRow, clientIP, userAgent)
if err != nil {
logrus.WithError(err).Error("failed to check device limit")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"})
@@ -79,19 +79,19 @@ func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) stri
return
}
} else {
publicCollectionSummary = cachedValue.(*ente.FileLinkRow)
fileLinkRow = cachedValue.(*ente.FileLinkRow)
}
if publicCollectionSummary.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry
publicCollectionSummary.ValidTill < time.Microseconds() {
if fileLinkRow.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry
fileLinkRow.ValidTill < time.Microseconds() {
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"})
return
}
// checks password protected public collection
if publicCollectionSummary.PassHash != nil && *publicCollectionSummary.PassHash != "" {
if fileLinkRow.PassHash != nil && *fileLinkRow.PassHash != "" {
reqPath := urlSanitizer(c)
if err = m.validatePassword(c, reqPath, publicCollectionSummary); err != nil {
if err = m.validatePassword(c, reqPath, fileLinkRow); err != nil {
logrus.WithError(err).Warn("password validation failed")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err})
return
@@ -99,14 +99,14 @@ func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) stri
}
if !cacheHit {
m.Cache.Set(cacheKey, publicCollectionSummary, cache.DefaultExpiration)
m.Cache.Set(cacheKey, fileLinkRow, cache.DefaultExpiration)
}
c.Set(auth.FileLinkAccessKey, &ente.PublicFileAccessContext{
ID: publicCollectionSummary.LinkID,
c.Set(auth.FileLinkAccessKey, &ente.FileLinkAccessContext{
LinkID: fileLinkRow.LinkID,
IP: clientIP,
UserAgent: userAgent,
FileID: publicCollectionSummary.FileID,
FileID: fileLinkRow.FileID,
})
c.Next()
}
@@ -168,5 +168,5 @@ func (m *FileLinkMiddleware) validatePassword(c *gin.Context, reqPath string,
if accessTokenJWT == "" {
return ente.ErrAuthenticationRequired
}
return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash)
return m.FileLinkCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash)
}

View File

@@ -92,9 +92,9 @@ func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx con
return result, nil
}
// GetActivePublicCollectionToken will return ente.CollectionLinkRow for given collection ID
// GetActiveCollectionLinkRow will return ente.CollectionLinkRow for given collection ID
// Note: The token could be expired or deviceLimit is already reached
func (pcr *PublicCollectionRepository) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) {
func (pcr *PublicCollectionRepository) GetActiveCollectionLinkRow(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) {
row := pcr.DB.QueryRowContext(ctx, `SELECT id, collection_id, access_token, valid_till, device_limit,
is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download, enable_collect, enable_join FROM
public_collection_tokens WHERE collection_id = $1 and is_disabled = FALSE`,

View File

@@ -133,6 +133,10 @@ func MustGetPublicAccessContext(c *gin.Context) ente.PublicAccessContext {
return c.MustGet(PublicAccessKey).(ente.PublicAccessContext)
}
func MustGetFileLinkAccessContext(c *gin.Context) *ente.FileLinkAccessContext {
return c.MustGet(FileLinkAccessKey).(*ente.FileLinkAccessContext)
}
func GetCastCtx(c *gin.Context) cast.AuthContext {
return c.MustGet(CastContext).(cast.AuthContext)
}