diff --git a/server/ente/public_file.go b/server/ente/public_file.go index 3b43ac0f43..12398c38a1 100644 --- a/server/ente/public_file.go +++ b/server/ente/public_file.go @@ -51,8 +51,8 @@ type FileUrl struct { CreatedAt int64 `json:"createdAt"` } -type PublicFileAccessContext struct { - ID string +type FileLinkAccessContext struct { + LinkID string IP string UserAgent string FileID int64 diff --git a/server/pkg/controller/public/collection_link.go b/server/pkg/controller/public/collection_link.go index 1c47b04cd9..bec8393425 100644 --- a/server/pkg/controller/public/collection_link.go +++ b/server/pkg/controller/public/collection_link.go @@ -8,7 +8,6 @@ import ( "github.com/ente-io/museum/pkg/repo/public" "github.com/ente-io/museum/ente" - enteJWT "github.com/ente-io/museum/ente/jwt" emailCtrl "github.com/ente-io/museum/pkg/controller/email" "github.com/ente-io/museum/pkg/repo" "github.com/ente-io/museum/pkg/utils/auth" @@ -16,7 +15,6 @@ import ( "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt" "github.com/lithammer/shortuuid/v3" "github.com/sirupsen/logrus" ) @@ -94,7 +92,7 @@ func (c *CollectionLinkController) CreateLink(ctx context.Context, req ente.Crea } func (c *CollectionLinkController) GetActiveCollectionLinkToken(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) { - return c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, collectionID) + return c.PublicCollectionRepo.GetActiveCollectionLinkRow(ctx, collectionID) } func (c *CollectionLinkController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) { @@ -126,7 +124,7 @@ func (c *CollectionLinkController) Disable(ctx context.Context, cID int64) error } func (c *CollectionLinkController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) { - publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, req.CollectionID) + publicCollectionToken, err := c.PublicCollectionRepo.GetActiveCollectionLinkRow(ctx, req.CollectionID) if err != nil { return ente.PublicURL{}, err } @@ -180,50 +178,15 @@ func (c *CollectionLinkController) UpdateSharedUrl(ctx context.Context, req ente // attack for guessing password. func (c *CollectionLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) { accessContext := auth.MustGetPublicAccessContext(ctx) - publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, accessContext.CollectionID) + collectionLinkRow, err := c.PublicCollectionRepo.GetActiveCollectionLinkRow(ctx, accessContext.CollectionID) if err != nil { return nil, stacktrace.Propagate(err, "failed to get public collection info") } - if publicCollectionToken.PassHash == nil || *publicCollectionToken.PassHash == "" { - return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link") - } - if req.PassHash != *publicCollectionToken.PassHash { - return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link") - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.LinkPasswordClaim{ - PassHash: req.PassHash, - ExpiryTime: time.NDaysFromNow(365), - }) - // Sign and get the complete encoded token as a string using the secret - tokenString, err := token.SignedString(c.JwtSecret) - - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - return &ente.VerifyPasswordResponse{ - JWTToken: tokenString, - }, nil + return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req) } func (c *CollectionLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error { - token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.LinkPasswordClaim{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil - } - return c.JwtSecret, nil - }) - if err != nil { - return stacktrace.Propagate(err, "JWT parsed failed") - } - claims, ok := token.Claims.(*enteJWT.LinkPasswordClaim) - - if !ok { - return stacktrace.Propagate(errors.New("no claim in jwt token"), "") - } - if token.Valid && claims.PassHash == passwordHash { - return nil - } - return ente.ErrInvalidPassword + return validateJWTToken(c.JwtSecret, jwtToken, passwordHash) } // ReportAbuse captures abuse report for a publicly shared collection. diff --git a/server/pkg/controller/public/file_link.go b/server/pkg/controller/public/file_link.go index bf1ad07865..edb90ec155 100644 --- a/server/pkg/controller/public/file_link.go +++ b/server/pkg/controller/public/file_link.go @@ -34,6 +34,23 @@ func (c *FileLinkController) CreateLink(ctx *gin.Context, req ente.CreateFileUrl return nil, stacktrace.Propagate(err, "failed to create public file link") } +// VerifyPassword verifies if the user has provided correct pw hash. If yes, it returns a signed jwt token which can be +// used by the client to pass in other requests for public collection. +// Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force +// attack for guessing password. +func (c *FileLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) { + accessContext := auth.MustGetFileLinkAccessContext(ctx) + collectionLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, accessContext.FileID) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get public collection info") + } + return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req) +} + +func (c *FileLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error { + return validateJWTToken(c.JwtSecret, jwtToken, passwordHash) +} + func (c *FileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl { app := auth.GetApp(ctx) var url string diff --git a/server/pkg/controller/public/link_common.go b/server/pkg/controller/public/link_common.go index 9fd24e5c1e..9a56334a0e 100644 --- a/server/pkg/controller/public/link_common.go +++ b/server/pkg/controller/public/link_common.go @@ -1 +1,54 @@ package public + +import ( + "errors" + "fmt" + "github.com/ente-io/museum/ente" + enteJWT "github.com/ente-io/museum/ente/jwt" + "github.com/ente-io/museum/pkg/utils/time" + "github.com/ente-io/stacktrace" + "github.com/golang-jwt/jwt" +) + +func validateJWTToken(secret []byte, jwtToken string, passwordHash string) error { + token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.LinkPasswordClaim{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil + } + return secret, nil + }) + if err != nil { + return stacktrace.Propagate(err, "JWT parsed failed") + } + claims, ok := token.Claims.(*enteJWT.LinkPasswordClaim) + + if !ok { + return stacktrace.Propagate(errors.New("no claim in jwt token"), "") + } + if token.Valid && claims.PassHash == passwordHash { + return nil + } + return ente.ErrInvalidPassword +} + +func verifyPassword(secret []byte, expectedPassHash *string, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) { + if expectedPassHash == nil || *expectedPassHash == "" { + return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link") + } + if req.PassHash != *expectedPassHash { + return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link") + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.LinkPasswordClaim{ + PassHash: req.PassHash, + ExpiryTime: time.NDaysFromNow(365), + }) + // Sign and get the complete encoded token as a string using the secret + tokenString, err := token.SignedString(secret) + + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return &ente.VerifyPasswordResponse{ + JWTToken: tokenString, + }, nil +} diff --git a/server/pkg/middleware/file_link_token.go b/server/pkg/middleware/file_link_token.go index 5c230882b5..539ca1179e 100644 --- a/server/pkg/middleware/file_link_token.go +++ b/server/pkg/middleware/file_link_token.go @@ -25,12 +25,12 @@ var filePasswordWhiteListedURLs = []string{"/public-collection/info", "/public-c // FileLinkMiddleware intercepts and authenticates incoming requests type FileLinkMiddleware struct { - FileLinkRepo *public.FileLinkRepository - PublicCollectionCtrl *publicCtrl.CollectionLinkController - CollectionRepo *repo.CollectionRepository - Cache *cache.Cache - BillingCtrl *controller.BillingController - DiscordController *discord.DiscordController + FileLinkRepo *public.FileLinkRepository + FileLinkCtrl *publicCtrl.FileLinkController + CollectionRepo *repo.CollectionRepository + Cache *cache.Cache + BillingCtrl *controller.BillingController + DiscordController *discord.DiscordController } // Authenticate returns a middle ware that extracts the `X-Auth-Access-Token` @@ -48,27 +48,27 @@ func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) stri cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":") cachedValue, cacheHit := m.Cache.Get(cacheKey) - var publicCollectionSummary *ente.FileLinkRow + var fileLinkRow *ente.FileLinkRow var err error if !cacheHit { - publicCollectionSummary, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken) + fileLinkRow, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) return } - if publicCollectionSummary.IsDisabled { + if fileLinkRow.IsDisabled { c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"}) return } // validate if user still has active paid subscription - if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(publicCollectionSummary.OwnerID, true); err != nil { + if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(fileLinkRow.OwnerID, true); err != nil { logrus.WithError(err).Warn("failed to verify active paid subscription") c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"}) return } // validate device limit - reached, err := m.isDeviceLimitReached(c, publicCollectionSummary, clientIP, userAgent) + reached, err := m.isDeviceLimitReached(c, fileLinkRow, clientIP, userAgent) if err != nil { logrus.WithError(err).Error("failed to check device limit") c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"}) @@ -79,19 +79,19 @@ func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) stri return } } else { - publicCollectionSummary = cachedValue.(*ente.FileLinkRow) + fileLinkRow = cachedValue.(*ente.FileLinkRow) } - if publicCollectionSummary.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry - publicCollectionSummary.ValidTill < time.Microseconds() { + if fileLinkRow.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry + fileLinkRow.ValidTill < time.Microseconds() { c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"}) return } // checks password protected public collection - if publicCollectionSummary.PassHash != nil && *publicCollectionSummary.PassHash != "" { + if fileLinkRow.PassHash != nil && *fileLinkRow.PassHash != "" { reqPath := urlSanitizer(c) - if err = m.validatePassword(c, reqPath, publicCollectionSummary); err != nil { + if err = m.validatePassword(c, reqPath, fileLinkRow); err != nil { logrus.WithError(err).Warn("password validation failed") c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err}) return @@ -99,14 +99,14 @@ func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) stri } if !cacheHit { - m.Cache.Set(cacheKey, publicCollectionSummary, cache.DefaultExpiration) + m.Cache.Set(cacheKey, fileLinkRow, cache.DefaultExpiration) } - c.Set(auth.FileLinkAccessKey, &ente.PublicFileAccessContext{ - ID: publicCollectionSummary.LinkID, + c.Set(auth.FileLinkAccessKey, &ente.FileLinkAccessContext{ + LinkID: fileLinkRow.LinkID, IP: clientIP, UserAgent: userAgent, - FileID: publicCollectionSummary.FileID, + FileID: fileLinkRow.FileID, }) c.Next() } @@ -168,5 +168,5 @@ func (m *FileLinkMiddleware) validatePassword(c *gin.Context, reqPath string, if accessTokenJWT == "" { return ente.ErrAuthenticationRequired } - return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash) + return m.FileLinkCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash) } diff --git a/server/pkg/repo/public/public_collection.go b/server/pkg/repo/public/public_collection.go index 077d1889e8..5a73c2e43f 100644 --- a/server/pkg/repo/public/public_collection.go +++ b/server/pkg/repo/public/public_collection.go @@ -92,9 +92,9 @@ func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx con return result, nil } -// GetActivePublicCollectionToken will return ente.CollectionLinkRow for given collection ID +// GetActiveCollectionLinkRow will return ente.CollectionLinkRow for given collection ID // Note: The token could be expired or deviceLimit is already reached -func (pcr *PublicCollectionRepository) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) { +func (pcr *PublicCollectionRepository) GetActiveCollectionLinkRow(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT id, collection_id, access_token, valid_till, device_limit, is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download, enable_collect, enable_join FROM public_collection_tokens WHERE collection_id = $1 and is_disabled = FALSE`, diff --git a/server/pkg/utils/auth/auth.go b/server/pkg/utils/auth/auth.go index e352d168f1..85acc995c3 100644 --- a/server/pkg/utils/auth/auth.go +++ b/server/pkg/utils/auth/auth.go @@ -133,6 +133,10 @@ func MustGetPublicAccessContext(c *gin.Context) ente.PublicAccessContext { return c.MustGet(PublicAccessKey).(ente.PublicAccessContext) } +func MustGetFileLinkAccessContext(c *gin.Context) *ente.FileLinkAccessContext { + return c.MustGet(FileLinkAccessKey).(*ente.FileLinkAccessContext) +} + func GetCastCtx(c *gin.Context) cast.AuthContext { return c.MustGet(CastContext).(cast.AuthContext) }