diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 5e5daef7ed..87d720530a 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -5,6 +5,9 @@ import ( "database/sql" b64 "encoding/base64" "fmt" + "github.com/ente-io/museum/pkg/controller/collections" + publicCtrl "github.com/ente-io/museum/pkg/controller/public" + "github.com/ente-io/museum/pkg/repo/public" "net/http" "os" "os/signal" @@ -14,8 +17,6 @@ import ( "syscall" "time" - "github.com/ente-io/museum/pkg/controller/collections" - "github.com/ente-io/museum/ente/base" "github.com/ente-io/museum/pkg/controller/emergency" "github.com/ente-io/museum/pkg/controller/file_copy" @@ -97,6 +98,7 @@ func main() { } viper.SetDefault("apps.public-albums", "https://albums.ente.io") + viper.SetDefault("apps.public-locker", "https://locker.ente.io") viper.SetDefault("apps.accounts", "https://accounts.ente.io") viper.SetDefault("apps.cast", "https://cast.ente.io") viper.SetDefault("apps.family", "https://family.ente.io") @@ -174,11 +176,13 @@ func main() { fileRepo := &repo.FileRepository{DB: db, S3Config: s3Config, QueueRepo: queueRepo, ObjectRepo: objectRepo, ObjectCleanupRepo: objectCleanupRepo, ObjectCopiesRepo: objectCopiesRepo, UsageRepo: usageRepo} + fileLinkRepo := public.NewFileLinkRepo(db) fileDataRepo := &fileDataRepo.Repository{DB: db, ObjectCleanupRepo: objectCleanupRepo} familyRepo := &repo.FamilyRepository{DB: db} - trashRepo := &repo.TrashRepository{DB: db, ObjectRepo: objectRepo, FileRepo: fileRepo, QueueRepo: queueRepo} - publicCollectionRepo := repo.NewPublicCollectionRepository(db, viper.GetString("apps.public-albums")) - collectionRepo := &repo.CollectionRepository{DB: db, FileRepo: fileRepo, PublicCollectionRepo: publicCollectionRepo, + trashRepo := &repo.TrashRepository{DB: db, ObjectRepo: objectRepo, FileRepo: fileRepo, QueueRepo: queueRepo, FileLinkRepo: fileLinkRepo} + collectionLinkRepo := public.NewCollectionLinkRepository(db, viper.GetString("apps.public-albums")) + + collectionRepo := &repo.CollectionRepository{DB: db, FileRepo: fileRepo, CollectionLinkRepo: collectionLinkRepo, TrashRepo: trashRepo, SecretEncryptionKey: secretEncryptionKeyBytes, QueueRepo: queueRepo, LatencyLogger: latencyLogger} pushRepo := &repo.PushTokenRepository{DB: db} kexRepo := &kex.Repository{ @@ -300,26 +304,27 @@ func main() { UsageRepo: usageRepo, } - publicCollectionCtrl := &controller.PublicCollectionController{ + collectionLinkCtrl := &publicCtrl.CollectionLinkController{ FileController: fileController, EmailNotificationCtrl: emailNotificationCtrl, - PublicCollectionRepo: publicCollectionRepo, + CollectionLinkRepo: collectionLinkRepo, + FileLinkRepo: fileLinkRepo, CollectionRepo: collectionRepo, UserRepo: userRepo, JwtSecret: jwtSecretBytes, } collectionController := &collections.CollectionController{ - CollectionRepo: collectionRepo, - EmailCtrl: emailNotificationCtrl, - AccessCtrl: accessCtrl, - PublicCollectionCtrl: publicCollectionCtrl, - UserRepo: userRepo, - FileRepo: fileRepo, - CastRepo: &castDb, - BillingCtrl: billingController, - QueueRepo: queueRepo, - TaskRepo: taskLockingRepo, + CollectionRepo: collectionRepo, + EmailCtrl: emailNotificationCtrl, + AccessCtrl: accessCtrl, + CollectionLinkCtrl: collectionLinkCtrl, + UserRepo: userRepo, + FileRepo: fileRepo, + CastRepo: &castDb, + BillingCtrl: billingController, + QueueRepo: queueRepo, + TaskRepo: taskLockingRepo, } kexCtrl := &kexCtrl.Controller{ @@ -351,6 +356,12 @@ func main() { userCache, userCacheCtrl, ) + fileLinkCtrl := &publicCtrl.FileLinkController{ + FileController: fileController, + FileLinkRepo: fileLinkRepo, + FileRepo: fileRepo, + JwtSecret: jwtSecretBytes, + } passkeyCtrl := &controller.PasskeyController{ Repo: passkeysRepo, @@ -358,14 +369,21 @@ func main() { } authMiddleware := middleware.AuthMiddleware{UserAuthRepo: userAuthRepo, Cache: authCache, UserController: userController} - accessTokenMiddleware := middleware.AccessTokenMiddleware{ - PublicCollectionRepo: publicCollectionRepo, - PublicCollectionCtrl: publicCollectionCtrl, + collectionLinkMiddleware := middleware.CollectionLinkMiddleware{ + CollectionLinkRepo: collectionLinkRepo, + PublicCollectionCtrl: collectionLinkCtrl, CollectionRepo: collectionRepo, Cache: accessTokenCache, BillingCtrl: billingController, DiscordController: discordController, } + fileLinkMiddleware := &middleware.FileLinkMiddleware{ + FileLinkRepo: fileLinkRepo, + FileLinkCtrl: fileLinkCtrl, + Cache: accessTokenCache, + BillingCtrl: billingController, + DiscordController: discordController, + } if environment != "local" { gin.SetMode(gin.ReleaseMode) @@ -404,7 +422,9 @@ func main() { familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer)) publicCollectionAPI := server.Group("/public-collection") - publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer)) + publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), collectionLinkMiddleware.Authenticate(urlSanitizer)) + fileLinkApi := server.GET("/file-link") + fileLinkApi.Use(rateLimiter.GlobalRateLimiter(), fileLinkMiddleware.Authenticate(urlSanitizer)) healthCheckHandler := &api.HealthCheckHandler{ DB: db, @@ -432,6 +452,7 @@ func main() { Controller: fileController, FileCopyCtrl: fileCopyCtrl, FileDataCtrl: fileDataCtrl, + FileUrlCtrl: fileLinkCtrl, } privateAPI.GET("/files/upload-urls", fileHandler.GetUploadURLs) privateAPI.GET("/files/multipart-upload-urls", fileHandler.GetMultipartUploadURLs) @@ -440,6 +461,11 @@ func main() { privateAPI.GET("/files/preview/:fileID", fileHandler.GetThumbnail) privateAPI.GET("/files/preview/v2/:fileID", fileHandler.GetThumbnail) + privateAPI.POST("/files/share-url", fileHandler.ShareUrl) + privateAPI.PUT("/files/share-url", fileHandler.UpdateFileURL) + privateAPI.DELETE("/files/share-url/:fileID", fileHandler.DisableUrl) + privateAPI.GET("/files/share-urls/", fileHandler.GetUrls) + privateAPI.PUT("/files/data", fileHandler.PutFileData) privateAPI.PUT("/files/video-data", fileHandler.PutVideoData) privateAPI.POST("/files/data/status-diff", fileHandler.FileDataStatusDiff) @@ -566,13 +592,19 @@ func main() { privateAPI.PUT("/collections/sharee-magic-metadata", collectionHandler.ShareeMagicMetadataUpdate) publicCollectionHandler := &api.PublicCollectionHandler{ - Controller: publicCollectionCtrl, + Controller: collectionLinkCtrl, FileCtrl: fileController, CollectionCtrl: collectionController, FileDataCtrl: fileDataCtrl, StorageBonusController: storageBonusCtrl, } + fileLinkApi.GET("/info", fileHandler.LinkInfo) + fileLinkApi.GET("/pass-info", fileHandler.PasswordInfo) + fileLinkApi.GET("/thumbnail", fileHandler.LinkThumbnail) + fileLinkApi.GET("/file", fileHandler.LinkFile) + fileLinkApi.POST("/verify-password", fileHandler.VerifyPassword) + publicCollectionAPI.GET("/files/preview/:fileID", publicCollectionHandler.GetThumbnail) publicCollectionAPI.GET("/files/download/:fileID", publicCollectionHandler.GetFile) publicCollectionAPI.GET("/files/data/fetch", publicCollectionHandler.GetFileData) @@ -770,7 +802,7 @@ func main() { setKnownAPIs(server.Routes()) setupAndStartBackgroundJobs(objectCleanupController, replicationController3, fileDataCtrl) setupAndStartCrons( - userAuthRepo, publicCollectionRepo, twoFactorRepo, passkeysRepo, fileController, taskLockingRepo, emailNotificationCtrl, + userAuthRepo, collectionLinkRepo, fileLinkRepo, twoFactorRepo, passkeysRepo, fileController, taskLockingRepo, emailNotificationCtrl, trashController, pushController, objectController, dataCleanupController, storageBonusCtrl, emergencyCtrl, embeddingController, healthCheckHandler, kexCtrl, castDb) @@ -899,7 +931,8 @@ func setupAndStartBackgroundJobs( objectCleanupController.StartClearingOrphanObjects() } -func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionRepo *repo.PublicCollectionRepository, +func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, collectionLinkRepo *public.CollectionLinkRepo, + fileLinkRepo *public.FileLinkRepository, twoFactorRepo *repo.TwoFactorRepository, passkeysRepo *passkey.Repository, fileController *controller.FileController, taskRepo *repo.TaskLockRepository, emailNotificationCtrl *email.EmailNotificationController, trashController *controller.TrashController, pushController *controller.PushController, @@ -925,7 +958,8 @@ func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionR schedule(c, "@every 24h", func() { _ = userAuthRepo.RemoveDeletedTokens(timeUtil.MicrosecondsBeforeDays(30)) _ = castDb.DeleteOldSessions(context.Background(), timeUtil.MicrosecondsBeforeDays(7)) - _ = publicCollectionRepo.CleanupAccessHistory(context.Background()) + _ = collectionLinkRepo.CleanupAccessHistory(context.Background()) + _ = fileLinkRepo.CleanupAccessHistory(context.Background()) }) schedule(c, "@every 1m", func() { diff --git a/server/configurations/local.yaml b/server/configurations/local.yaml index b6b1d567eb..a16f560e43 100644 --- a/server/configurations/local.yaml +++ b/server/configurations/local.yaml @@ -79,9 +79,14 @@ http: apps: # Default is https://albums.ente.io # - # If you're running a self hosted instance and wish to serve public links, + # If you're running a self hosted instance and wish to serve public links for photos, # set this to the URL where your albums web app is running. public-albums: + # Default is https://locker.ente.io + # + # If you're running a self-hosted instance and wish to serve public links for locker, + # set this to the URL where your albums web app is running. + public-locker: # Default is https://cast.ente.io cast: # Default is https://accounts.ente.io diff --git a/server/ente/errors.go b/server/ente/errors.go index 4a5a02fb47..2370ab7fe8 100644 --- a/server/ente/errors.go +++ b/server/ente/errors.go @@ -97,8 +97,8 @@ var ErrUserDeleted = errors.New("user account has been deleted") // ErrLockUnavailable is thrown when a lock could not be acquired var ErrLockUnavailable = errors.New("could not acquire lock") -// ErrActiveLinkAlreadyExists is thrown when the collection already has active public link -var ErrActiveLinkAlreadyExists = errors.New("Collection already has active public link") +// ErrActiveLinkAlreadyExists is thrown when an active link already exists for entity +var ErrActiveLinkAlreadyExists = errors.New("link already exists for this entity") // ErrNotImplemented indicates that the action that we tried to perform is not // available at this museum instance. e.g. this could be something that is not @@ -176,6 +176,11 @@ var ErrMaxPasskeysReached = ApiError{ Message: "Max passkeys limit reached", HttpStatusCode: http.StatusConflict, } +var ErrPassProtectedResource = ApiError{ + Code: "PASS_PROTECTED_RESOURCE", + Message: "This resource is password protected", + HttpStatusCode: http.StatusForbidden, +} var ErrCastPermissionDenied = ApiError{ Code: "CAST_PERMISSION_DENIED", diff --git a/server/ente/file_link.go b/server/ente/file_link.go new file mode 100644 index 0000000000..e817e35da0 --- /dev/null +++ b/server/ente/file_link.go @@ -0,0 +1,94 @@ +package ente + +import ( + "fmt" + "github.com/ente-io/museum/pkg/utils/time" +) + +// CreateFileUrl represents an encrypted file in the system +type CreateFileUrl struct { + FileID int64 `json:"fileID" binding:"required"` + App App `json:"app" binding:"required"` +} + +// UpdateFileUrl .. +type UpdateFileUrl struct { + LinkID string `json:"linkID" binding:"required"` + FileID int64 `json:"fileID" binding:"required"` + ValidTill *int64 `json:"validTill"` + DeviceLimit *int `json:"deviceLimit"` + PassHash *string + Nonce *string + MemLimit *int64 + OpsLimit *int64 + EnableDownload *bool `json:"enableDownload"` + DisablePassword *bool `json:"disablePassword"` +} + +func (ut *UpdateFileUrl) Validate() error { + if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil && + ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil { + return NewBadRequestWithMessage("all parameters are missing") + } + + if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) { + return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit)) + } + + if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() { + return NewBadRequestWithMessage("valid till should be greater than current timestamp") + } + + var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil + var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil + + if !(allPassParamsMissing || allPassParamsPresent) { + return NewBadRequestWithMessage("all password params should be either present or missing") + } + + if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword { + return NewBadRequestWithMessage("can not set and disable password in same request") + } + return nil +} + +type FileLinkRow struct { + LinkID string + OwnerID int64 + FileID int64 + Token string + DeviceLimit int + ValidTill int64 + IsDisabled bool + PassHash *string + Nonce *string + MemLimit *int64 + OpsLimit *int64 + EnableDownload bool + CreatedAt int64 + UpdatedAt int64 +} + +type FileUrl struct { + LinkID string `json:"linkID" binding:"required"` + URL string `json:"url" binding:"required"` + OwnerID int64 `json:"ownerID" binding:"required"` + FileID int64 `json:"fileID" binding:"required"` + ValidTill int64 `json:"validTill"` + DeviceLimit int `json:"deviceLimit"` + PasswordEnabled bool `json:"passwordEnabled"` + // Nonce contains the nonce value for the password if the link is password protected. + Nonce *string `json:"nonce,omitempty"` + MemLimit *int64 `json:"memLimit,omitempty"` + OpsLimit *int64 `json:"opsLimit,omitempty"` + EnableDownload bool `json:"enableDownload"` + CreatedAt int64 `json:"createdAt"` +} + +type FileLinkAccessContext struct { + LinkID string + IP string + UserAgent string + FileID int64 + OwnerID int64 +} diff --git a/server/ente/jwt/jwt.go b/server/ente/jwt/jwt.go index 94cfa995f2..c4d210b66c 100644 --- a/server/ente/jwt/jwt.go +++ b/server/ente/jwt/jwt.go @@ -40,13 +40,13 @@ func (w WebCommonJWTClaim) Valid() error { return nil } -// PublicAlbumPasswordClaim refer to token granted post public album password verification -type PublicAlbumPasswordClaim struct { +// LinkPasswordClaim refer to token granted post link password verification +type LinkPasswordClaim struct { PassHash string `json:"passKey"` ExpiryTime int64 `json:"expiryTime"` } -func (c PublicAlbumPasswordClaim) Valid() error { +func (c LinkPasswordClaim) Valid() error { if c.ExpiryTime < time.Microseconds() { return errors.New("token expired") } diff --git a/server/ente/public_collection.go b/server/ente/public_collection.go index eb1bd8c385..5f3867e6d0 100644 --- a/server/ente/public_collection.go +++ b/server/ente/public_collection.go @@ -3,7 +3,7 @@ package ente import ( "database/sql/driver" "encoding/json" - + "fmt" "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" ) @@ -32,6 +32,33 @@ type UpdatePublicAccessTokenRequest struct { EnableJoin *bool `json:"enableJoin"` } +func (ut *UpdatePublicAccessTokenRequest) Validate() error { + if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil && + ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil && ut.EnableCollect == nil { + return NewBadRequestWithMessage("all parameters are missing") + } + + if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) { + return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit)) + } + + if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() { + return NewBadRequestWithMessage("valid till should be greater than current timestamp") + } + + var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil + var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil + + if !(allPassParamsMissing || allPassParamsPresent) { + return NewBadRequestWithMessage("all password params should be either present or missing") + } + + if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword { + return NewBadRequestWithMessage("can not set and disable password in same request") + } + return nil +} + type VerifyPasswordRequest struct { PassHash string `json:"passHash" binding:"required"` } @@ -40,8 +67,8 @@ type VerifyPasswordResponse struct { JWTToken string `json:"jwtToken"` } -// PublicCollectionToken represents row entity for public_collection_token table -type PublicCollectionToken struct { +// CollectionLinkRow represents row entity for public_collection_token table +type CollectionLinkRow struct { ID int64 CollectionID int64 Token string @@ -57,7 +84,7 @@ type PublicCollectionToken struct { EnableJoin bool } -func (p PublicCollectionToken) CanJoin() error { +func (p CollectionLinkRow) CanJoin() error { if p.IsDisabled { return NewBadRequestWithMessage("link disabled") } diff --git a/server/migrations/103_single_file_url.down.sql b/server/migrations/103_single_file_url.down.sql new file mode 100644 index 0000000000..2efd3e0053 --- /dev/null +++ b/server/migrations/103_single_file_url.down.sql @@ -0,0 +1,3 @@ + +DROP TABLE IF EXISTS public_file_tokens_access_history; +DROP TABLE IF EXISTS public_file_tokens; diff --git a/server/migrations/103_single_file_url.up.sql b/server/migrations/103_single_file_url.up.sql new file mode 100644 index 0000000000..3145e46aad --- /dev/null +++ b/server/migrations/103_single_file_url.up.sql @@ -0,0 +1,46 @@ + + +CREATE TABLE IF NOT EXISTS public_file_tokens +( + id text primary key, + file_id bigint NOT NULL, + owner_id bigint NOT NULL, + app text NOT NULL, + access_token text not null, + valid_till bigint not null DEFAULT 0, + device_limit int not null DEFAULT 0, + is_disabled bool not null DEFAULT FALSE, + enable_download bool not null DEFAULT TRUE, + pw_hash TEXT, + pw_nonce TEXT, + mem_limit BIGINT, + ops_limit BIGINT, + created_at bigint NOT NULL DEFAULT now_utc_micro_seconds(), + updated_at bigint NOT NULL DEFAULT now_utc_micro_seconds() +); + + +CREATE OR REPLACE TRIGGER update_public_file_tokens_updated_at + BEFORE UPDATE + ON public_file_tokens + FOR EACH ROW +EXECUTE PROCEDURE + trigger_updated_at_microseconds_column(); + + +CREATE TABLE IF NOT EXISTS public_file_tokens_access_history +( + id text NOT NULL, + ip text not null, + user_agent text not null, + created_at bigint NOT NULL DEFAULT now_utc_micro_seconds(), + CONSTRAINT unique_access_id_ip_ua UNIQUE (id, ip, user_agent), + CONSTRAINT fk_public_file_history_token_id + FOREIGN KEY (id) + REFERENCES public_file_tokens (id) + ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS public_file_token_unique_idx ON public_file_tokens (access_token) WHERE is_disabled = FALSE; +CREATE INDEX IF NOT EXISTS public_file_tokens_owner_id_updated_at_idx ON public_file_tokens (owner_id, updated_at); +CREATE UNIQUE INDEX IF NOT EXISTS public_active_file_link_unique_idx ON public_file_tokens (file_id, is_disabled) WHERE is_disabled = FALSE; diff --git a/server/pkg/api/collection.go b/server/pkg/api/collection.go index 9318f5c329..6c198652c7 100644 --- a/server/pkg/api/collection.go +++ b/server/pkg/api/collection.go @@ -10,7 +10,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/ente-io/museum/ente" - "github.com/ente-io/museum/pkg/controller" "github.com/ente-io/museum/pkg/utils/auth" "github.com/ente-io/museum/pkg/utils/handler" "github.com/ente-io/museum/pkg/utils/time" @@ -172,35 +171,6 @@ func (h *CollectionHandler) UpdateShareURL(c *gin.Context) { handler.Error(c, stacktrace.Propagate(err, "")) return } - if req.DeviceLimit == nil && req.ValidTill == nil && req.DisablePassword == nil && - req.Nonce == nil && req.PassHash == nil && req.EnableDownload == nil && req.EnableCollect == nil { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all parameters are missing")) - return - } - - if req.DeviceLimit != nil && (*req.DeviceLimit < 0 || *req.DeviceLimit > controller.DeviceLimitThreshold) { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("device limit: %d out of range", *req.DeviceLimit))) - return - } - - if req.ValidTill != nil && *req.ValidTill != 0 && *req.ValidTill < time.Microseconds() { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "valid till should be greater than current timestamp")) - return - } - - var allPassParamsMissing = req.Nonce == nil && req.PassHash == nil && req.MemLimit == nil && req.OpsLimit == nil - var allPassParamsPresent = req.Nonce != nil && req.PassHash != nil && req.MemLimit != nil && req.OpsLimit != nil - - if !(allPassParamsMissing || allPassParamsPresent) { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all password params should be either present or missing")) - return - } - - if allPassParamsPresent && req.DisablePassword != nil && *req.DisablePassword { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "can not set and disable password in same request")) - return - } - response, err := h.Controller.UpdateShareURL(c, auth.GetUserID(c.Request.Header), req) if err != nil { handler.Error(c, stacktrace.Propagate(err, "")) diff --git a/server/pkg/api/file.go b/server/pkg/api/file.go index 2e15ade325..4ec205d1bb 100644 --- a/server/pkg/api/file.go +++ b/server/pkg/api/file.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/ente-io/museum/pkg/controller/file_copy" "github.com/ente-io/museum/pkg/controller/filedata" + "github.com/ente-io/museum/pkg/controller/public" "net/http" "os" "strconv" @@ -24,6 +25,7 @@ import ( // FileHandler exposes request handlers for all encrypted file related requests type FileHandler struct { Controller *controller.FileController + FileUrlCtrl *public.FileLinkController FileCopyCtrl *file_copy.FileCopyController FileDataCtrl *filedata.Controller } diff --git a/server/pkg/api/file_link.go b/server/pkg/api/file_link.go new file mode 100644 index 0000000000..d243a87532 --- /dev/null +++ b/server/pkg/api/file_link.go @@ -0,0 +1,141 @@ +package api + +import ( + "github.com/ente-io/museum/ente" + "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/handler" + "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" + "net/http" + "strconv" +) + +// ShareUrl a sharable url for the file +func (h *FileHandler) ShareUrl(c *gin.Context) { + var file ente.CreateFileUrl + if err := c.ShouldBindJSON(&file); err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + + response, err := h.FileUrlCtrl.CreateLink(c, file) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, response) +} + +func (h *FileHandler) LinkInfo(c *gin.Context) { + resp, err := h.FileUrlCtrl.Info(c) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{ + "file": resp, + }) +} + +func (h *FileHandler) PasswordInfo(c *gin.Context) { + resp, err := h.FileUrlCtrl.PassInfo(c) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{ + "nonce": resp.Nonce, + "opsLimit": resp.OpsLimit, + "memLimit": resp.MemLimit, + }) +} + +func (h *FileHandler) LinkThumbnail(c *gin.Context) { + linkCtx := auth.MustGetFileLinkAccessContext(c) + url, err := h.Controller.GetThumbnailURL(c, linkCtx.OwnerID, linkCtx.FileID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.Redirect(http.StatusTemporaryRedirect, url) +} + +func (h *FileHandler) LinkFile(c *gin.Context) { + linkCtx := auth.MustGetFileLinkAccessContext(c) + url, err := h.Controller.GetFileURL(c, linkCtx.OwnerID, linkCtx.FileID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.Redirect(http.StatusTemporaryRedirect, url) +} + +func (h *FileHandler) DisableUrl(c *gin.Context) { + cID, err := strconv.ParseInt(c.Param("fileID"), 10, 64) + if err != nil { + handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "")) + return + } + err = h.FileUrlCtrl.Disable(c, cID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{}) +} + +func (h *FileHandler) GetUrls(c *gin.Context) { + sinceTime, err := strconv.ParseInt(c.Query("sinceTime"), 10, 64) + if err != nil { + handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "sinceTime parsing failed")) + return + } + limit := 500 + if c.Query("limit") != "" { + limit, err = strconv.Atoi(c.Query("limit")) + if err != nil || limit < 1 { + handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "")) + return + } + } + response, err := h.FileUrlCtrl.GetUrls(c, sinceTime, int64(limit)) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{ + "diff": response, + }) +} + +// VerifyPassword verifies the password for given link access token and return signed jwt token if it's valid +func (h *FileHandler) VerifyPassword(c *gin.Context) { + var req ente.VerifyPasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + resp, err := h.FileUrlCtrl.VerifyPassword(c, req) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, resp) +} + +// UpdateFileURL updates the share URL for a file +func (h *FileHandler) UpdateFileURL(c *gin.Context) { + var req ente.UpdateFileUrl + if err := c.ShouldBindJSON(&req); err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + response, err := h.FileUrlCtrl.UpdateSharedUrl(c, req) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{ + "result": response, + }) +} diff --git a/server/pkg/api/public_collection.go b/server/pkg/api/public_collection.go index 9f61ba788e..81e1836f90 100644 --- a/server/pkg/api/public_collection.go +++ b/server/pkg/api/public_collection.go @@ -5,6 +5,7 @@ import ( fileData "github.com/ente-io/museum/ente/filedata" "github.com/ente-io/museum/pkg/controller/collections" "github.com/ente-io/museum/pkg/controller/filedata" + "github.com/ente-io/museum/pkg/controller/public" "net/http" "strconv" @@ -20,7 +21,7 @@ import ( // PublicCollectionHandler exposes request handlers for publicly accessible collections type PublicCollectionHandler struct { - Controller *controller.PublicCollectionController + Controller *public.CollectionLinkController FileCtrl *controller.FileController CollectionCtrl *collections.CollectionController FileDataCtrl *filedata.Controller diff --git a/server/pkg/controller/collections/collection.go b/server/pkg/controller/collections/collection.go index 5f096bc133..5b86d38d0a 100644 --- a/server/pkg/controller/collections/collection.go +++ b/server/pkg/controller/collections/collection.go @@ -6,6 +6,7 @@ import ( "github.com/ente-io/museum/pkg/controller" "github.com/ente-io/museum/pkg/controller/access" "github.com/ente-io/museum/pkg/controller/email" + "github.com/ente-io/museum/pkg/controller/public" "github.com/ente-io/museum/pkg/repo/cast" "github.com/ente-io/museum/pkg/utils/array" "github.com/ente-io/museum/pkg/utils/auth" @@ -24,16 +25,16 @@ const ( // CollectionController encapsulates logic that deals with collections type CollectionController struct { - PublicCollectionCtrl *controller.PublicCollectionController - EmailCtrl *email.EmailNotificationController - AccessCtrl access.Controller - BillingCtrl *controller.BillingController - CollectionRepo *repo.CollectionRepository - UserRepo *repo.UserRepository - FileRepo *repo.FileRepository - QueueRepo *repo.QueueRepository - CastRepo *cast.Repository - TaskRepo *repo.TaskLockRepository + CollectionLinkCtrl *public.CollectionLinkController + EmailCtrl *email.EmailNotificationController + AccessCtrl access.Controller + BillingCtrl *controller.BillingController + CollectionRepo *repo.CollectionRepository + UserRepo *repo.UserRepository + FileRepo *repo.FileRepository + QueueRepo *repo.QueueRepository + CastRepo *cast.Repository + TaskRepo *repo.TaskLockRepository } // Create creates a collection @@ -148,7 +149,7 @@ func (c *CollectionController) TrashV3(ctx *gin.Context, req ente.TrashCollectio } } - err = c.PublicCollectionCtrl.Disable(ctx, cID) + err = c.CollectionLinkCtrl.Disable(ctx, cID) if err != nil { return stacktrace.Propagate(err, "failed to disabled public share url") } @@ -209,7 +210,7 @@ func (c *CollectionController) HandleAccountDeletion(ctx context.Context, userID if err != nil { return stacktrace.Propagate(err, "failed to revoke cast token for user") } - err = c.PublicCollectionCtrl.HandleAccountDeletion(ctx, userID, logger) + err = c.CollectionLinkCtrl.HandleAccountDeletion(ctx, userID, logger) return stacktrace.Propagate(err, "") } diff --git a/server/pkg/controller/collections/share.go b/server/pkg/controller/collections/share.go index ced64f0fdf..6002c7b493 100644 --- a/server/pkg/controller/collections/share.go +++ b/server/pkg/controller/collections/share.go @@ -70,21 +70,21 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec if !collection.AllowSharing() { return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("joining %s is not allowed", collection.Type)) } - publicCollectionToken, err := c.PublicCollectionCtrl.GetActivePublicCollectionToken(ctx, req.CollectionID) + collectionLinkToken, err := c.CollectionLinkCtrl.GetActiveCollectionLinkToken(ctx, req.CollectionID) if err != nil { return stacktrace.Propagate(err, "") } - if canJoin := publicCollectionToken.CanJoin(); canJoin != nil { + if canJoin := collectionLinkToken.CanJoin(); canJoin != nil { return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("can not join collection: %s", canJoin.Error())) } accessToken := auth.GetAccessToken(ctx) - if publicCollectionToken.Token != accessToken { + if collectionLinkToken.Token != accessToken { return stacktrace.Propagate(ente.ErrPermissionDenied, "token doesn't match collection") } - if publicCollectionToken.PassHash != nil && *publicCollectionToken.PassHash != "" { + if collectionLinkToken.PassHash != nil && *collectionLinkToken.PassHash != "" { accessTokenJWT := auth.GetAccessTokenJWT(ctx) - if passCheckErr := c.PublicCollectionCtrl.ValidateJWTToken(ctx, accessTokenJWT, *publicCollectionToken.PassHash); passCheckErr != nil { + if passCheckErr := c.CollectionLinkCtrl.ValidateJWTToken(ctx, accessTokenJWT, *collectionLinkToken.PassHash); passCheckErr != nil { return stacktrace.Propagate(passCheckErr, "") } } @@ -93,7 +93,7 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec return stacktrace.Propagate(err, "") } role := ente.VIEWER - if publicCollectionToken.EnableCollect { + if collectionLinkToken.EnableCollect { role = ente.COLLABORATOR } joinErr := c.CollectionRepo.Share(req.CollectionID, collection.Owner.ID, userID, req.EncryptedKey, role, time.Microseconds()) @@ -197,7 +197,7 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } - response, err := c.PublicCollectionCtrl.CreateAccessToken(ctx, req) + response, err := c.CollectionLinkCtrl.CreateLink(ctx, req) if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } @@ -205,20 +205,26 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e } // UpdateShareURL updates the shared url configuration -func (c *CollectionController) UpdateShareURL(ctx context.Context, userID int64, req ente.UpdatePublicAccessTokenRequest) ( - ente.PublicURL, error) { +func (c *CollectionController) UpdateShareURL( + ctx context.Context, + userID int64, + req ente.UpdatePublicAccessTokenRequest, +) (*ente.PublicURL, error) { + if err := req.Validate(); err != nil { + return nil, stacktrace.Propagate(err, "") + } if err := c.verifyOwnership(req.CollectionID, userID); err != nil { - return ente.PublicURL{}, stacktrace.Propagate(err, "") + return nil, stacktrace.Propagate(err, "") } err := c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true) if err != nil { - return ente.PublicURL{}, stacktrace.Propagate(err, "") + return nil, stacktrace.Propagate(err, "") } - response, err := c.PublicCollectionCtrl.UpdateSharedUrl(ctx, req) + response, err := c.CollectionLinkCtrl.UpdateSharedUrl(ctx, req) if err != nil { - return ente.PublicURL{}, stacktrace.Propagate(err, "") + return nil, stacktrace.Propagate(err, "") } - return response, nil + return &response, nil } // DisableSharedURL disable a public auth-token for the given collectionID @@ -226,7 +232,7 @@ func (c *CollectionController) DisableSharedURL(ctx context.Context, userID int6 if err := c.verifyOwnership(cID, userID); err != nil { return stacktrace.Propagate(err, "") } - err := c.PublicCollectionCtrl.Disable(ctx, cID) + err := c.CollectionLinkCtrl.Disable(ctx, cID) return stacktrace.Propagate(err, "") } diff --git a/server/pkg/controller/public_collection.go b/server/pkg/controller/public/collection_link.go similarity index 69% rename from server/pkg/controller/public_collection.go rename to server/pkg/controller/public/collection_link.go index 022a08812e..ead744bfbf 100644 --- a/server/pkg/controller/public_collection.go +++ b/server/pkg/controller/public/collection_link.go @@ -1,12 +1,13 @@ -package controller +package public import ( "context" "errors" "fmt" + "github.com/ente-io/museum/pkg/controller" + "github.com/ente-io/museum/pkg/repo/public" "github.com/ente-io/museum/ente" - enteJWT "github.com/ente-io/museum/ente/jwt" emailCtrl "github.com/ente-io/museum/pkg/controller/email" "github.com/ente-io/museum/pkg/repo" "github.com/ente-io/museum/pkg/utils/auth" @@ -14,7 +15,6 @@ import ( "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt" "github.com/lithammer/shortuuid/v3" "github.com/sirupsen/logrus" ) @@ -49,23 +49,24 @@ const ( AbuseLimitExceededTemplate = "report_limit_exceeded_alert.html" ) -// PublicCollectionController controls share collection operations -type PublicCollectionController struct { - FileController *FileController +// CollectionLinkController controls share collection operations +type CollectionLinkController struct { + FileController *controller.FileController EmailNotificationCtrl *emailCtrl.EmailNotificationController - PublicCollectionRepo *repo.PublicCollectionRepository + CollectionLinkRepo *public.CollectionLinkRepo + FileLinkRepo *public.FileLinkRepository CollectionRepo *repo.CollectionRepository UserRepo *repo.UserRepository JwtSecret []byte } -func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req ente.CreatePublicAccessTokenRequest) (ente.PublicURL, error) { +func (c *CollectionLinkController) CreateLink(ctx context.Context, req ente.CreatePublicAccessTokenRequest) (ente.PublicURL, error) { accessToken := shortuuid.New()[0:AccessTokenLength] - err := c.PublicCollectionRepo. + err := c.CollectionLinkRepo. Insert(ctx, req.CollectionID, accessToken, req.ValidTill, req.DeviceLimit, req.EnableCollect, req.EnableJoin) if err != nil { if errors.Is(err, ente.ErrActiveLinkAlreadyExists) { - collectionToPubUrlMap, err2 := c.PublicCollectionRepo.GetCollectionToActivePublicURLMap(ctx, []int64{req.CollectionID}) + collectionToPubUrlMap, err2 := c.CollectionLinkRepo.GetCollectionToActivePublicURLMap(ctx, []int64{req.CollectionID}) if err2 != nil { return ente.PublicURL{}, stacktrace.Propagate(err2, "") } @@ -81,7 +82,7 @@ func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req } } response := ente.PublicURL{ - URL: c.PublicCollectionRepo.GetAlbumUrl(accessToken), + URL: c.CollectionLinkRepo.GetAlbumUrl(accessToken), ValidTill: req.ValidTill, DeviceLimit: req.DeviceLimit, EnableDownload: true, @@ -91,11 +92,11 @@ func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req return response, nil } -func (c *PublicCollectionController) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) { - return c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, collectionID) +func (c *CollectionLinkController) GetActiveCollectionLinkToken(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) { + return c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, collectionID) } -func (c *PublicCollectionController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) { +func (c *CollectionLinkController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) { collection, err := c.GetPublicCollection(ctx, true) if err != nil { return ente.File{}, stacktrace.Propagate(err, "") @@ -118,13 +119,13 @@ func (c *PublicCollectionController) CreateFile(ctx *gin.Context, file ente.File } // Disable all public accessTokens generated for the given cID till date. -func (c *PublicCollectionController) Disable(ctx context.Context, cID int64) error { - err := c.PublicCollectionRepo.DisableSharing(ctx, cID) +func (c *CollectionLinkController) Disable(ctx context.Context, cID int64) error { + err := c.CollectionLinkRepo.DisableSharing(ctx, cID) return stacktrace.Propagate(err, "") } -func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) { - publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, req.CollectionID) +func (c *CollectionLinkController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) { + publicCollectionToken, err := c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, req.CollectionID) if err != nil { return ente.PublicURL{}, err } @@ -154,12 +155,12 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en if req.EnableJoin != nil { publicCollectionToken.EnableJoin = *req.EnableJoin } - err = c.PublicCollectionRepo.UpdatePublicCollectionToken(ctx, publicCollectionToken) + err = c.CollectionLinkRepo.UpdatePublicCollectionToken(ctx, publicCollectionToken) if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } return ente.PublicURL{ - URL: c.PublicCollectionRepo.GetAlbumUrl(publicCollectionToken.Token), + URL: c.CollectionLinkRepo.GetAlbumUrl(publicCollectionToken.Token), DeviceLimit: publicCollectionToken.DeviceLimit, ValidTill: publicCollectionToken.ValidTill, EnableDownload: publicCollectionToken.EnableDownload, @@ -176,58 +177,23 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en // used by the client to pass in other requests for public collection. // Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force // attack for guessing password. -func (c *PublicCollectionController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) { +func (c *CollectionLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) { accessContext := auth.MustGetPublicAccessContext(ctx) - publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, accessContext.CollectionID) + collectionLinkRow, err := c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, accessContext.CollectionID) if err != nil { return nil, stacktrace.Propagate(err, "failed to get public collection info") } - if publicCollectionToken.PassHash == nil || *publicCollectionToken.PassHash == "" { - return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link") - } - if req.PassHash != *publicCollectionToken.PassHash { - return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link") - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.PublicAlbumPasswordClaim{ - PassHash: req.PassHash, - ExpiryTime: time.NDaysFromNow(365), - }) - // Sign and get the complete encoded token as a string using the secret - tokenString, err := token.SignedString(c.JwtSecret) - - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - return &ente.VerifyPasswordResponse{ - JWTToken: tokenString, - }, nil + return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req) } -func (c *PublicCollectionController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error { - token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.PublicAlbumPasswordClaim{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil - } - return c.JwtSecret, nil - }) - if err != nil { - return stacktrace.Propagate(err, "JWT parsed failed") - } - claims, ok := token.Claims.(*enteJWT.PublicAlbumPasswordClaim) - - if !ok { - return stacktrace.Propagate(errors.New("no claim in jwt token"), "") - } - if token.Valid && claims.PassHash == passwordHash { - return nil - } - return ente.ErrInvalidPassword +func (c *CollectionLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error { + return validateJWTToken(c.JwtSecret, jwtToken, passwordHash) } // ReportAbuse captures abuse report for a publicly shared collection. // It will also disable the accessToken for the collection if total abuse reports for the said collection // reaches AutoDisableAbuseThreshold -func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.AbuseReportRequest) error { +func (c *CollectionLinkController) ReportAbuse(ctx *gin.Context, req ente.AbuseReportRequest) error { accessContext := auth.MustGetPublicAccessContext(ctx) readableReason, found := AllowedReasons[req.Reason] if !found { @@ -235,11 +201,11 @@ func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.Abus } logrus.WithField("collectionID", accessContext.CollectionID).Error("CRITICAL: received abuse report") - err := c.PublicCollectionRepo.RecordAbuseReport(ctx, accessContext, req.URL, req.Reason, req.Details) + err := c.CollectionLinkRepo.RecordAbuseReport(ctx, accessContext, req.URL, req.Reason, req.Details) if err != nil { return stacktrace.Propagate(err, "") } - count, err := c.PublicCollectionRepo.GetAbuseReportCount(ctx, accessContext) + count, err := c.CollectionLinkRepo.GetAbuseReportCount(ctx, accessContext) if err != nil { return stacktrace.Propagate(err, "") } @@ -253,7 +219,7 @@ func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.Abus return nil } -func (c *PublicCollectionController) onAbuseReportReceived(collectionID int64, report ente.AbuseReportRequest, readableReason string, abuseCount int64) { +func (c *CollectionLinkController) onAbuseReportReceived(collectionID int64, report ente.AbuseReportRequest, readableReason string, abuseCount int64) { collection, err := c.CollectionRepo.Get(collectionID) if err != nil { logrus.Error("Could not get collection for abuse report") @@ -292,9 +258,9 @@ func (c *PublicCollectionController) onAbuseReportReceived(collectionID int64, r } } -func (c *PublicCollectionController) HandleAccountDeletion(ctx context.Context, userID int64, logger *logrus.Entry) error { +func (c *CollectionLinkController) HandleAccountDeletion(ctx context.Context, userID int64, logger *logrus.Entry) error { logger.Info("updating public collection on account deletion") - collectionIDs, err := c.PublicCollectionRepo.GetActivePublicTokenForUser(ctx, userID) + collectionIDs, err := c.CollectionLinkRepo.GetActivePublicTokenForUser(ctx, userID) if err != nil { return stacktrace.Propagate(err, "") } @@ -305,12 +271,12 @@ func (c *PublicCollectionController) HandleAccountDeletion(ctx context.Context, return stacktrace.Propagate(err, "") } } - return nil + return c.FileLinkRepo.DisableLinksForUser(ctx, userID) } // GetPublicCollection will return collection info for a public url. // is mustAllowCollect is set to true but the underlying collection doesn't allow uploading -func (c *PublicCollectionController) GetPublicCollection(ctx *gin.Context, mustAllowCollect bool) (ente.Collection, error) { +func (c *CollectionLinkController) GetPublicCollection(ctx *gin.Context, mustAllowCollect bool) (ente.Collection, error) { accessContext := auth.MustGetPublicAccessContext(ctx) collection, err := c.CollectionRepo.Get(accessContext.CollectionID) if err != nil { diff --git a/server/pkg/controller/public/file_link.go b/server/pkg/controller/public/file_link.go new file mode 100644 index 0000000000..be015bc719 --- /dev/null +++ b/server/pkg/controller/public/file_link.go @@ -0,0 +1,162 @@ +package public + +import ( + "github.com/ente-io/museum/ente" + "github.com/ente-io/museum/pkg/controller" + "github.com/ente-io/museum/pkg/repo" + "github.com/ente-io/museum/pkg/repo/public" + "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" + "github.com/lithammer/shortuuid/v3" +) + +// FileLinkController controls share collection operations +type FileLinkController struct { + FileController *controller.FileController + FileLinkRepo *public.FileLinkRepository + FileRepo *repo.FileRepository + JwtSecret []byte +} + +func (c *FileLinkController) CreateLink(ctx *gin.Context, req ente.CreateFileUrl) (*ente.FileUrl, error) { + actorUserID := auth.GetUserID(ctx.Request.Header) + app := auth.GetApp(ctx) + if req.App != app { + return nil, stacktrace.Propagate(ente.NewBadRequestWithMessage("app mismatch"), "app mismatch") + } + file, err := c.FileRepo.GetFileAttributes(req.FileID) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get file attributes") + } + if actorUserID != file.OwnerID { + return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "") + } + accessToken := shortuuid.New()[0:AccessTokenLength] + _, err = c.FileLinkRepo.Insert(ctx, req.FileID, actorUserID, accessToken, app) + if err == nil || err == ente.ErrActiveLinkAlreadyExists { + row, rowErr := c.FileLinkRepo.GetFileUrlRowByFileID(ctx, req.FileID) + if rowErr != nil { + return nil, stacktrace.Propagate(rowErr, "failed to get active file url token") + } + return c.mapRowToFileUrl(ctx, row), nil + } + return nil, stacktrace.Propagate(err, "failed to create public file link") +} + +// Disable all public accessTokens generated for the given fileID till date. +func (c *FileLinkController) Disable(ctx *gin.Context, fileID int64) error { + userID := auth.GetUserID(ctx.Request.Header) + file, err := c.FileRepo.GetFileAttributes(fileID) + if err != nil { + return stacktrace.Propagate(err, "failed to get file attributes") + } + if userID != file.OwnerID { + return stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "") + } + return c.FileLinkRepo.DisableLinkForFiles(ctx, []int64{fileID}) +} + +func (c *FileLinkController) GetUrls(ctx *gin.Context, sinceTime int64, limit int64) ([]*ente.FileUrl, error) { + userID := auth.GetUserID(ctx.Request.Header) + app := auth.GetApp(ctx) + fileLinks, err := c.FileLinkRepo.GetFileUrls(ctx, userID, sinceTime, limit, app) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get file urls") + } + var fileUrls []*ente.FileUrl + for _, row := range fileLinks { + fileUrls = append(fileUrls, c.mapRowToFileUrl(ctx, row)) + } + return fileUrls, nil +} + +func (c *FileLinkController) UpdateSharedUrl(ctx *gin.Context, req ente.UpdateFileUrl) (*ente.FileUrl, error) { + if err := req.Validate(); err != nil { + return nil, stacktrace.Propagate(err, "invalid request") + } + fileLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, req.FileID) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get file link info") + } + if fileLinkRow.OwnerID != auth.GetUserID(ctx.Request.Header) { + return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "") + } + if req.ValidTill != nil { + fileLinkRow.ValidTill = *req.ValidTill + } + if req.DeviceLimit != nil { + fileLinkRow.DeviceLimit = *req.DeviceLimit + } + if req.PassHash != nil && req.Nonce != nil && req.OpsLimit != nil && req.MemLimit != nil { + fileLinkRow.PassHash = req.PassHash + fileLinkRow.Nonce = req.Nonce + fileLinkRow.OpsLimit = req.OpsLimit + fileLinkRow.MemLimit = req.MemLimit + } else if req.DisablePassword != nil && *req.DisablePassword { + fileLinkRow.PassHash = nil + fileLinkRow.Nonce = nil + fileLinkRow.OpsLimit = nil + fileLinkRow.MemLimit = nil + } + if req.EnableDownload != nil { + fileLinkRow.EnableDownload = *req.EnableDownload + } + + err = c.FileLinkRepo.UpdateLink(ctx, *fileLinkRow) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return c.mapRowToFileUrl(ctx, fileLinkRow), nil +} + +func (c *FileLinkController) Info(ctx *gin.Context) (*ente.File, error) { + accessContext := auth.MustGetFileLinkAccessContext(ctx) + return c.FileRepo.GetFileAttributes(accessContext.FileID) +} + +func (c *FileLinkController) PassInfo(ctx *gin.Context) (*ente.FileLinkRow, error) { + accessContext := auth.MustGetFileLinkAccessContext(ctx) + return c.FileLinkRepo.GetFileUrlRowByFileID(ctx, accessContext.FileID) +} + +// VerifyPassword verifies if the user has provided correct pw hash. If yes, it returns a signed jwt token which can be +// used by the client to pass in other requests for public collection. +// Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force +// attack for guessing password. +func (c *FileLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) { + accessContext := auth.MustGetFileLinkAccessContext(ctx) + collectionLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, accessContext.FileID) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get public collection info") + } + return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req) +} + +func (c *FileLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error { + return validateJWTToken(c.JwtSecret, jwtToken, passwordHash) +} + +func (c *FileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl { + app := auth.GetApp(ctx) + var url string + if app == ente.Locker { + url = c.FileLinkRepo.LockerFileLink(row.Token) + } else { + url = c.FileLinkRepo.PhotoLink(row.Token) + } + return &ente.FileUrl{ + LinkID: row.LinkID, + FileID: row.FileID, + URL: url, + OwnerID: row.OwnerID, + ValidTill: row.ValidTill, + DeviceLimit: row.DeviceLimit, + PasswordEnabled: row.PassHash != nil, + Nonce: row.Nonce, + OpsLimit: row.OpsLimit, + MemLimit: row.MemLimit, + EnableDownload: row.EnableDownload, + CreatedAt: row.CreatedAt, + } +} diff --git a/server/pkg/controller/public/link_common.go b/server/pkg/controller/public/link_common.go new file mode 100644 index 0000000000..9a56334a0e --- /dev/null +++ b/server/pkg/controller/public/link_common.go @@ -0,0 +1,54 @@ +package public + +import ( + "errors" + "fmt" + "github.com/ente-io/museum/ente" + enteJWT "github.com/ente-io/museum/ente/jwt" + "github.com/ente-io/museum/pkg/utils/time" + "github.com/ente-io/stacktrace" + "github.com/golang-jwt/jwt" +) + +func validateJWTToken(secret []byte, jwtToken string, passwordHash string) error { + token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.LinkPasswordClaim{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil + } + return secret, nil + }) + if err != nil { + return stacktrace.Propagate(err, "JWT parsed failed") + } + claims, ok := token.Claims.(*enteJWT.LinkPasswordClaim) + + if !ok { + return stacktrace.Propagate(errors.New("no claim in jwt token"), "") + } + if token.Valid && claims.PassHash == passwordHash { + return nil + } + return ente.ErrInvalidPassword +} + +func verifyPassword(secret []byte, expectedPassHash *string, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) { + if expectedPassHash == nil || *expectedPassHash == "" { + return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link") + } + if req.PassHash != *expectedPassHash { + return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link") + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.LinkPasswordClaim{ + PassHash: req.PassHash, + ExpiryTime: time.NDaysFromNow(365), + }) + // Sign and get the complete encoded token as a string using the secret + tokenString, err := token.SignedString(secret) + + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return &ente.VerifyPasswordResponse{ + JWTToken: tokenString, + }, nil +} diff --git a/server/pkg/middleware/access_token.go b/server/pkg/middleware/collection_link.go similarity index 81% rename from server/pkg/middleware/access_token.go rename to server/pkg/middleware/collection_link.go index 702af77db8..5b11a5ec07 100644 --- a/server/pkg/middleware/access_token.go +++ b/server/pkg/middleware/collection_link.go @@ -5,6 +5,8 @@ import ( "context" "crypto/sha256" "fmt" + public2 "github.com/ente-io/museum/pkg/controller/public" + "github.com/ente-io/museum/pkg/repo/public" "net/http" "github.com/ente-io/museum/ente" @@ -24,20 +26,20 @@ import ( var passwordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"} var whitelistedCollectionShareIDs = []int64{111} -// AccessTokenMiddleware intercepts and authenticates incoming requests -type AccessTokenMiddleware struct { - PublicCollectionRepo *repo.PublicCollectionRepository - PublicCollectionCtrl *controller.PublicCollectionController +// CollectionLinkMiddleware intercepts and authenticates incoming requests +type CollectionLinkMiddleware struct { + CollectionLinkRepo *public.CollectionLinkRepo + PublicCollectionCtrl *public2.CollectionLinkController CollectionRepo *repo.CollectionRepository Cache *cache.Cache BillingCtrl *controller.BillingController DiscordController *discord.DiscordController } -// AccessTokenAuthMiddleware returns a middle ware that extracts the `X-Auth-Access-Token` +// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token` // within the header of a request and uses it to validate the access token and set the // ente.PublicAccessContext with auth.PublicAccessKey as key -func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc { +func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc { return func(c *gin.Context) { accessToken := auth.GetAccessToken(c) if accessToken == "" { @@ -52,7 +54,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":") cachedValue, cacheHit := m.Cache.Get(cacheKey) if !cacheHit { - publicCollectionSummary, err = m.PublicCollectionRepo.GetCollectionSummaryByToken(c, accessToken) + publicCollectionSummary, err = m.CollectionLinkRepo.GetCollectionSummaryByToken(c, accessToken) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) return @@ -112,7 +114,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g c.Next() } } -func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error { +func (m *CollectionLinkMiddleware) validateOwnersSubscription(cID int64) error { userID, err := m.CollectionRepo.GetOwnerID(cID) if err != nil { return stacktrace.Propagate(err, "") @@ -120,7 +122,7 @@ func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error { return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false) } -func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context, +func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context, collectionSummary ente.PublicCollectionSummary, ip string, ua string) (bool, error) { // skip deviceLimit check & record keeping for requests via CF worker if network.IsCFWorkerIP(ip) { @@ -128,7 +130,7 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context, } sharedID := collectionSummary.ID - hasAccessedInPast, err := m.PublicCollectionRepo.AccessedInPast(ctx, sharedID, ip, ua) + hasAccessedInPast, err := m.CollectionLinkRepo.AccessedInPast(ctx, sharedID, ip, ua) if err != nil { return false, stacktrace.Propagate(err, "") } @@ -136,17 +138,17 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context, if hasAccessedInPast { return false, nil } - count, err := m.PublicCollectionRepo.GetUniqueAccessCount(ctx, sharedID) + count, err := m.CollectionLinkRepo.GetUniqueAccessCount(ctx, sharedID) if err != nil { return false, stacktrace.Propagate(err, "failed to get unique access count") } deviceLimit := int64(collectionSummary.DeviceLimit) - if deviceLimit == controller.DeviceLimitThreshold { - deviceLimit = controller.DeviceLimitThresholdMultiplier * controller.DeviceLimitThreshold + if deviceLimit == public2.DeviceLimitThreshold { + deviceLimit = public2.DeviceLimitThresholdMultiplier * public2.DeviceLimitThreshold } - if count >= controller.DeviceLimitWarningThreshold { + if count >= public2.DeviceLimitWarningThreshold { if !array.Int64InList(sharedID, whitelistedCollectionShareIDs) { m.DiscordController.NotifyPotentialAbuse( fmt.Sprintf("Album exceeds warning threshold: {CollectionID: %d, ShareID: %d}", @@ -157,12 +159,12 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context, if deviceLimit > 0 && count >= deviceLimit { return true, nil } - err = m.PublicCollectionRepo.RecordAccessHistory(ctx, sharedID, ip, ua) + err = m.CollectionLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua) return false, stacktrace.Propagate(err, "failed to record access history") } // validatePassword will verify if the user is provided correct password for the public album -func (m *AccessTokenMiddleware) validatePassword(c *gin.Context, reqPath string, +func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath string, collectionSummary ente.PublicCollectionSummary) error { if array.StringInList(reqPath, passwordWhiteListedURLs) { return nil diff --git a/server/pkg/middleware/file_link.go b/server/pkg/middleware/file_link.go new file mode 100644 index 0000000000..72b095bbc0 --- /dev/null +++ b/server/pkg/middleware/file_link.go @@ -0,0 +1,168 @@ +package middleware + +import ( + "context" + "fmt" + publicCtrl "github.com/ente-io/museum/pkg/controller/public" + "github.com/ente-io/museum/pkg/repo/public" + "github.com/ente-io/museum/pkg/utils/array" + "net/http" + + "github.com/ente-io/museum/ente" + "github.com/ente-io/museum/pkg/controller" + "github.com/ente-io/museum/pkg/controller/discord" + "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/network" + "github.com/ente-io/museum/pkg/utils/time" + "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" + "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" +) + +var filePasswordWhiteListedURLs = []string{"/file-link/pass-info", "/file-link/verify-password"} + +// FileLinkMiddleware intercepts and authenticates incoming requests +type FileLinkMiddleware struct { + FileLinkRepo *public.FileLinkRepository + FileLinkCtrl *publicCtrl.FileLinkController + Cache *cache.Cache + BillingCtrl *controller.BillingController + DiscordController *discord.DiscordController +} + +// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token` +// within the header of a request and uses it to validate the access token and set the +// ente.PublicAccessContext with auth.PublicAccessKey as key +func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc { + return func(c *gin.Context) { + accessToken := auth.GetAccessToken(c) + if accessToken == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing accessToken"}) + return + } + clientIP := network.GetClientIP(c) + userAgent := c.GetHeader("User-Agent") + + cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":") + cachedValue, cacheHit := m.Cache.Get(cacheKey) + var fileLinkRow *ente.FileLinkRow + var err error + if !cacheHit { + fileLinkRow, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken) + if err != nil { + logrus.WithError(err).Info("failed to get file link row by token") + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + return + } + if fileLinkRow.IsDisabled { + c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"}) + return + } + // validate if user still has active paid subscription + if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(fileLinkRow.OwnerID, true); err != nil { + logrus.WithError(err).Info("failed to verify active paid subscription") + c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"}) + return + } + + // validate device limit + reached, limitErr := m.isDeviceLimitReached(c, fileLinkRow, clientIP, userAgent) + if limitErr != nil { + logrus.WithError(limitErr).Error("failed to check device limit") + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"}) + return + } + if reached { + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "reached device limit"}) + return + } + } else { + fileLinkRow = cachedValue.(*ente.FileLinkRow) + } + + if fileLinkRow.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry + fileLinkRow.ValidTill < time.Microseconds() { + c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"}) + return + } + + // checks password protected public collection + if fileLinkRow.PassHash != nil && *fileLinkRow.PassHash != "" { + reqPath := urlSanitizer(c) + if err = m.validatePassword(c, reqPath, fileLinkRow); err != nil { + logrus.WithError(err).Warn("password validation failed") + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err}) + return + } + } + + if !cacheHit { + m.Cache.Set(cacheKey, fileLinkRow, cache.DefaultExpiration) + } + + c.Set(auth.FileLinkAccessKey, &ente.FileLinkAccessContext{ + LinkID: fileLinkRow.LinkID, + IP: clientIP, + UserAgent: userAgent, + FileID: fileLinkRow.FileID, + OwnerID: fileLinkRow.OwnerID, + }) + c.Next() + } +} + +func (m *FileLinkMiddleware) isDeviceLimitReached(ctx context.Context, + collectionSummary *ente.FileLinkRow, ip string, ua string) (bool, error) { + // skip deviceLimit check & record keeping for requests via CF worker + if network.IsCFWorkerIP(ip) { + return false, nil + } + + sharedID := collectionSummary.LinkID + hasAccessedInPast, err := m.FileLinkRepo.AccessedInPast(ctx, sharedID, ip, ua) + if err != nil { + return false, stacktrace.Propagate(err, "") + } + // if the device has accessed the url in the past, let it access it now as well, irrespective of device limit. + if hasAccessedInPast { + return false, nil + } + count, err := m.FileLinkRepo.GetUniqueAccessCount(ctx, sharedID) + if err != nil { + return false, stacktrace.Propagate(err, "failed to get unique access count") + } + + deviceLimit := int64(collectionSummary.DeviceLimit) + if deviceLimit == publicCtrl.DeviceLimitThreshold { + deviceLimit = publicCtrl.DeviceLimitThresholdMultiplier * publicCtrl.DeviceLimitThreshold + } + + if count >= publicCtrl.DeviceLimitWarningThreshold { + m.DiscordController.NotifyPotentialAbuse( + fmt.Sprintf("FileLink exceeds warning threshold: {FileID: %d, ShareID: %s}", + collectionSummary.FileID, collectionSummary.LinkID)) + } + + if deviceLimit > 0 && count >= deviceLimit { + return true, nil + } + err = m.FileLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua) + return false, stacktrace.Propagate(err, "failed to record access history") +} + +// validatePassword will verify if the user is provided correct password for the public album +func (m *FileLinkMiddleware) validatePassword( + c *gin.Context, + reqPath string, + fileLinkRow *ente.FileLinkRow, +) error { + accessTokenJWT := auth.GetAccessTokenJWT(c) + if accessTokenJWT == "" { + if array.StringInList(reqPath, filePasswordWhiteListedURLs) { + return nil + } + return &ente.ErrPassProtectedResource + } + return m.FileLinkCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash) +} diff --git a/server/pkg/middleware/rate_limit.go b/server/pkg/middleware/rate_limit.go index 14d3c92a00..bf4c403cfe 100644 --- a/server/pkg/middleware/rate_limit.go +++ b/server/pkg/middleware/rate_limit.go @@ -140,6 +140,7 @@ func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limi reqPath == "/users/verify-email" || reqPath == "/user/change-email" || reqPath == "/public-collection/verify-password" || + reqPath == "/file-link/verify-password" || reqPath == "/family/accept-invite" || reqPath == "/users/srp/attributes" || (reqPath == "/cast/device-info" && reqMethod == "POST") || diff --git a/server/pkg/repo/collection.go b/server/pkg/repo/collection.go index 3f9af70268..c76b40da50 100644 --- a/server/pkg/repo/collection.go +++ b/server/pkg/repo/collection.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/ente-io/museum/pkg/repo/public" "strconv" t "time" @@ -22,13 +23,13 @@ import ( // CollectionRepository defines the methods for inserting, updating and // retrieving collection entities from the underlying repository type CollectionRepository struct { - DB *sql.DB - FileRepo *FileRepository - PublicCollectionRepo *PublicCollectionRepository - TrashRepo *TrashRepository - SecretEncryptionKey []byte - QueueRepo *QueueRepository - LatencyLogger *prometheus.HistogramVec + DB *sql.DB + FileRepo *FileRepository + CollectionLinkRepo *public.CollectionLinkRepo + TrashRepo *TrashRepository + SecretEncryptionKey []byte + QueueRepo *QueueRepository + LatencyLogger *prometheus.HistogramVec } type SharedCollection struct { @@ -74,7 +75,7 @@ func (repo *CollectionRepository) Get(collectionID int64) (ente.Collection, erro c.EncryptedName = encryptedName.String c.NameDecryptionNonce = nameDecryptionNonce.String } - urlMap, err := repo.PublicCollectionRepo.GetCollectionToActivePublicURLMap(context.Background(), []int64{collectionID}) + urlMap, err := repo.CollectionLinkRepo.GetCollectionToActivePublicURLMap(context.Background(), []int64{collectionID}) if err != nil { return ente.Collection{}, stacktrace.Propagate(err, "failed to get publicURL info") } @@ -174,7 +175,7 @@ pct.access_token, pct.valid_till, pct.device_limit, pct.created_at, pct.updated_ if _, ok := addPublicUrlMap[pctToken.String]; !ok { addPublicUrlMap[pctToken.String] = true url := ente.PublicURL{ - URL: repo.PublicCollectionRepo.GetAlbumUrl(pctToken.String), + URL: repo.CollectionLinkRepo.GetAlbumUrl(pctToken.String), DeviceLimit: int(pctDeviceLimit.Int32), ValidTill: pctValidTill.Int64, EnableDownload: pctEnableDownload.Bool, diff --git a/server/pkg/repo/file.go b/server/pkg/repo/file.go index 2ae4eafdca..50945cf6b1 100644 --- a/server/pkg/repo/file.go +++ b/server/pkg/repo/file.go @@ -638,6 +638,16 @@ func (repo *FileRepository) GetFileAttributesForCopy(fileIDs []int64) ([]ente.Fi return result, nil } +func (repo *FileRepository) GetFileAttributes(fileID int64) (*ente.File, error) { + rows := repo.DB.QueryRow(`SELECT file_id, owner_id, file_decryption_header, thumbnail_decryption_header, metadata_decryption_header, encrypted_metadata, pub_magic_metadata FROM files WHERE file_id = $1`, fileID) + var file ente.File + err := rows.Scan(&file.ID, &file.OwnerID, &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader, &file.Metadata.DecryptionHeader, &file.Metadata.EncryptedData, &file.PubicMagicMetadata) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return &file, nil +} + // GetUsage gets the Storage usage of a user // Deprecated: GetUsage is deprecated, use UsageRepository.GetUsage func (repo *FileRepository) GetUsage(userID int64) (int64, error) { diff --git a/server/pkg/repo/public_collection.go b/server/pkg/repo/public/collection_link.go similarity index 78% rename from server/pkg/repo/public_collection.go rename to server/pkg/repo/public/collection_link.go index f5ae8f2d72..fafcd4cb11 100644 --- a/server/pkg/repo/public_collection.go +++ b/server/pkg/repo/public/collection_link.go @@ -1,4 +1,4 @@ -package repo +package public import ( "context" @@ -13,29 +13,29 @@ import ( const BaseShareURL = "https://albums.ente.io/?t=%s" -// PublicCollectionRepository defines the methods for inserting, updating and +// CollectionLinkRepo defines the methods for inserting, updating and // retrieving entities related to public collections -type PublicCollectionRepository struct { +type CollectionLinkRepo struct { DB *sql.DB albumHost string } -// NewPublicCollectionRepository .. -func NewPublicCollectionRepository(db *sql.DB, albumHost string) *PublicCollectionRepository { +// NewCollectionLinkRepository .. +func NewCollectionLinkRepository(db *sql.DB, albumHost string) *CollectionLinkRepo { if albumHost == "" { albumHost = "https://albums.ente.io" } - return &PublicCollectionRepository{ + return &CollectionLinkRepo{ DB: db, albumHost: albumHost, } } -func (pcr *PublicCollectionRepository) GetAlbumUrl(token string) string { +func (pcr *CollectionLinkRepo) GetAlbumUrl(token string) string { return fmt.Sprintf("%s/?t=%s", pcr.albumHost, token) } -func (pcr *PublicCollectionRepository) Insert(ctx context.Context, +func (pcr *CollectionLinkRepo) Insert(ctx context.Context, cID int64, token string, validTill int64, deviceLimit int, enableCollect bool, enableJoin *bool) error { // default value for enableJoin is true join := true @@ -51,7 +51,7 @@ func (pcr *PublicCollectionRepository) Insert(ctx context.Context, return stacktrace.Propagate(err, "failed to insert") } -func (pcr *PublicCollectionRepository) DisableSharing(ctx context.Context, cID int64) error { +func (pcr *CollectionLinkRepo) DisableSharing(ctx context.Context, cID int64) error { _, err := pcr.DB.ExecContext(ctx, `UPDATE public_collection_tokens SET is_disabled = true where collection_id = $1 and is_disabled = false`, cID) return stacktrace.Propagate(err, "failed to disable sharing") @@ -59,7 +59,7 @@ func (pcr *PublicCollectionRepository) DisableSharing(ctx context.Context, cID i // GetCollectionToActivePublicURLMap will return map of collectionID to PublicURLs which are not disabled yet. // Note: The url could be expired or deviceLimit is already reached -func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx context.Context, collectionIDs []int64) (map[int64][]ente.PublicURL, error) { +func (pcr *CollectionLinkRepo) GetCollectionToActivePublicURLMap(ctx context.Context, collectionIDs []int64) (map[int64][]ente.PublicURL, error) { rows, err := pcr.DB.QueryContext(ctx, `SELECT collection_id, access_token, valid_till, device_limit, enable_download, enable_collect, enable_join, pw_nonce, mem_limit, ops_limit FROM public_collection_tokens WHERE collection_id = ANY($1) and is_disabled = FALSE`, pq.Array(collectionIDs)) @@ -92,26 +92,26 @@ func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx con return result, nil } -// GetActivePublicCollectionToken will return ente.PublicCollectionToken for given collection ID +// GetActiveCollectionLinkRow will return ente.CollectionLinkRow for given collection ID // Note: The token could be expired or deviceLimit is already reached -func (pcr *PublicCollectionRepository) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) { +func (pcr *CollectionLinkRepo) GetActiveCollectionLinkRow(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT id, collection_id, access_token, valid_till, device_limit, is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download, enable_collect, enable_join FROM public_collection_tokens WHERE collection_id = $1 and is_disabled = FALSE`, collectionID) //defer rows.Close() - ret := ente.PublicCollectionToken{} + ret := ente.CollectionLinkRow{} err := row.Scan(&ret.ID, &ret.CollectionID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit, &ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload, &ret.EnableCollect, &ret.EnableJoin) if err != nil { - return ente.PublicCollectionToken{}, stacktrace.Propagate(err, "") + return ente.CollectionLinkRow{}, stacktrace.Propagate(err, "") } return ret, nil } // UpdatePublicCollectionToken will update the row for corresponding public collection token -func (pcr *PublicCollectionRepository) UpdatePublicCollectionToken(ctx context.Context, pct ente.PublicCollectionToken) error { +func (pcr *CollectionLinkRepo) UpdatePublicCollectionToken(ctx context.Context, pct ente.CollectionLinkRow) error { _, err := pcr.DB.ExecContext(ctx, `UPDATE public_collection_tokens SET valid_till = $1, device_limit = $2, pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7, enable_collect = $8, enable_join = $9 where id = $10`, @@ -119,7 +119,7 @@ func (pcr *PublicCollectionRepository) UpdatePublicCollectionToken(ctx context.C return stacktrace.Propagate(err, "failed to update public collection token") } -func (pcr *PublicCollectionRepository) RecordAbuseReport(ctx context.Context, accessCtx ente.PublicAccessContext, +func (pcr *CollectionLinkRepo) RecordAbuseReport(ctx context.Context, accessCtx ente.PublicAccessContext, url string, reason string, details ente.AbuseReportDetails) error { _, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_abuse_report (share_id, ip, user_agent, url, reason, details) VALUES ($1, $2, $3, $4, $5, $6) @@ -128,7 +128,7 @@ func (pcr *PublicCollectionRepository) RecordAbuseReport(ctx context.Context, ac return stacktrace.Propagate(err, "failed to record abuse report") } -func (pcr *PublicCollectionRepository) GetAbuseReportCount(ctx context.Context, accessCtx ente.PublicAccessContext) (int64, error) { +func (pcr *CollectionLinkRepo) GetAbuseReportCount(ctx context.Context, accessCtx ente.PublicAccessContext) (int64, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_abuse_report WHERE share_id = $1`, accessCtx.ID) var count int64 = 0 err := row.Scan(&count) @@ -138,7 +138,7 @@ func (pcr *PublicCollectionRepository) GetAbuseReportCount(ctx context.Context, return count, nil } -func (pcr *PublicCollectionRepository) GetUniqueAccessCount(ctx context.Context, shareId int64) (int64, error) { +func (pcr *CollectionLinkRepo) GetUniqueAccessCount(ctx context.Context, shareId int64) (int64, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_collection_access_history WHERE share_id = $1`, shareId) var count int64 = 0 err := row.Scan(&count) @@ -148,7 +148,7 @@ func (pcr *PublicCollectionRepository) GetUniqueAccessCount(ctx context.Context, return count, nil } -func (pcr *PublicCollectionRepository) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error { +func (pcr *CollectionLinkRepo) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error { _, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_collection_access_history (share_id, ip, user_agent) VALUES ($1, $2, $3) ON CONFLICT ON CONSTRAINT unique_access_sid_ip_ua DO NOTHING;`, @@ -157,7 +157,7 @@ func (pcr *PublicCollectionRepository) RecordAccessHistory(ctx context.Context, } // AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past -func (pcr *PublicCollectionRepository) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) { +func (pcr *CollectionLinkRepo) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) { row := pcr.DB.QueryRowContext(ctx, `select share_id from public_collection_access_history where share_id =$1 and ip = $2 and user_agent = $3`, shareID, ip, ua) var tempID int64 @@ -168,7 +168,7 @@ func (pcr *PublicCollectionRepository) AccessedInPast(ctx context.Context, share return true, stacktrace.Propagate(err, "failed to record access history") } -func (pcr *PublicCollectionRepository) GetCollectionSummaryByToken(ctx context.Context, accessToken string) (ente.PublicCollectionSummary, error) { +func (pcr *CollectionLinkRepo) GetCollectionSummaryByToken(ctx context.Context, accessToken string) (ente.PublicCollectionSummary, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT sct.id, sct.collection_id, sct.is_disabled, sct.valid_till, sct.device_limit, sct.pw_hash, sct.created_at, sct.updated_at, count(ah.share_id) @@ -185,7 +185,7 @@ func (pcr *PublicCollectionRepository) GetCollectionSummaryByToken(ctx context.C return result, nil } -func (pcr *PublicCollectionRepository) GetActivePublicTokenForUser(ctx context.Context, userID int64) ([]int64, error) { +func (pcr *CollectionLinkRepo) GetActivePublicTokenForUser(ctx context.Context, userID int64) ([]int64, error) { rows, err := pcr.DB.QueryContext(ctx, `select pt.collection_id from public_collection_tokens pt left join collections c on pt.collection_id = c.collection_id where pt.is_disabled = FALSE and c.owner_id= $1;`, userID) if err != nil { return nil, stacktrace.Propagate(err, "") @@ -204,7 +204,7 @@ func (pcr *PublicCollectionRepository) GetActivePublicTokenForUser(ctx context.C } // CleanupAccessHistory public_collection_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days -func (pcr *PublicCollectionRepository) CleanupAccessHistory(ctx context.Context) error { +func (pcr *CollectionLinkRepo) CleanupAccessHistory(ctx context.Context) error { _, err := pcr.DB.ExecContext(ctx, `DELETE FROM public_collection_access_history WHERE share_id IN (SELECT id FROM public_collection_tokens WHERE is_disabled = TRUE AND updated_at < (now_utc_micro_seconds() - (24::BIGINT * 30 * 60 * 60 * 1000 * 1000)))`) if err != nil { return stacktrace.Propagate(err, "failed to clean up public collection access history") diff --git a/server/pkg/repo/public/file_link.go b/server/pkg/repo/public/file_link.go new file mode 100644 index 0000000000..fbb2f4e072 --- /dev/null +++ b/server/pkg/repo/public/file_link.go @@ -0,0 +1,218 @@ +package public + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/ente-io/museum/ente/base" + "github.com/lib/pq" + "github.com/spf13/viper" + + "github.com/ente-io/museum/ente" + "github.com/ente-io/stacktrace" +) + +// FileLinkRepository defines the methods for inserting, updating and +// retrieving entities related to public file +type FileLinkRepository struct { + DB *sql.DB + photoHost string + lockerHost string +} + +// NewFileLinkRepo .. +func NewFileLinkRepo(db *sql.DB) *FileLinkRepository { + albumHost := viper.GetString("apps.public-albums") + if albumHost == "" { + albumHost = "https://albums.ente.io" + } + lockerHost := viper.GetString("apps.public-locker") + if lockerHost == "" { + lockerHost = "https://locker.ente.io" + } + return &FileLinkRepository{ + DB: db, + photoHost: albumHost, + lockerHost: lockerHost, + } +} + +func (pcr *FileLinkRepository) PhotoLink(token string) string { + return fmt.Sprintf("%s/?t=%s", pcr.photoHost, token) +} + +func (pcr *FileLinkRepository) LockerFileLink(token string) string { + return fmt.Sprintf("%s/?t=%s", pcr.lockerHost, token) +} + +func (pcr *FileLinkRepository) Insert( + ctx context.Context, + fileID int64, + ownerID int64, + token string, + app ente.App, +) (*string, error) { + id, err := base.NewID("pft") + if err != nil { + return nil, stacktrace.Propagate(err, "failed to generate new ID for public file token") + } + _, err = pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens + (id, file_id, owner_id, access_token, app) VALUES ($1, $2, $3, $4, $5)`, + id, fileID, ownerID, token, string(app)) + if err != nil { + if err.Error() == "pq: duplicate key value violates unique constraint \"public_active_file_link_unique_idx\"" { + return nil, ente.ErrActiveLinkAlreadyExists + } + return nil, stacktrace.Propagate(err, "failed to insert") + } + return id, nil +} + +// GetActiveFileUrlToken will return ente.CollectionLinkRow for given collection ID +// Note: The token could be expired or deviceLimit is already reached +func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) { + row := pcr.DB.QueryRowContext(ctx, `SELECT id, file_id, owner_id, access_token, valid_till, device_limit, + is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download FROM + public_file_tokens WHERE file_id = $1 and is_disabled = FALSE`, + fileID) + + ret := ente.FileLinkRow{} + err := row.Scan(&ret.LinkID, &ret.FileID, ret.OwnerID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit, + &ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return &ret, nil +} +func (pcr *FileLinkRepository) GetFileUrls(ctx context.Context, userID int64, sinceTime int64, limit int64, app ente.App) ([]*ente.FileLinkRow, error) { + if limit <= 0 { + limit = 500 + } + query := `SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit, + created_at, updated_at FROM public_file_tokens + WHERE owner_id = $1 AND created_at > $2 AND app = $3 ORDER BY updated_at DESC LIMIT $4` + rows, err := pcr.DB.QueryContext(ctx, query, userID, sinceTime, string(app), limit) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get public file urls") + } + defer rows.Close() + + var result []*ente.FileLinkRow + for rows.Next() { + var row ente.FileLinkRow + err = rows.Scan(&row.LinkID, &row.FileID, &row.OwnerID, &row.IsDisabled, + &row.ValidTill, &row.DeviceLimit, &row.EnableDownload, + &row.PassHash, &row.Nonce, &row.MemLimit, + &row.OpsLimit, &row.CreatedAt, &row.UpdatedAt) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to scan public file url row") + } + result = append(result, &row) + } + return result, nil +} + +func (pcr *FileLinkRepository) DisableLinkForFiles(ctx context.Context, fileIDs []int64) error { + if len(fileIDs) == 0 { + return nil + } + query := `UPDATE public_file_tokens SET is_disabled = TRUE WHERE file_id = ANY($1)` + _, err := pcr.DB.ExecContext(ctx, query, pq.Array(fileIDs)) + if err != nil { + return stacktrace.Propagate(err, "failed to disable public file links") + } + return nil +} + +// DisableLinksForUser will disable all public file links for the given user +func (pcr *FileLinkRepository) DisableLinksForUser(ctx context.Context, userID int64) error { + _, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET is_disabled = TRUE WHERE owner_id = $1`, userID) + if err != nil { + return stacktrace.Propagate(err, "failed to disable public file link") + } + return nil +} + +func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.FileLinkRow, error) { + row := pcr.DB.QueryRowContext(ctx, + `SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit + created_at, updated_at + from public_file_tokens + where access_token = $1 +`, accessToken) + var result = ente.FileLinkRow{} + err := row.Scan(&result.LinkID, &result.FileID, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, ente.ErrNotFound + } + return nil, stacktrace.Propagate(err, "failed to get public file url summary by token") + } + return &result, nil +} + +func (pcr *FileLinkRepository) GetFileUrlRowByFileID(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) { + row := pcr.DB.QueryRowContext(ctx, + `SELECT id, file_id, access_token, owner_id, is_disabled, enable_download, valid_till, device_limit, pw_hash, pw_nonce, mem_limit, ops_limit, + created_at, updated_at + from public_file_tokens + where file_id = $1 and is_disabled = FALSE`, fileID) + var result = ente.FileLinkRow{} + err := row.Scan(&result.LinkID, &result.FileID, &result.Token, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, ente.ErrNotFound + } + return nil, stacktrace.Propagate(err, "failed to get public file url summary by file ID") + } + return &result, nil +} + +// UpdateLink will update the row for corresponding public file token +func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.FileLinkRow) error { + _, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET valid_till = $1, device_limit = $2, + pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7 + where id = $8`, + pct.ValidTill, pct.DeviceLimit, pct.PassHash, pct.Nonce, pct.MemLimit, pct.OpsLimit, pct.EnableDownload, pct.LinkID) + return stacktrace.Propagate(err, "failed to update public file token") +} + +func (pcr *FileLinkRepository) GetUniqueAccessCount(ctx context.Context, linkId string) (int64, error) { + row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_file_tokens_access_history WHERE id = $1`, linkId) + var count int64 = 0 + err := row.Scan(&count) + if err != nil { + return -1, stacktrace.Propagate(err, "") + } + return count, nil +} + +func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID string, ip string, ua string) error { + _, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens_access_history + (id, ip, user_agent) VALUES ($1, $2, $3) + ON CONFLICT ON CONSTRAINT unique_access_id_ip_ua DO NOTHING;`, + shareID, ip, ua) + return stacktrace.Propagate(err, "failed to record access history") +} + +// AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past +func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID string, ip string, ua string) (bool, error) { + row := pcr.DB.QueryRowContext(ctx, `select id from public_file_tokens_access_history where id =$1 and ip = $2 and user_agent = $3`, + shareID, ip, ua) + var tempID int64 + err := row.Scan(&tempID) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + return true, stacktrace.Propagate(err, "failed to record access history") +} + +// CleanupAccessHistory public_file_tokens_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days +func (pcr *FileLinkRepository) CleanupAccessHistory(ctx context.Context) error { + _, err := pcr.DB.ExecContext(ctx, `DELETE FROM public_file_tokens_access_history WHERE id IN (SELECT id FROM public_file_tokens WHERE is_disabled = TRUE AND updated_at < (now_utc_micro_seconds() - (24::BIGINT * 30 * 60 * 60 * 1000 * 1000)))`) + if err != nil { + return stacktrace.Propagate(err, "failed to clean up public file access history") + } + return nil +} diff --git a/server/pkg/repo/trash.go b/server/pkg/repo/trash.go index f1ab3c2289..3781e0716e 100644 --- a/server/pkg/repo/trash.go +++ b/server/pkg/repo/trash.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/ente-io/museum/pkg/repo/public" "strings" "github.com/ente-io/museum/ente" @@ -32,10 +33,11 @@ type FileWithUpdatedAt struct { } type TrashRepository struct { - DB *sql.DB - ObjectRepo *ObjectRepository - FileRepo *FileRepository - QueueRepo *QueueRepository + DB *sql.DB + ObjectRepo *ObjectRepository + FileRepo *FileRepository + QueueRepo *QueueRepository + FileLinkRepo *public.FileLinkRepository } func (t *TrashRepository) InsertItems(ctx context.Context, tx *sql.Tx, userID int64, items []ente.TrashItemRequest) error { @@ -156,6 +158,13 @@ func (t *TrashRepository) TrashFiles(fileIDs []int64, userID int64, trash ente.T return stacktrace.Propagate(err, "") } err = tx.Commit() + + if err == nil { + removeLinkErr := t.FileLinkRepo.DisableLinkForFiles(ctx, fileIDs) + if removeLinkErr != nil { + return stacktrace.Propagate(removeLinkErr, "failed to disable file links for files being trashed") + } + } return stacktrace.Propagate(err, "") } diff --git a/server/pkg/utils/auth/auth.go b/server/pkg/utils/auth/auth.go index 6f8091998b..8b52808e36 100644 --- a/server/pkg/utils/auth/auth.go +++ b/server/pkg/utils/auth/auth.go @@ -17,8 +17,9 @@ import ( ) const ( - PublicAccessKey = "X-Public-Access-ID" - CastContext = "X-Cast-Context" + PublicAccessKey = "X-Public-Access-ID" + FileLinkAccessKey = "X-Public-FileLink-Access-ID" + CastContext = "X-Cast-Context" ) // GenerateRandomBytes returns securely generated random bytes. @@ -120,6 +121,8 @@ func GetCastToken(c *gin.Context) string { return token } +// GetAccessTokenJWT fetches the JWT access token from the request header or query parameters. +// This token is issued by server on password verification of links that are protected by password. func GetAccessTokenJWT(c *gin.Context) string { token := c.GetHeader("X-Auth-Access-Token-JWT") if token == "" { @@ -132,6 +135,10 @@ func MustGetPublicAccessContext(c *gin.Context) ente.PublicAccessContext { return c.MustGet(PublicAccessKey).(ente.PublicAccessContext) } +func MustGetFileLinkAccessContext(c *gin.Context) *ente.FileLinkAccessContext { + return c.MustGet(FileLinkAccessKey).(*ente.FileLinkAccessContext) +} + func GetCastCtx(c *gin.Context) cast.AuthContext { return c.MustGet(CastContext).(cast.AuthContext) }