From 2e49f581c4814a38b679660107dc03355b1d8d1e Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Thu, 17 Jul 2025 14:10:37 +0530 Subject: [PATCH] File link token middleware --- server/cmd/museum/main.go | 4 +- server/ente/public_file.go | 10 +- server/pkg/controller/public_file.go | 2 +- .../{access_token.go => collection_token.go} | 14 +- server/pkg/middleware/file_link_token.go | 171 ++++++++++++++++++ server/pkg/repo/public/public_file.go | 24 ++- server/pkg/utils/auth/auth.go | 5 +- 7 files changed, 205 insertions(+), 25 deletions(-) rename server/pkg/middleware/{access_token.go => collection_token.go} (91%) create mode 100644 server/pkg/middleware/file_link_token.go diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index ad5b42cae3..90845fbba9 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -358,7 +358,7 @@ func main() { } authMiddleware := middleware.AuthMiddleware{UserAuthRepo: userAuthRepo, Cache: authCache, UserController: userController} - accessTokenMiddleware := middleware.AccessTokenMiddleware{ + collectionTokenMiddleware := middleware.CollectionTokenMiddleware{ PublicCollectionRepo: publicCollectionRepo, PublicCollectionCtrl: publicCollectionCtrl, CollectionRepo: collectionRepo, @@ -404,7 +404,7 @@ func main() { familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer)) publicCollectionAPI := server.Group("/public-collection") - publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer)) + publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), collectionTokenMiddleware.Authenticate(urlSanitizer)) healthCheckHandler := &api.HealthCheckHandler{ DB: db, diff --git a/server/ente/public_file.go b/server/ente/public_file.go index e8462eaf9a..3b43ac0f43 100644 --- a/server/ente/public_file.go +++ b/server/ente/public_file.go @@ -19,7 +19,7 @@ type UpdateFileUrl struct { DisablePassword *bool `json:"disablePassword"` } -type PublicFileUrlRow struct { +type FileLinkRow struct { LinkID string OwnerID int64 FileID int64 @@ -52,8 +52,8 @@ type FileUrl struct { } type PublicFileAccessContext struct { - ID int64 - IP string - UserAgent string - CollectionID int64 + ID string + IP string + UserAgent string + FileID int64 } diff --git a/server/pkg/controller/public_file.go b/server/pkg/controller/public_file.go index b15d313c93..c74f1050df 100644 --- a/server/pkg/controller/public_file.go +++ b/server/pkg/controller/public_file.go @@ -36,7 +36,7 @@ func (c *PublicFileLinkController) CreateFileUrl(ctx *gin.Context, req ente.Crea return nil, stacktrace.Propagate(err, "failed to create public file link") } -func (c *PublicFileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.PublicFileUrlRow) *ente.FileUrl { +func (c *PublicFileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl { app := auth.GetApp(ctx) var url string if app == ente.Locker { diff --git a/server/pkg/middleware/access_token.go b/server/pkg/middleware/collection_token.go similarity index 91% rename from server/pkg/middleware/access_token.go rename to server/pkg/middleware/collection_token.go index 05ba354580..079f6f6f95 100644 --- a/server/pkg/middleware/access_token.go +++ b/server/pkg/middleware/collection_token.go @@ -25,8 +25,8 @@ import ( var passwordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"} var whitelistedCollectionShareIDs = []int64{111} -// AccessTokenMiddleware intercepts and authenticates incoming requests -type AccessTokenMiddleware struct { +// CollectionTokenMiddleware intercepts and authenticates incoming requests +type CollectionTokenMiddleware struct { PublicCollectionRepo *public.PublicCollectionRepository PublicCollectionCtrl *controller.PublicCollectionController CollectionRepo *repo.CollectionRepository @@ -35,10 +35,10 @@ type AccessTokenMiddleware struct { DiscordController *discord.DiscordController } -// AccessTokenAuthMiddleware returns a middle ware that extracts the `X-Auth-Access-Token` +// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token` // within the header of a request and uses it to validate the access token and set the // ente.PublicAccessContext with auth.PublicAccessKey as key -func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc { +func (m *CollectionTokenMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc { return func(c *gin.Context) { accessToken := auth.GetAccessToken(c) if accessToken == "" { @@ -113,7 +113,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g c.Next() } } -func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error { +func (m *CollectionTokenMiddleware) validateOwnersSubscription(cID int64) error { userID, err := m.CollectionRepo.GetOwnerID(cID) if err != nil { return stacktrace.Propagate(err, "") @@ -121,7 +121,7 @@ func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error { return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false) } -func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context, +func (m *CollectionTokenMiddleware) isDeviceLimitReached(ctx context.Context, collectionSummary ente.PublicCollectionSummary, ip string, ua string) (bool, error) { // skip deviceLimit check & record keeping for requests via CF worker if network.IsCFWorkerIP(ip) { @@ -163,7 +163,7 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context, } // validatePassword will verify if the user is provided correct password for the public album -func (m *AccessTokenMiddleware) validatePassword(c *gin.Context, reqPath string, +func (m *CollectionTokenMiddleware) validatePassword(c *gin.Context, reqPath string, collectionSummary ente.PublicCollectionSummary) error { if array.StringInList(reqPath, passwordWhiteListedURLs) { return nil diff --git a/server/pkg/middleware/file_link_token.go b/server/pkg/middleware/file_link_token.go new file mode 100644 index 0000000000..9ff26d2533 --- /dev/null +++ b/server/pkg/middleware/file_link_token.go @@ -0,0 +1,171 @@ +package middleware + +import ( + "context" + "fmt" + "github.com/ente-io/museum/pkg/repo/public" + "net/http" + + "github.com/ente-io/museum/ente" + "github.com/ente-io/museum/pkg/controller" + "github.com/ente-io/museum/pkg/controller/discord" + "github.com/ente-io/museum/pkg/repo" + "github.com/ente-io/museum/pkg/utils/array" + "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/network" + "github.com/ente-io/museum/pkg/utils/time" + "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" + "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" +) + +var filePasswordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"} + +// FileLinkMiddleware intercepts and authenticates incoming requests +type FileLinkMiddleware struct { + FileLinkRepo *public.FileLinkRepository + PublicCollectionCtrl *controller.PublicCollectionController + CollectionRepo *repo.CollectionRepository + Cache *cache.Cache + BillingCtrl *controller.BillingController + DiscordController *discord.DiscordController +} + +// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token` +// within the header of a request and uses it to validate the access token and set the +// ente.PublicAccessContext with auth.PublicAccessKey as key +func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc { + return func(c *gin.Context) { + accessToken := auth.GetAccessToken(c) + if accessToken == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing accessToken"}) + return + } + clientIP := network.GetClientIP(c) + userAgent := c.GetHeader("User-Agent") + + cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":") + cachedValue, cacheHit := m.Cache.Get(cacheKey) + var publicCollectionSummary *ente.FileLinkRow + var err error + if !cacheHit { + publicCollectionSummary, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + return + } + if publicCollectionSummary.IsDisabled { + c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"}) + return + } + // validate if user still has active paid subscription + if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(publicCollectionSummary.OwnerID, true); err != nil { + logrus.WithError(err).Warn("failed to verify active paid subscription") + c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"}) + return + } + + // validate device limit + reached, err := m.isDeviceLimitReached(c, publicCollectionSummary, clientIP, userAgent) + if err != nil { + logrus.WithError(err).Error("failed to check device limit") + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"}) + return + } + if reached { + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "reached device limit"}) + return + } + } else { + publicCollectionSummary = cachedValue.(*ente.FileLinkRow) + } + + if publicCollectionSummary.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry + publicCollectionSummary.ValidTill < time.Microseconds() { + c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"}) + return + } + + // checks password protected public collection + if publicCollectionSummary.PassHash != nil && *publicCollectionSummary.PassHash != "" { + reqPath := urlSanitizer(c) + if err = m.validatePassword(c, reqPath, publicCollectionSummary); err != nil { + logrus.WithError(err).Warn("password validation failed") + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err}) + return + } + } + + if !cacheHit { + m.Cache.Set(cacheKey, publicCollectionSummary, cache.DefaultExpiration) + } + + c.Set(auth.FileLinkAccessKey, &ente.PublicFileAccessContext{ + ID: publicCollectionSummary.LinkID, + IP: clientIP, + UserAgent: userAgent, + FileID: publicCollectionSummary.FileID, + }) + c.Next() + } +} +func (m *FileLinkMiddleware) validateOwnersSubscription(cID int64) error { + userID, err := m.CollectionRepo.GetOwnerID(cID) + if err != nil { + return stacktrace.Propagate(err, "") + } + return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true) +} + +func (m *FileLinkMiddleware) isDeviceLimitReached(ctx context.Context, + collectionSummary *ente.FileLinkRow, ip string, ua string) (bool, error) { + // skip deviceLimit check & record keeping for requests via CF worker + if network.IsCFWorkerIP(ip) { + return false, nil + } + + sharedID := collectionSummary.LinkID + hasAccessedInPast, err := m.FileLinkRepo.AccessedInPast(ctx, sharedID, ip, ua) + if err != nil { + return false, stacktrace.Propagate(err, "") + } + // if the device has accessed the url in the past, let it access it now as well, irrespective of device limit. + if hasAccessedInPast { + return false, nil + } + count, err := m.FileLinkRepo.GetUniqueAccessCount(ctx, sharedID) + if err != nil { + return false, stacktrace.Propagate(err, "failed to get unique access count") + } + + deviceLimit := int64(collectionSummary.DeviceLimit) + if deviceLimit == controller.DeviceLimitThreshold { + deviceLimit = controller.DeviceLimitThresholdMultiplier * controller.DeviceLimitThreshold + } + + if count >= controller.DeviceLimitWarningThreshold { + m.DiscordController.NotifyPotentialAbuse( + fmt.Sprintf("Album exceeds warning threshold: {FileID: %d, ShareID: %s}", + collectionSummary.FileID, collectionSummary.LinkID)) + } + + if deviceLimit > 0 && count >= deviceLimit { + return true, nil + } + err = m.FileLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua) + return false, stacktrace.Propagate(err, "failed to record access history") +} + +// validatePassword will verify if the user is provided correct password for the public album +func (m *FileLinkMiddleware) validatePassword(c *gin.Context, reqPath string, + fileLinkRow *ente.FileLinkRow) error { + if array.StringInList(reqPath, passwordWhiteListedURLs) { + return nil + } + accessTokenJWT := auth.GetAccessTokenJWT(c) + if accessTokenJWT == "" { + return ente.ErrAuthenticationRequired + } + return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash) +} diff --git a/server/pkg/repo/public/public_file.go b/server/pkg/repo/public/public_file.go index 9841e7edc2..022a6ba7b0 100644 --- a/server/pkg/repo/public/public_file.go +++ b/server/pkg/repo/public/public_file.go @@ -3,6 +3,7 @@ package public import ( "context" "database/sql" + "errors" "fmt" "github.com/ente-io/museum/ente/base" @@ -64,13 +65,13 @@ func (pcr *FileLinkRepository) Insert( // GetActiveFileUrlToken will return ente.PublicCollectionToken for given collection ID // Note: The token could be expired or deviceLimit is already reached -func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.PublicFileUrlRow, error) { +func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT id, file_id, owner_id, access_token, valid_till, device_limit, is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download FROM public_file_tokens WHERE file_id = $1 and is_disabled = FALSE`, fileID) - ret := ente.PublicFileUrlRow{} + ret := ente.FileLinkRow{} err := row.Scan(&ret.LinkID, &ret.FileID, ret.OwnerID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit, &ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload) if err != nil { @@ -79,14 +80,14 @@ func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID return &ret, nil } -func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.PublicFileUrlRow, error) { +func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.FileLinkRow, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit created_at, updated_at from public_file_tokens where access_token = $1 `, accessToken) - var result = ente.PublicFileUrlRow{} + var result = ente.FileLinkRow{} err := row.Scan(&result.LinkID, &result.FileID, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt) if err != nil { if err == sql.ErrNoRows { @@ -98,7 +99,7 @@ func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessT } // UpdateLink will update the row for corresponding public file token -func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.PublicFileUrlRow) error { +func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.FileLinkRow) error { _, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET valid_till = $1, device_limit = $2, pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7 where id = $8`, @@ -116,7 +117,7 @@ func (pcr *FileLinkRepository) GetUniqueAccessCount(ctx context.Context, linkId return count, nil } -func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error { +func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID string, ip string, ua string) error { _, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens_access_history (id, ip, user_agent) VALUES ($1, $2, $3) ON CONFLICT ON CONSTRAINT unique_access_id_ip_ua DO NOTHING;`, @@ -125,8 +126,15 @@ func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID } // AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past -func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) { - panic("not implemented, refactor & public file") +func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID string, ip string, ua string) (bool, error) { + row := pcr.DB.QueryRowContext(ctx, `select id from public_file_tokens_access_history where id =$1 and ip = $2 and user_agent = $3`, + shareID, ip, ua) + var tempID int64 + err := row.Scan(&tempID) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + return true, stacktrace.Propagate(err, "failed to record access history") } // CleanupAccessHistory public_file_tokens_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days diff --git a/server/pkg/utils/auth/auth.go b/server/pkg/utils/auth/auth.go index 6f8091998b..e352d168f1 100644 --- a/server/pkg/utils/auth/auth.go +++ b/server/pkg/utils/auth/auth.go @@ -17,8 +17,9 @@ import ( ) const ( - PublicAccessKey = "X-Public-Access-ID" - CastContext = "X-Cast-Context" + PublicAccessKey = "X-Public-Access-ID" + FileLinkAccessKey = "X-Public-FileLink-Access-ID" + CastContext = "X-Cast-Context" ) // GenerateRandomBytes returns securely generated random bytes.