File link token middleware

This commit is contained in:
Neeraj Gupta
2025-07-17 14:10:37 +05:30
parent c5d9b2408f
commit 2e49f581c4
7 changed files with 205 additions and 25 deletions

View File

@@ -358,7 +358,7 @@ func main() {
}
authMiddleware := middleware.AuthMiddleware{UserAuthRepo: userAuthRepo, Cache: authCache, UserController: userController}
accessTokenMiddleware := middleware.AccessTokenMiddleware{
collectionTokenMiddleware := middleware.CollectionTokenMiddleware{
PublicCollectionRepo: publicCollectionRepo,
PublicCollectionCtrl: publicCollectionCtrl,
CollectionRepo: collectionRepo,
@@ -404,7 +404,7 @@ func main() {
familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
publicCollectionAPI := server.Group("/public-collection")
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), collectionTokenMiddleware.Authenticate(urlSanitizer))
healthCheckHandler := &api.HealthCheckHandler{
DB: db,

View File

@@ -19,7 +19,7 @@ type UpdateFileUrl struct {
DisablePassword *bool `json:"disablePassword"`
}
type PublicFileUrlRow struct {
type FileLinkRow struct {
LinkID string
OwnerID int64
FileID int64
@@ -52,8 +52,8 @@ type FileUrl struct {
}
type PublicFileAccessContext struct {
ID int64
IP string
UserAgent string
CollectionID int64
ID string
IP string
UserAgent string
FileID int64
}

View File

@@ -36,7 +36,7 @@ func (c *PublicFileLinkController) CreateFileUrl(ctx *gin.Context, req ente.Crea
return nil, stacktrace.Propagate(err, "failed to create public file link")
}
func (c *PublicFileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.PublicFileUrlRow) *ente.FileUrl {
func (c *PublicFileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl {
app := auth.GetApp(ctx)
var url string
if app == ente.Locker {

View File

@@ -25,8 +25,8 @@ import (
var passwordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"}
var whitelistedCollectionShareIDs = []int64{111}
// AccessTokenMiddleware intercepts and authenticates incoming requests
type AccessTokenMiddleware struct {
// CollectionTokenMiddleware intercepts and authenticates incoming requests
type CollectionTokenMiddleware struct {
PublicCollectionRepo *public.PublicCollectionRepository
PublicCollectionCtrl *controller.PublicCollectionController
CollectionRepo *repo.CollectionRepository
@@ -35,10 +35,10 @@ type AccessTokenMiddleware struct {
DiscordController *discord.DiscordController
}
// AccessTokenAuthMiddleware returns a middle ware that extracts the `X-Auth-Access-Token`
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
// within the header of a request and uses it to validate the access token and set the
// ente.PublicAccessContext with auth.PublicAccessKey as key
func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
func (m *CollectionTokenMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
accessToken := auth.GetAccessToken(c)
if accessToken == "" {
@@ -113,7 +113,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g
c.Next()
}
}
func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
func (m *CollectionTokenMiddleware) validateOwnersSubscription(cID int64) error {
userID, err := m.CollectionRepo.GetOwnerID(cID)
if err != nil {
return stacktrace.Propagate(err, "")
@@ -121,7 +121,7 @@ func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
}
func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
func (m *CollectionTokenMiddleware) isDeviceLimitReached(ctx context.Context,
collectionSummary ente.PublicCollectionSummary, ip string, ua string) (bool, error) {
// skip deviceLimit check & record keeping for requests via CF worker
if network.IsCFWorkerIP(ip) {
@@ -163,7 +163,7 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
}
// validatePassword will verify if the user is provided correct password for the public album
func (m *AccessTokenMiddleware) validatePassword(c *gin.Context, reqPath string,
func (m *CollectionTokenMiddleware) validatePassword(c *gin.Context, reqPath string,
collectionSummary ente.PublicCollectionSummary) error {
if array.StringInList(reqPath, passwordWhiteListedURLs) {
return nil

View File

@@ -0,0 +1,171 @@
package middleware
import (
"context"
"fmt"
"github.com/ente-io/museum/pkg/repo/public"
"net/http"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/controller"
"github.com/ente-io/museum/pkg/controller/discord"
"github.com/ente-io/museum/pkg/repo"
"github.com/ente-io/museum/pkg/utils/array"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/museum/pkg/utils/network"
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/sirupsen/logrus"
)
var filePasswordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"}
// FileLinkMiddleware intercepts and authenticates incoming requests
type FileLinkMiddleware struct {
FileLinkRepo *public.FileLinkRepository
PublicCollectionCtrl *controller.PublicCollectionController
CollectionRepo *repo.CollectionRepository
Cache *cache.Cache
BillingCtrl *controller.BillingController
DiscordController *discord.DiscordController
}
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
// within the header of a request and uses it to validate the access token and set the
// ente.PublicAccessContext with auth.PublicAccessKey as key
func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
accessToken := auth.GetAccessToken(c)
if accessToken == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing accessToken"})
return
}
clientIP := network.GetClientIP(c)
userAgent := c.GetHeader("User-Agent")
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
cachedValue, cacheHit := m.Cache.Get(cacheKey)
var publicCollectionSummary *ente.FileLinkRow
var err error
if !cacheHit {
publicCollectionSummary, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if publicCollectionSummary.IsDisabled {
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"})
return
}
// validate if user still has active paid subscription
if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(publicCollectionSummary.OwnerID, true); err != nil {
logrus.WithError(err).Warn("failed to verify active paid subscription")
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
return
}
// validate device limit
reached, err := m.isDeviceLimitReached(c, publicCollectionSummary, clientIP, userAgent)
if err != nil {
logrus.WithError(err).Error("failed to check device limit")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"})
return
}
if reached {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "reached device limit"})
return
}
} else {
publicCollectionSummary = cachedValue.(*ente.FileLinkRow)
}
if publicCollectionSummary.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry
publicCollectionSummary.ValidTill < time.Microseconds() {
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"})
return
}
// checks password protected public collection
if publicCollectionSummary.PassHash != nil && *publicCollectionSummary.PassHash != "" {
reqPath := urlSanitizer(c)
if err = m.validatePassword(c, reqPath, publicCollectionSummary); err != nil {
logrus.WithError(err).Warn("password validation failed")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err})
return
}
}
if !cacheHit {
m.Cache.Set(cacheKey, publicCollectionSummary, cache.DefaultExpiration)
}
c.Set(auth.FileLinkAccessKey, &ente.PublicFileAccessContext{
ID: publicCollectionSummary.LinkID,
IP: clientIP,
UserAgent: userAgent,
FileID: publicCollectionSummary.FileID,
})
c.Next()
}
}
func (m *FileLinkMiddleware) validateOwnersSubscription(cID int64) error {
userID, err := m.CollectionRepo.GetOwnerID(cID)
if err != nil {
return stacktrace.Propagate(err, "")
}
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
}
func (m *FileLinkMiddleware) isDeviceLimitReached(ctx context.Context,
collectionSummary *ente.FileLinkRow, ip string, ua string) (bool, error) {
// skip deviceLimit check & record keeping for requests via CF worker
if network.IsCFWorkerIP(ip) {
return false, nil
}
sharedID := collectionSummary.LinkID
hasAccessedInPast, err := m.FileLinkRepo.AccessedInPast(ctx, sharedID, ip, ua)
if err != nil {
return false, stacktrace.Propagate(err, "")
}
// if the device has accessed the url in the past, let it access it now as well, irrespective of device limit.
if hasAccessedInPast {
return false, nil
}
count, err := m.FileLinkRepo.GetUniqueAccessCount(ctx, sharedID)
if err != nil {
return false, stacktrace.Propagate(err, "failed to get unique access count")
}
deviceLimit := int64(collectionSummary.DeviceLimit)
if deviceLimit == controller.DeviceLimitThreshold {
deviceLimit = controller.DeviceLimitThresholdMultiplier * controller.DeviceLimitThreshold
}
if count >= controller.DeviceLimitWarningThreshold {
m.DiscordController.NotifyPotentialAbuse(
fmt.Sprintf("Album exceeds warning threshold: {FileID: %d, ShareID: %s}",
collectionSummary.FileID, collectionSummary.LinkID))
}
if deviceLimit > 0 && count >= deviceLimit {
return true, nil
}
err = m.FileLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
return false, stacktrace.Propagate(err, "failed to record access history")
}
// validatePassword will verify if the user is provided correct password for the public album
func (m *FileLinkMiddleware) validatePassword(c *gin.Context, reqPath string,
fileLinkRow *ente.FileLinkRow) error {
if array.StringInList(reqPath, passwordWhiteListedURLs) {
return nil
}
accessTokenJWT := auth.GetAccessTokenJWT(c)
if accessTokenJWT == "" {
return ente.ErrAuthenticationRequired
}
return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash)
}

View File

@@ -3,6 +3,7 @@ package public
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/ente-io/museum/ente/base"
@@ -64,13 +65,13 @@ func (pcr *FileLinkRepository) Insert(
// GetActiveFileUrlToken will return ente.PublicCollectionToken for given collection ID
// Note: The token could be expired or deviceLimit is already reached
func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.PublicFileUrlRow, error) {
func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) {
row := pcr.DB.QueryRowContext(ctx, `SELECT id, file_id, owner_id, access_token, valid_till, device_limit,
is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download FROM
public_file_tokens WHERE file_id = $1 and is_disabled = FALSE`,
fileID)
ret := ente.PublicFileUrlRow{}
ret := ente.FileLinkRow{}
err := row.Scan(&ret.LinkID, &ret.FileID, ret.OwnerID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit,
&ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload)
if err != nil {
@@ -79,14 +80,14 @@ func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID
return &ret, nil
}
func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.PublicFileUrlRow, error) {
func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.FileLinkRow, error) {
row := pcr.DB.QueryRowContext(ctx,
`SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit
created_at, updated_at
from public_file_tokens
where access_token = $1
`, accessToken)
var result = ente.PublicFileUrlRow{}
var result = ente.FileLinkRow{}
err := row.Scan(&result.LinkID, &result.FileID, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt)
if err != nil {
if err == sql.ErrNoRows {
@@ -98,7 +99,7 @@ func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessT
}
// UpdateLink will update the row for corresponding public file token
func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.PublicFileUrlRow) error {
func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.FileLinkRow) error {
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET valid_till = $1, device_limit = $2,
pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7
where id = $8`,
@@ -116,7 +117,7 @@ func (pcr *FileLinkRepository) GetUniqueAccessCount(ctx context.Context, linkId
return count, nil
}
func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error {
func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID string, ip string, ua string) error {
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens_access_history
(id, ip, user_agent) VALUES ($1, $2, $3)
ON CONFLICT ON CONSTRAINT unique_access_id_ip_ua DO NOTHING;`,
@@ -125,8 +126,15 @@ func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID
}
// AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past
func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) {
panic("not implemented, refactor & public file")
func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID string, ip string, ua string) (bool, error) {
row := pcr.DB.QueryRowContext(ctx, `select id from public_file_tokens_access_history where id =$1 and ip = $2 and user_agent = $3`,
shareID, ip, ua)
var tempID int64
err := row.Scan(&tempID)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return true, stacktrace.Propagate(err, "failed to record access history")
}
// CleanupAccessHistory public_file_tokens_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days

View File

@@ -17,8 +17,9 @@ import (
)
const (
PublicAccessKey = "X-Public-Access-ID"
CastContext = "X-Cast-Context"
PublicAccessKey = "X-Public-Access-ID"
FileLinkAccessKey = "X-Public-FileLink-Access-ID"
CastContext = "X-Cast-Context"
)
// GenerateRandomBytes returns securely generated random bytes.