File link token middleware
This commit is contained in:
@@ -358,7 +358,7 @@ func main() {
|
||||
}
|
||||
|
||||
authMiddleware := middleware.AuthMiddleware{UserAuthRepo: userAuthRepo, Cache: authCache, UserController: userController}
|
||||
accessTokenMiddleware := middleware.AccessTokenMiddleware{
|
||||
collectionTokenMiddleware := middleware.CollectionTokenMiddleware{
|
||||
PublicCollectionRepo: publicCollectionRepo,
|
||||
PublicCollectionCtrl: publicCollectionCtrl,
|
||||
CollectionRepo: collectionRepo,
|
||||
@@ -404,7 +404,7 @@ func main() {
|
||||
familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
|
||||
publicCollectionAPI := server.Group("/public-collection")
|
||||
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
|
||||
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), collectionTokenMiddleware.Authenticate(urlSanitizer))
|
||||
|
||||
healthCheckHandler := &api.HealthCheckHandler{
|
||||
DB: db,
|
||||
|
||||
@@ -19,7 +19,7 @@ type UpdateFileUrl struct {
|
||||
DisablePassword *bool `json:"disablePassword"`
|
||||
}
|
||||
|
||||
type PublicFileUrlRow struct {
|
||||
type FileLinkRow struct {
|
||||
LinkID string
|
||||
OwnerID int64
|
||||
FileID int64
|
||||
@@ -52,8 +52,8 @@ type FileUrl struct {
|
||||
}
|
||||
|
||||
type PublicFileAccessContext struct {
|
||||
ID int64
|
||||
IP string
|
||||
UserAgent string
|
||||
CollectionID int64
|
||||
ID string
|
||||
IP string
|
||||
UserAgent string
|
||||
FileID int64
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func (c *PublicFileLinkController) CreateFileUrl(ctx *gin.Context, req ente.Crea
|
||||
return nil, stacktrace.Propagate(err, "failed to create public file link")
|
||||
}
|
||||
|
||||
func (c *PublicFileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.PublicFileUrlRow) *ente.FileUrl {
|
||||
func (c *PublicFileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl {
|
||||
app := auth.GetApp(ctx)
|
||||
var url string
|
||||
if app == ente.Locker {
|
||||
|
||||
@@ -25,8 +25,8 @@ import (
|
||||
var passwordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"}
|
||||
var whitelistedCollectionShareIDs = []int64{111}
|
||||
|
||||
// AccessTokenMiddleware intercepts and authenticates incoming requests
|
||||
type AccessTokenMiddleware struct {
|
||||
// CollectionTokenMiddleware intercepts and authenticates incoming requests
|
||||
type CollectionTokenMiddleware struct {
|
||||
PublicCollectionRepo *public.PublicCollectionRepository
|
||||
PublicCollectionCtrl *controller.PublicCollectionController
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
@@ -35,10 +35,10 @@ type AccessTokenMiddleware struct {
|
||||
DiscordController *discord.DiscordController
|
||||
}
|
||||
|
||||
// AccessTokenAuthMiddleware returns a middle ware that extracts the `X-Auth-Access-Token`
|
||||
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
|
||||
// within the header of a request and uses it to validate the access token and set the
|
||||
// ente.PublicAccessContext with auth.PublicAccessKey as key
|
||||
func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
||||
func (m *CollectionTokenMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
accessToken := auth.GetAccessToken(c)
|
||||
if accessToken == "" {
|
||||
@@ -113,7 +113,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
func (m *CollectionTokenMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
userID, err := m.CollectionRepo.GetOwnerID(cID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
@@ -121,7 +121,7 @@ func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
|
||||
}
|
||||
|
||||
func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
func (m *CollectionTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
collectionSummary ente.PublicCollectionSummary, ip string, ua string) (bool, error) {
|
||||
// skip deviceLimit check & record keeping for requests via CF worker
|
||||
if network.IsCFWorkerIP(ip) {
|
||||
@@ -163,7 +163,7 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
}
|
||||
|
||||
// validatePassword will verify if the user is provided correct password for the public album
|
||||
func (m *AccessTokenMiddleware) validatePassword(c *gin.Context, reqPath string,
|
||||
func (m *CollectionTokenMiddleware) validatePassword(c *gin.Context, reqPath string,
|
||||
collectionSummary ente.PublicCollectionSummary) error {
|
||||
if array.StringInList(reqPath, passwordWhiteListedURLs) {
|
||||
return nil
|
||||
171
server/pkg/middleware/file_link_token.go
Normal file
171
server/pkg/middleware/file_link_token.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
"net/http"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/ente-io/museum/pkg/controller/discord"
|
||||
"github.com/ente-io/museum/pkg/repo"
|
||||
"github.com/ente-io/museum/pkg/utils/array"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
"github.com/ente-io/museum/pkg/utils/network"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var filePasswordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"}
|
||||
|
||||
// FileLinkMiddleware intercepts and authenticates incoming requests
|
||||
type FileLinkMiddleware struct {
|
||||
FileLinkRepo *public.FileLinkRepository
|
||||
PublicCollectionCtrl *controller.PublicCollectionController
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
Cache *cache.Cache
|
||||
BillingCtrl *controller.BillingController
|
||||
DiscordController *discord.DiscordController
|
||||
}
|
||||
|
||||
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
|
||||
// within the header of a request and uses it to validate the access token and set the
|
||||
// ente.PublicAccessContext with auth.PublicAccessKey as key
|
||||
func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
accessToken := auth.GetAccessToken(c)
|
||||
if accessToken == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing accessToken"})
|
||||
return
|
||||
}
|
||||
clientIP := network.GetClientIP(c)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
|
||||
cachedValue, cacheHit := m.Cache.Get(cacheKey)
|
||||
var publicCollectionSummary *ente.FileLinkRow
|
||||
var err error
|
||||
if !cacheHit {
|
||||
publicCollectionSummary, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
if publicCollectionSummary.IsDisabled {
|
||||
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"})
|
||||
return
|
||||
}
|
||||
// validate if user still has active paid subscription
|
||||
if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(publicCollectionSummary.OwnerID, true); err != nil {
|
||||
logrus.WithError(err).Warn("failed to verify active paid subscription")
|
||||
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
// validate device limit
|
||||
reached, err := m.isDeviceLimitReached(c, publicCollectionSummary, clientIP, userAgent)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("failed to check device limit")
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"})
|
||||
return
|
||||
}
|
||||
if reached {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "reached device limit"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
publicCollectionSummary = cachedValue.(*ente.FileLinkRow)
|
||||
}
|
||||
|
||||
if publicCollectionSummary.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry
|
||||
publicCollectionSummary.ValidTill < time.Microseconds() {
|
||||
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"})
|
||||
return
|
||||
}
|
||||
|
||||
// checks password protected public collection
|
||||
if publicCollectionSummary.PassHash != nil && *publicCollectionSummary.PassHash != "" {
|
||||
reqPath := urlSanitizer(c)
|
||||
if err = m.validatePassword(c, reqPath, publicCollectionSummary); err != nil {
|
||||
logrus.WithError(err).Warn("password validation failed")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !cacheHit {
|
||||
m.Cache.Set(cacheKey, publicCollectionSummary, cache.DefaultExpiration)
|
||||
}
|
||||
|
||||
c.Set(auth.FileLinkAccessKey, &ente.PublicFileAccessContext{
|
||||
ID: publicCollectionSummary.LinkID,
|
||||
IP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
FileID: publicCollectionSummary.FileID,
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
func (m *FileLinkMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
userID, err := m.CollectionRepo.GetOwnerID(cID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
|
||||
}
|
||||
|
||||
func (m *FileLinkMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
collectionSummary *ente.FileLinkRow, ip string, ua string) (bool, error) {
|
||||
// skip deviceLimit check & record keeping for requests via CF worker
|
||||
if network.IsCFWorkerIP(ip) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
sharedID := collectionSummary.LinkID
|
||||
hasAccessedInPast, err := m.FileLinkRepo.AccessedInPast(ctx, sharedID, ip, ua)
|
||||
if err != nil {
|
||||
return false, stacktrace.Propagate(err, "")
|
||||
}
|
||||
// if the device has accessed the url in the past, let it access it now as well, irrespective of device limit.
|
||||
if hasAccessedInPast {
|
||||
return false, nil
|
||||
}
|
||||
count, err := m.FileLinkRepo.GetUniqueAccessCount(ctx, sharedID)
|
||||
if err != nil {
|
||||
return false, stacktrace.Propagate(err, "failed to get unique access count")
|
||||
}
|
||||
|
||||
deviceLimit := int64(collectionSummary.DeviceLimit)
|
||||
if deviceLimit == controller.DeviceLimitThreshold {
|
||||
deviceLimit = controller.DeviceLimitThresholdMultiplier * controller.DeviceLimitThreshold
|
||||
}
|
||||
|
||||
if count >= controller.DeviceLimitWarningThreshold {
|
||||
m.DiscordController.NotifyPotentialAbuse(
|
||||
fmt.Sprintf("Album exceeds warning threshold: {FileID: %d, ShareID: %s}",
|
||||
collectionSummary.FileID, collectionSummary.LinkID))
|
||||
}
|
||||
|
||||
if deviceLimit > 0 && count >= deviceLimit {
|
||||
return true, nil
|
||||
}
|
||||
err = m.FileLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
|
||||
return false, stacktrace.Propagate(err, "failed to record access history")
|
||||
}
|
||||
|
||||
// validatePassword will verify if the user is provided correct password for the public album
|
||||
func (m *FileLinkMiddleware) validatePassword(c *gin.Context, reqPath string,
|
||||
fileLinkRow *ente.FileLinkRow) error {
|
||||
if array.StringInList(reqPath, passwordWhiteListedURLs) {
|
||||
return nil
|
||||
}
|
||||
accessTokenJWT := auth.GetAccessTokenJWT(c)
|
||||
if accessTokenJWT == "" {
|
||||
return ente.ErrAuthenticationRequired
|
||||
}
|
||||
return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package public
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/ente/base"
|
||||
|
||||
@@ -64,13 +65,13 @@ func (pcr *FileLinkRepository) Insert(
|
||||
|
||||
// GetActiveFileUrlToken will return ente.PublicCollectionToken for given collection ID
|
||||
// Note: The token could be expired or deviceLimit is already reached
|
||||
func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.PublicFileUrlRow, error) {
|
||||
func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `SELECT id, file_id, owner_id, access_token, valid_till, device_limit,
|
||||
is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download FROM
|
||||
public_file_tokens WHERE file_id = $1 and is_disabled = FALSE`,
|
||||
fileID)
|
||||
|
||||
ret := ente.PublicFileUrlRow{}
|
||||
ret := ente.FileLinkRow{}
|
||||
err := row.Scan(&ret.LinkID, &ret.FileID, ret.OwnerID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit,
|
||||
&ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload)
|
||||
if err != nil {
|
||||
@@ -79,14 +80,14 @@ func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.PublicFileUrlRow, error) {
|
||||
func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.FileLinkRow, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx,
|
||||
`SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit
|
||||
created_at, updated_at
|
||||
from public_file_tokens
|
||||
where access_token = $1
|
||||
`, accessToken)
|
||||
var result = ente.PublicFileUrlRow{}
|
||||
var result = ente.FileLinkRow{}
|
||||
err := row.Scan(&result.LinkID, &result.FileID, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -98,7 +99,7 @@ func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessT
|
||||
}
|
||||
|
||||
// UpdateLink will update the row for corresponding public file token
|
||||
func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.PublicFileUrlRow) error {
|
||||
func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.FileLinkRow) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET valid_till = $1, device_limit = $2,
|
||||
pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7
|
||||
where id = $8`,
|
||||
@@ -116,7 +117,7 @@ func (pcr *FileLinkRepository) GetUniqueAccessCount(ctx context.Context, linkId
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error {
|
||||
func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID string, ip string, ua string) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens_access_history
|
||||
(id, ip, user_agent) VALUES ($1, $2, $3)
|
||||
ON CONFLICT ON CONSTRAINT unique_access_id_ip_ua DO NOTHING;`,
|
||||
@@ -125,8 +126,15 @@ func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID
|
||||
}
|
||||
|
||||
// AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past
|
||||
func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) {
|
||||
panic("not implemented, refactor & public file")
|
||||
func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID string, ip string, ua string) (bool, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `select id from public_file_tokens_access_history where id =$1 and ip = $2 and user_agent = $3`,
|
||||
shareID, ip, ua)
|
||||
var tempID int64
|
||||
err := row.Scan(&tempID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return true, stacktrace.Propagate(err, "failed to record access history")
|
||||
}
|
||||
|
||||
// CleanupAccessHistory public_file_tokens_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days
|
||||
|
||||
@@ -17,8 +17,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
PublicAccessKey = "X-Public-Access-ID"
|
||||
CastContext = "X-Cast-Context"
|
||||
PublicAccessKey = "X-Public-Access-ID"
|
||||
FileLinkAccessKey = "X-Public-FileLink-Access-ID"
|
||||
CastContext = "X-Cast-Context"
|
||||
)
|
||||
|
||||
// GenerateRandomBytes returns securely generated random bytes.
|
||||
|
||||
Reference in New Issue
Block a user