diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index ed731b51da..21519ca285 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -311,16 +311,16 @@ func main() { } collectionController := &collections.CollectionController{ - CollectionRepo: collectionRepo, - EmailCtrl: emailNotificationCtrl, - AccessCtrl: accessCtrl, - CollectionLinkController: collectionLinkCtrl, - UserRepo: userRepo, - FileRepo: fileRepo, - CastRepo: &castDb, - BillingCtrl: billingController, - QueueRepo: queueRepo, - TaskRepo: taskLockingRepo, + CollectionRepo: collectionRepo, + EmailCtrl: emailNotificationCtrl, + AccessCtrl: accessCtrl, + CollectionLinkCtrl: collectionLinkCtrl, + UserRepo: userRepo, + FileRepo: fileRepo, + CastRepo: &castDb, + BillingCtrl: billingController, + QueueRepo: queueRepo, + TaskRepo: taskLockingRepo, } kexCtrl := &kexCtrl.Controller{ @@ -442,6 +442,9 @@ func main() { privateAPI.GET("/files/preview/v2/:fileID", fileHandler.GetThumbnail) privateAPI.POST("files/share-url", fileHandler.ShareUrl) + privateAPI.PUT("files/share-url", fileHandler.UpdateFileURL) + privateAPI.DELETE("files/share-url/:fileID", fileHandler.DisableUrl) + privateAPI.GET("files/share-urls/", fileHandler.DisableUrl) privateAPI.PUT("/files/data", fileHandler.PutFileData) privateAPI.PUT("/files/video-data", fileHandler.PutVideoData) diff --git a/server/ente/public_file.go b/server/ente/file_link.go similarity index 57% rename from server/ente/public_file.go rename to server/ente/file_link.go index 12398c38a1..a69034bf11 100644 --- a/server/ente/public_file.go +++ b/server/ente/file_link.go @@ -1,8 +1,14 @@ package ente +import ( + "fmt" + "github.com/ente-io/museum/pkg/utils/time" +) + // CreateFileUrl represents an encrypted file in the system type CreateFileUrl struct { FileID int64 `json:"fileID" binding:"required"` + App App `json:"app" binding:"required"` } // UpdateFileUrl .. @@ -19,6 +25,33 @@ type UpdateFileUrl struct { DisablePassword *bool `json:"disablePassword"` } +func (ut *UpdateFileUrl) Validate() error { + if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil && + ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil { + return NewBadRequestWithMessage("all parameters are missing") + } + + if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) { + return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit)) + } + + if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() { + return NewBadRequestWithMessage("valid till should be greater than current timestamp") + } + + var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil + var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil + + if !(allPassParamsMissing || allPassParamsPresent) { + return NewBadRequestWithMessage("all password params should be either present or missing") + } + + if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword { + return NewBadRequestWithMessage("can not set and disable password in same request") + } + return nil +} + type FileLinkRow struct { LinkID string OwnerID int64 @@ -35,6 +68,7 @@ type FileLinkRow struct { CreatedAt int64 UpdatedAt int64 } + type FileUrl struct { LinkID string `json:"linkID" binding:"required"` URL string `json:"url" binding:"required"` diff --git a/server/ente/public_collection.go b/server/ente/public_collection.go index f34c0bf2f1..5f3867e6d0 100644 --- a/server/ente/public_collection.go +++ b/server/ente/public_collection.go @@ -3,7 +3,7 @@ package ente import ( "database/sql/driver" "encoding/json" - + "fmt" "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" ) @@ -32,6 +32,33 @@ type UpdatePublicAccessTokenRequest struct { EnableJoin *bool `json:"enableJoin"` } +func (ut *UpdatePublicAccessTokenRequest) Validate() error { + if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil && + ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil && ut.EnableCollect == nil { + return NewBadRequestWithMessage("all parameters are missing") + } + + if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) { + return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit)) + } + + if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() { + return NewBadRequestWithMessage("valid till should be greater than current timestamp") + } + + var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil + var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil + + if !(allPassParamsMissing || allPassParamsPresent) { + return NewBadRequestWithMessage("all password params should be either present or missing") + } + + if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword { + return NewBadRequestWithMessage("can not set and disable password in same request") + } + return nil +} + type VerifyPasswordRequest struct { PassHash string `json:"passHash" binding:"required"` } diff --git a/server/migrations/102_single_file_url.up.sql b/server/migrations/102_single_file_url.up.sql index 5a1f55089a..0f94897ba2 100644 --- a/server/migrations/102_single_file_url.up.sql +++ b/server/migrations/102_single_file_url.up.sql @@ -5,6 +5,7 @@ CREATE TABLE IF NOT EXISTS public_file_tokens id text primary key, file_id bigint NOT NULL, owner_id bigint NOT NULL, + app text NOT NULL, access_token text not null, valid_till bigint not null DEFAULT 0, device_limit int not null DEFAULT 0, @@ -40,6 +41,6 @@ CREATE TABLE IF NOT EXISTS public_file_tokens_access_history ON DELETE CASCADE ); -CREATE UNIQUE INDEX IF NOT EXISTS public_access_token_unique_idx ON public_file_tokens (access_token) WHERE is_disabled = FALSE; +CREATE UNIQUE INDEX IF NOT EXISTS public_file_token_unique_idx ON public_file_tokens (access_token) WHERE is_disabled = FALSE; CREATE INDEX IF NOT EXISTS public_file_tokens_owner_id_updated_at_idx ON public_file_tokens (owner_id, updated_at); diff --git a/server/pkg/api/collection.go b/server/pkg/api/collection.go index d5f918c672..6c198652c7 100644 --- a/server/pkg/api/collection.go +++ b/server/pkg/api/collection.go @@ -3,7 +3,6 @@ package api import ( "fmt" "github.com/ente-io/museum/pkg/controller/collections" - "github.com/ente-io/museum/pkg/controller/public" "net/http" "strconv" @@ -172,35 +171,6 @@ func (h *CollectionHandler) UpdateShareURL(c *gin.Context) { handler.Error(c, stacktrace.Propagate(err, "")) return } - if req.DeviceLimit == nil && req.ValidTill == nil && req.DisablePassword == nil && - req.Nonce == nil && req.PassHash == nil && req.EnableDownload == nil && req.EnableCollect == nil { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all parameters are missing")) - return - } - - if req.DeviceLimit != nil && (*req.DeviceLimit < 0 || *req.DeviceLimit > public.DeviceLimitThreshold) { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("device limit: %d out of range", *req.DeviceLimit))) - return - } - - if req.ValidTill != nil && *req.ValidTill != 0 && *req.ValidTill < time.Microseconds() { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "valid till should be greater than current timestamp")) - return - } - - var allPassParamsMissing = req.Nonce == nil && req.PassHash == nil && req.MemLimit == nil && req.OpsLimit == nil - var allPassParamsPresent = req.Nonce != nil && req.PassHash != nil && req.MemLimit != nil && req.OpsLimit != nil - - if !(allPassParamsMissing || allPassParamsPresent) { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all password params should be either present or missing")) - return - } - - if allPassParamsPresent && req.DisablePassword != nil && *req.DisablePassword { - handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "can not set and disable password in same request")) - return - } - response, err := h.Controller.UpdateShareURL(c, auth.GetUserID(c.Request.Header), req) if err != nil { handler.Error(c, stacktrace.Propagate(err, "")) diff --git a/server/pkg/api/file_url.go b/server/pkg/api/file_url.go index 9073a0a331..1780345bfc 100644 --- a/server/pkg/api/file_url.go +++ b/server/pkg/api/file_url.go @@ -6,6 +6,7 @@ import ( "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" "net/http" + "strconv" ) // ShareUrl a sharable url for the file @@ -23,3 +24,58 @@ func (h *FileHandler) ShareUrl(c *gin.Context) { } c.JSON(http.StatusOK, response) } + +func (h *FileHandler) DisableUrl(c *gin.Context) { + cID, err := strconv.ParseInt(c.Param("fileID"), 10, 64) + if err != nil { + handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "")) + return + } + err = h.FileUrlCtrl.Disable(c, cID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{}) +} + +func (h *FileHandler) GetUrls(c *gin.Context) { + sinceTime, err := strconv.ParseInt(c.Query("sinceTime"), 10, 64) + if err != nil { + handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "")) + return + } + limit := 500 + if c.Query("limit") != "" { + limit, err = strconv.Atoi(c.Query("limit")) + if err != nil || limit < 1 { + handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "")) + return + } + } + response, err := h.FileUrlCtrl.GetUrls(c, sinceTime, int64(limit)) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{ + "diff": response, + }) +} + +// UpdateFileURL updates the share URL for a file +func (h *FileHandler) UpdateFileURL(c *gin.Context) { + var req ente.UpdateFileUrl + if err := c.ShouldBindJSON(&req); err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + response, err := h.FileUrlCtrl.UpdateSharedUrl(c, req) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{ + "result": response, + }) +} diff --git a/server/pkg/controller/collections/collection.go b/server/pkg/controller/collections/collection.go index 52a36df782..5b86d38d0a 100644 --- a/server/pkg/controller/collections/collection.go +++ b/server/pkg/controller/collections/collection.go @@ -25,16 +25,16 @@ const ( // CollectionController encapsulates logic that deals with collections type CollectionController struct { - CollectionLinkController *public.CollectionLinkController - EmailCtrl *email.EmailNotificationController - AccessCtrl access.Controller - BillingCtrl *controller.BillingController - CollectionRepo *repo.CollectionRepository - UserRepo *repo.UserRepository - FileRepo *repo.FileRepository - QueueRepo *repo.QueueRepository - CastRepo *cast.Repository - TaskRepo *repo.TaskLockRepository + CollectionLinkCtrl *public.CollectionLinkController + EmailCtrl *email.EmailNotificationController + AccessCtrl access.Controller + BillingCtrl *controller.BillingController + CollectionRepo *repo.CollectionRepository + UserRepo *repo.UserRepository + FileRepo *repo.FileRepository + QueueRepo *repo.QueueRepository + CastRepo *cast.Repository + TaskRepo *repo.TaskLockRepository } // Create creates a collection @@ -149,7 +149,7 @@ func (c *CollectionController) TrashV3(ctx *gin.Context, req ente.TrashCollectio } } - err = c.CollectionLinkController.Disable(ctx, cID) + err = c.CollectionLinkCtrl.Disable(ctx, cID) if err != nil { return stacktrace.Propagate(err, "failed to disabled public share url") } @@ -210,7 +210,7 @@ func (c *CollectionController) HandleAccountDeletion(ctx context.Context, userID if err != nil { return stacktrace.Propagate(err, "failed to revoke cast token for user") } - err = c.CollectionLinkController.HandleAccountDeletion(ctx, userID, logger) + err = c.CollectionLinkCtrl.HandleAccountDeletion(ctx, userID, logger) return stacktrace.Propagate(err, "") } diff --git a/server/pkg/controller/collections/share.go b/server/pkg/controller/collections/share.go index 7651266ece..6002c7b493 100644 --- a/server/pkg/controller/collections/share.go +++ b/server/pkg/controller/collections/share.go @@ -70,7 +70,7 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec if !collection.AllowSharing() { return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("joining %s is not allowed", collection.Type)) } - collectionLinkToken, err := c.CollectionLinkController.GetActiveCollectionLinkToken(ctx, req.CollectionID) + collectionLinkToken, err := c.CollectionLinkCtrl.GetActiveCollectionLinkToken(ctx, req.CollectionID) if err != nil { return stacktrace.Propagate(err, "") } @@ -84,7 +84,7 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec } if collectionLinkToken.PassHash != nil && *collectionLinkToken.PassHash != "" { accessTokenJWT := auth.GetAccessTokenJWT(ctx) - if passCheckErr := c.CollectionLinkController.ValidateJWTToken(ctx, accessTokenJWT, *collectionLinkToken.PassHash); passCheckErr != nil { + if passCheckErr := c.CollectionLinkCtrl.ValidateJWTToken(ctx, accessTokenJWT, *collectionLinkToken.PassHash); passCheckErr != nil { return stacktrace.Propagate(passCheckErr, "") } } @@ -197,7 +197,7 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } - response, err := c.CollectionLinkController.CreateLink(ctx, req) + response, err := c.CollectionLinkCtrl.CreateLink(ctx, req) if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") } @@ -205,20 +205,26 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e } // UpdateShareURL updates the shared url configuration -func (c *CollectionController) UpdateShareURL(ctx context.Context, userID int64, req ente.UpdatePublicAccessTokenRequest) ( - ente.PublicURL, error) { +func (c *CollectionController) UpdateShareURL( + ctx context.Context, + userID int64, + req ente.UpdatePublicAccessTokenRequest, +) (*ente.PublicURL, error) { + if err := req.Validate(); err != nil { + return nil, stacktrace.Propagate(err, "") + } if err := c.verifyOwnership(req.CollectionID, userID); err != nil { - return ente.PublicURL{}, stacktrace.Propagate(err, "") + return nil, stacktrace.Propagate(err, "") } err := c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true) if err != nil { - return ente.PublicURL{}, stacktrace.Propagate(err, "") + return nil, stacktrace.Propagate(err, "") } - response, err := c.CollectionLinkController.UpdateSharedUrl(ctx, req) + response, err := c.CollectionLinkCtrl.UpdateSharedUrl(ctx, req) if err != nil { - return ente.PublicURL{}, stacktrace.Propagate(err, "") + return nil, stacktrace.Propagate(err, "") } - return response, nil + return &response, nil } // DisableSharedURL disable a public auth-token for the given collectionID @@ -226,7 +232,7 @@ func (c *CollectionController) DisableSharedURL(ctx context.Context, userID int6 if err := c.verifyOwnership(cID, userID); err != nil { return stacktrace.Propagate(err, "") } - err := c.CollectionLinkController.Disable(ctx, cID) + err := c.CollectionLinkCtrl.Disable(ctx, cID) return stacktrace.Propagate(err, "") } diff --git a/server/pkg/controller/public/file_link.go b/server/pkg/controller/public/file_link.go index edb90ec155..39f1ba9fe7 100644 --- a/server/pkg/controller/public/file_link.go +++ b/server/pkg/controller/public/file_link.go @@ -15,17 +15,27 @@ import ( type FileLinkController struct { FileController *controller.FileController FileLinkRepo *public.FileLinkRepository - CollectionRepo *repo.CollectionRepository - UserRepo *repo.UserRepository + FileRepo *repo.FileRepository JwtSecret []byte } func (c *FileLinkController) CreateLink(ctx *gin.Context, req ente.CreateFileUrl) (*ente.FileUrl, error) { actorUserID := auth.GetUserID(ctx.Request.Header) + app := auth.GetApp(ctx) + if req.App != app { + return nil, stacktrace.Propagate(ente.NewBadRequestWithMessage("app mismatch"), "app mismatch") + } + file, err := c.FileRepo.GetFileAttributes(req.FileID) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get file attributes") + } + if actorUserID != file.OwnerID { + return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "") + } accessToken := shortuuid.New()[0:AccessTokenLength] - _, err := c.FileLinkRepo.Insert(ctx, req.FileID, actorUserID, accessToken) - if err == nil { - row, rowErr := c.FileLinkRepo.GetActiveFileUrlToken(ctx, req.FileID) + _, err = c.FileLinkRepo.Insert(ctx, req.FileID, actorUserID, accessToken, app) + if err == nil || err == ente.ErrActiveLinkAlreadyExists { + row, rowErr := c.FileLinkRepo.GetFileUrlRowByFileID(ctx, req.FileID) if rowErr != nil { return nil, stacktrace.Propagate(rowErr, "failed to get active file url token") } @@ -34,6 +44,72 @@ func (c *FileLinkController) CreateLink(ctx *gin.Context, req ente.CreateFileUrl return nil, stacktrace.Propagate(err, "failed to create public file link") } +// Disable all public accessTokens generated for the given fileID till date. +func (c *FileLinkController) Disable(ctx *gin.Context, fileID int64) error { + userID := auth.GetUserID(ctx.Request.Header) + file, err := c.FileRepo.GetFileAttributes(fileID) + if err != nil { + return stacktrace.Propagate(err, "failed to get file attributes") + } + if userID != file.OwnerID { + return stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "") + } + return c.FileLinkRepo.DisableLinkForFiles(ctx, []int64{fileID}) +} + +func (c *FileLinkController) GetUrls(ctx *gin.Context, sinceTime int64, limit int64) ([]*ente.FileUrl, error) { + userID := auth.GetUserID(ctx.Request.Header) + app := auth.GetApp(ctx) + fileLinks, err := c.FileLinkRepo.GetFileUrls(ctx, userID, sinceTime, limit, app) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get file urls") + } + var fileUrls []*ente.FileUrl + for _, row := range fileLinks { + fileUrls = append(fileUrls, c.mapRowToFileUrl(ctx, row)) + } + return fileUrls, nil +} + +func (c *FileLinkController) UpdateSharedUrl(ctx *gin.Context, req ente.UpdateFileUrl) (*ente.FileUrl, error) { + if err := req.Validate(); err != nil { + return nil, stacktrace.Propagate(err, "invalid request") + } + fileLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, req.FileID) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get file link info") + } + if fileLinkRow.OwnerID != auth.GetUserID(ctx.Request.Header) { + return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "") + } + if req.ValidTill != nil { + fileLinkRow.ValidTill = *req.ValidTill + } + if req.DeviceLimit != nil { + fileLinkRow.DeviceLimit = *req.DeviceLimit + } + if req.PassHash != nil && req.Nonce != nil && req.OpsLimit != nil && req.MemLimit != nil { + fileLinkRow.PassHash = req.PassHash + fileLinkRow.Nonce = req.Nonce + fileLinkRow.OpsLimit = req.OpsLimit + fileLinkRow.MemLimit = req.MemLimit + } else if req.DisablePassword != nil && *req.DisablePassword { + fileLinkRow.PassHash = nil + fileLinkRow.Nonce = nil + fileLinkRow.OpsLimit = nil + fileLinkRow.MemLimit = nil + } + if req.EnableDownload != nil { + fileLinkRow.EnableDownload = *req.EnableDownload + } + + err = c.FileLinkRepo.UpdateLink(ctx, *fileLinkRow) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return c.mapRowToFileUrl(ctx, fileLinkRow), nil +} + // VerifyPassword verifies if the user has provided correct pw hash. If yes, it returns a signed jwt token which can be // used by the client to pass in other requests for public collection. // Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force diff --git a/server/pkg/repo/file.go b/server/pkg/repo/file.go index 2ae4eafdca..50945cf6b1 100644 --- a/server/pkg/repo/file.go +++ b/server/pkg/repo/file.go @@ -638,6 +638,16 @@ func (repo *FileRepository) GetFileAttributesForCopy(fileIDs []int64) ([]ente.Fi return result, nil } +func (repo *FileRepository) GetFileAttributes(fileID int64) (*ente.File, error) { + rows := repo.DB.QueryRow(`SELECT file_id, owner_id, file_decryption_header, thumbnail_decryption_header, metadata_decryption_header, encrypted_metadata, pub_magic_metadata FROM files WHERE file_id = $1`, fileID) + var file ente.File + err := rows.Scan(&file.ID, &file.OwnerID, &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader, &file.Metadata.DecryptionHeader, &file.Metadata.EncryptedData, &file.PubicMagicMetadata) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return &file, nil +} + // GetUsage gets the Storage usage of a user // Deprecated: GetUsage is deprecated, use UsageRepository.GetUsage func (repo *FileRepository) GetUsage(userID int64) (int64, error) { diff --git a/server/pkg/repo/public/public_file.go b/server/pkg/repo/public/public_file.go index 20573031a7..3f6dbf7b33 100644 --- a/server/pkg/repo/public/public_file.go +++ b/server/pkg/repo/public/public_file.go @@ -46,16 +46,17 @@ func (pcr *FileLinkRepository) Insert( fileID int64, ownerID int64, token string, + app ente.App, ) (*string, error) { id, err := base.NewID("pft") if err != nil { return nil, stacktrace.Propagate(err, "failed to generate new ID for public file token") } _, err = pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens - (id, file_id, owner_id, access_token) VALUES ($1, $2, $3, $4)`, - id, fileID, ownerID, token) + (id, file_id, owner_id, access_token, app) VALUES ($1, $2, $3, $4, $5)`, + id, fileID, ownerID, token, string(app)) if err != nil { - if err.Error() == "pq: duplicate key value violates unique constraint \"public_access_token_unique_idx\"" { + if err.Error() == "pq: duplicate key value violates unique constraint \"public_file_token_unique_idx\"" { return nil, ente.ErrActiveLinkAlreadyExists } return nil, stacktrace.Propagate(err, "failed to insert") @@ -79,6 +80,54 @@ func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID } return &ret, nil } +func (pcr *FileLinkRepository) GetFileUrls(ctx context.Context, userID int64, sinceTime int64, limit int64, app ente.App) ([]*ente.FileLinkRow, error) { + if limit <= 0 { + limit = 500 + } + query := `SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit, + created_at, updated_at FROM public_file_tokens + WHERE owner_id = $1 AND created_at > $2 AND app = $3 ORDER BY updated_at DESC LIMIT $4` + rows, err := pcr.DB.QueryContext(ctx, query, userID, sinceTime, string(app), limit) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get public file urls") + } + defer rows.Close() + + var result []*ente.FileLinkRow + for rows.Next() { + var row ente.FileLinkRow + err = rows.Scan(&row.LinkID, &row.FileID, &row.OwnerID, &row.IsDisabled, + &row.ValidTill, &row.DeviceLimit, &row.EnableDownload, + &row.PassHash, &row.Nonce, &row.MemLimit, + &row.OpsLimit, &row.CreatedAt, &row.UpdatedAt) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to scan public file url row") + } + result = append(result, &row) + } + return result, nil +} + +func (pcr *FileLinkRepository) DisableLinkForFiles(ctx context.Context, fileIDs []int64) error { + if len(fileIDs) == 0 { + return nil + } + query := `UPDATE public_file_tokens SET is_disabled = TRUE WHERE file_id = ANY($1)` + _, err := pcr.DB.ExecContext(ctx, query, fileIDs) + if err != nil { + return stacktrace.Propagate(err, "failed to disable public file links") + } + return nil +} + +// DisableLinksForUser will disable all public file links for the given user +func (pcr *FileLinkRepository) DisableLinksForUser(ctx context.Context, linkID string) error { + _, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET is_disabled = TRUE WHERE id = $1`, linkID) + if err != nil { + return stacktrace.Propagate(err, "failed to disable public file link") + } + return nil +} func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.FileLinkRow, error) { row := pcr.DB.QueryRowContext(ctx, @@ -98,6 +147,23 @@ func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessT return &result, nil } +func (pcr *FileLinkRepository) GetFileUrlRowByFileID(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) { + row := pcr.DB.QueryRowContext(ctx, + `SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit + created_at, updated_at + from public_file_tokens + where file_id = $1 and is_disabled = FALSE`, fileID) + var result = ente.FileLinkRow{} + err := row.Scan(&result.LinkID, &result.FileID, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, ente.ErrNotFound + } + return nil, stacktrace.Propagate(err, "failed to get public file url summary by file ID") + } + return &result, nil +} + // UpdateLink will update the row for corresponding public file token func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.FileLinkRow) error { _, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET valid_till = $1, device_limit = $2, diff --git a/server/pkg/utils/auth/auth.go b/server/pkg/utils/auth/auth.go index 85acc995c3..8b52808e36 100644 --- a/server/pkg/utils/auth/auth.go +++ b/server/pkg/utils/auth/auth.go @@ -121,6 +121,8 @@ func GetCastToken(c *gin.Context) string { return token } +// GetAccessTokenJWT fetches the JWT access token from the request header or query parameters. +// This token is issued by server on password verification of links that are protected by password. func GetAccessTokenJWT(c *gin.Context) string { token := c.GetHeader("X-Auth-Access-Token-JWT") if token == "" {