[Server] Single file link (#6566)
## Description Adds 4 authenticate API for - Creating link for individual file - Update Link - Delete Link - Fetch all links (based on header, the server will return particular app's link) For link preview - API to get Info (pending discussion) - API to get file attributes (pending discussion) - APIs to get thumbnail and file - API to verify password Pending - [x] Clean up on account deletion - [x] Clean up on file deletion - [x] Clean up history for disabled links ## Tests Basic santiy check during client integration
This commit is contained in:
@@ -5,6 +5,9 @@ import (
|
||||
"database/sql"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/controller/collections"
|
||||
publicCtrl "github.com/ente-io/museum/pkg/controller/public"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -14,8 +17,6 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ente-io/museum/pkg/controller/collections"
|
||||
|
||||
"github.com/ente-io/museum/ente/base"
|
||||
"github.com/ente-io/museum/pkg/controller/emergency"
|
||||
"github.com/ente-io/museum/pkg/controller/file_copy"
|
||||
@@ -97,6 +98,7 @@ func main() {
|
||||
}
|
||||
|
||||
viper.SetDefault("apps.public-albums", "https://albums.ente.io")
|
||||
viper.SetDefault("apps.public-locker", "https://locker.ente.io")
|
||||
viper.SetDefault("apps.accounts", "https://accounts.ente.io")
|
||||
viper.SetDefault("apps.cast", "https://cast.ente.io")
|
||||
viper.SetDefault("apps.family", "https://family.ente.io")
|
||||
@@ -174,11 +176,13 @@ func main() {
|
||||
fileRepo := &repo.FileRepository{DB: db, S3Config: s3Config, QueueRepo: queueRepo,
|
||||
ObjectRepo: objectRepo, ObjectCleanupRepo: objectCleanupRepo,
|
||||
ObjectCopiesRepo: objectCopiesRepo, UsageRepo: usageRepo}
|
||||
fileLinkRepo := public.NewFileLinkRepo(db)
|
||||
fileDataRepo := &fileDataRepo.Repository{DB: db, ObjectCleanupRepo: objectCleanupRepo}
|
||||
familyRepo := &repo.FamilyRepository{DB: db}
|
||||
trashRepo := &repo.TrashRepository{DB: db, ObjectRepo: objectRepo, FileRepo: fileRepo, QueueRepo: queueRepo}
|
||||
publicCollectionRepo := repo.NewPublicCollectionRepository(db, viper.GetString("apps.public-albums"))
|
||||
collectionRepo := &repo.CollectionRepository{DB: db, FileRepo: fileRepo, PublicCollectionRepo: publicCollectionRepo,
|
||||
trashRepo := &repo.TrashRepository{DB: db, ObjectRepo: objectRepo, FileRepo: fileRepo, QueueRepo: queueRepo, FileLinkRepo: fileLinkRepo}
|
||||
collectionLinkRepo := public.NewCollectionLinkRepository(db, viper.GetString("apps.public-albums"))
|
||||
|
||||
collectionRepo := &repo.CollectionRepository{DB: db, FileRepo: fileRepo, CollectionLinkRepo: collectionLinkRepo,
|
||||
TrashRepo: trashRepo, SecretEncryptionKey: secretEncryptionKeyBytes, QueueRepo: queueRepo, LatencyLogger: latencyLogger}
|
||||
pushRepo := &repo.PushTokenRepository{DB: db}
|
||||
kexRepo := &kex.Repository{
|
||||
@@ -300,26 +304,27 @@ func main() {
|
||||
UsageRepo: usageRepo,
|
||||
}
|
||||
|
||||
publicCollectionCtrl := &controller.PublicCollectionController{
|
||||
collectionLinkCtrl := &publicCtrl.CollectionLinkController{
|
||||
FileController: fileController,
|
||||
EmailNotificationCtrl: emailNotificationCtrl,
|
||||
PublicCollectionRepo: publicCollectionRepo,
|
||||
CollectionLinkRepo: collectionLinkRepo,
|
||||
FileLinkRepo: fileLinkRepo,
|
||||
CollectionRepo: collectionRepo,
|
||||
UserRepo: userRepo,
|
||||
JwtSecret: jwtSecretBytes,
|
||||
}
|
||||
|
||||
collectionController := &collections.CollectionController{
|
||||
CollectionRepo: collectionRepo,
|
||||
EmailCtrl: emailNotificationCtrl,
|
||||
AccessCtrl: accessCtrl,
|
||||
PublicCollectionCtrl: publicCollectionCtrl,
|
||||
UserRepo: userRepo,
|
||||
FileRepo: fileRepo,
|
||||
CastRepo: &castDb,
|
||||
BillingCtrl: billingController,
|
||||
QueueRepo: queueRepo,
|
||||
TaskRepo: taskLockingRepo,
|
||||
CollectionRepo: collectionRepo,
|
||||
EmailCtrl: emailNotificationCtrl,
|
||||
AccessCtrl: accessCtrl,
|
||||
CollectionLinkCtrl: collectionLinkCtrl,
|
||||
UserRepo: userRepo,
|
||||
FileRepo: fileRepo,
|
||||
CastRepo: &castDb,
|
||||
BillingCtrl: billingController,
|
||||
QueueRepo: queueRepo,
|
||||
TaskRepo: taskLockingRepo,
|
||||
}
|
||||
|
||||
kexCtrl := &kexCtrl.Controller{
|
||||
@@ -351,6 +356,12 @@ func main() {
|
||||
userCache,
|
||||
userCacheCtrl,
|
||||
)
|
||||
fileLinkCtrl := &publicCtrl.FileLinkController{
|
||||
FileController: fileController,
|
||||
FileLinkRepo: fileLinkRepo,
|
||||
FileRepo: fileRepo,
|
||||
JwtSecret: jwtSecretBytes,
|
||||
}
|
||||
|
||||
passkeyCtrl := &controller.PasskeyController{
|
||||
Repo: passkeysRepo,
|
||||
@@ -358,14 +369,21 @@ func main() {
|
||||
}
|
||||
|
||||
authMiddleware := middleware.AuthMiddleware{UserAuthRepo: userAuthRepo, Cache: authCache, UserController: userController}
|
||||
accessTokenMiddleware := middleware.AccessTokenMiddleware{
|
||||
PublicCollectionRepo: publicCollectionRepo,
|
||||
PublicCollectionCtrl: publicCollectionCtrl,
|
||||
collectionLinkMiddleware := middleware.CollectionLinkMiddleware{
|
||||
CollectionLinkRepo: collectionLinkRepo,
|
||||
PublicCollectionCtrl: collectionLinkCtrl,
|
||||
CollectionRepo: collectionRepo,
|
||||
Cache: accessTokenCache,
|
||||
BillingCtrl: billingController,
|
||||
DiscordController: discordController,
|
||||
}
|
||||
fileLinkMiddleware := &middleware.FileLinkMiddleware{
|
||||
FileLinkRepo: fileLinkRepo,
|
||||
FileLinkCtrl: fileLinkCtrl,
|
||||
Cache: accessTokenCache,
|
||||
BillingCtrl: billingController,
|
||||
DiscordController: discordController,
|
||||
}
|
||||
|
||||
if environment != "local" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -404,7 +422,9 @@ func main() {
|
||||
familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
|
||||
|
||||
publicCollectionAPI := server.Group("/public-collection")
|
||||
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
|
||||
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), collectionLinkMiddleware.Authenticate(urlSanitizer))
|
||||
fileLinkApi := server.GET("/file-link")
|
||||
fileLinkApi.Use(rateLimiter.GlobalRateLimiter(), fileLinkMiddleware.Authenticate(urlSanitizer))
|
||||
|
||||
healthCheckHandler := &api.HealthCheckHandler{
|
||||
DB: db,
|
||||
@@ -432,6 +452,7 @@ func main() {
|
||||
Controller: fileController,
|
||||
FileCopyCtrl: fileCopyCtrl,
|
||||
FileDataCtrl: fileDataCtrl,
|
||||
FileUrlCtrl: fileLinkCtrl,
|
||||
}
|
||||
privateAPI.GET("/files/upload-urls", fileHandler.GetUploadURLs)
|
||||
privateAPI.GET("/files/multipart-upload-urls", fileHandler.GetMultipartUploadURLs)
|
||||
@@ -440,6 +461,11 @@ func main() {
|
||||
privateAPI.GET("/files/preview/:fileID", fileHandler.GetThumbnail)
|
||||
privateAPI.GET("/files/preview/v2/:fileID", fileHandler.GetThumbnail)
|
||||
|
||||
privateAPI.POST("/files/share-url", fileHandler.ShareUrl)
|
||||
privateAPI.PUT("/files/share-url", fileHandler.UpdateFileURL)
|
||||
privateAPI.DELETE("/files/share-url/:fileID", fileHandler.DisableUrl)
|
||||
privateAPI.GET("/files/share-urls/", fileHandler.GetUrls)
|
||||
|
||||
privateAPI.PUT("/files/data", fileHandler.PutFileData)
|
||||
privateAPI.PUT("/files/video-data", fileHandler.PutVideoData)
|
||||
privateAPI.POST("/files/data/status-diff", fileHandler.FileDataStatusDiff)
|
||||
@@ -566,13 +592,19 @@ func main() {
|
||||
privateAPI.PUT("/collections/sharee-magic-metadata", collectionHandler.ShareeMagicMetadataUpdate)
|
||||
|
||||
publicCollectionHandler := &api.PublicCollectionHandler{
|
||||
Controller: publicCollectionCtrl,
|
||||
Controller: collectionLinkCtrl,
|
||||
FileCtrl: fileController,
|
||||
CollectionCtrl: collectionController,
|
||||
FileDataCtrl: fileDataCtrl,
|
||||
StorageBonusController: storageBonusCtrl,
|
||||
}
|
||||
|
||||
fileLinkApi.GET("/info", fileHandler.LinkInfo)
|
||||
fileLinkApi.GET("/pass-info", fileHandler.PasswordInfo)
|
||||
fileLinkApi.GET("/thumbnail", fileHandler.LinkThumbnail)
|
||||
fileLinkApi.GET("/file", fileHandler.LinkFile)
|
||||
fileLinkApi.POST("/verify-password", fileHandler.VerifyPassword)
|
||||
|
||||
publicCollectionAPI.GET("/files/preview/:fileID", publicCollectionHandler.GetThumbnail)
|
||||
publicCollectionAPI.GET("/files/download/:fileID", publicCollectionHandler.GetFile)
|
||||
publicCollectionAPI.GET("/files/data/fetch", publicCollectionHandler.GetFileData)
|
||||
@@ -770,7 +802,7 @@ func main() {
|
||||
setKnownAPIs(server.Routes())
|
||||
setupAndStartBackgroundJobs(objectCleanupController, replicationController3, fileDataCtrl)
|
||||
setupAndStartCrons(
|
||||
userAuthRepo, publicCollectionRepo, twoFactorRepo, passkeysRepo, fileController, taskLockingRepo, emailNotificationCtrl,
|
||||
userAuthRepo, collectionLinkRepo, fileLinkRepo, twoFactorRepo, passkeysRepo, fileController, taskLockingRepo, emailNotificationCtrl,
|
||||
trashController, pushController, objectController, dataCleanupController, storageBonusCtrl, emergencyCtrl,
|
||||
embeddingController, healthCheckHandler, kexCtrl, castDb)
|
||||
|
||||
@@ -899,7 +931,8 @@ func setupAndStartBackgroundJobs(
|
||||
objectCleanupController.StartClearingOrphanObjects()
|
||||
}
|
||||
|
||||
func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionRepo *repo.PublicCollectionRepository,
|
||||
func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, collectionLinkRepo *public.CollectionLinkRepo,
|
||||
fileLinkRepo *public.FileLinkRepository,
|
||||
twoFactorRepo *repo.TwoFactorRepository, passkeysRepo *passkey.Repository, fileController *controller.FileController,
|
||||
taskRepo *repo.TaskLockRepository, emailNotificationCtrl *email.EmailNotificationController,
|
||||
trashController *controller.TrashController, pushController *controller.PushController,
|
||||
@@ -925,7 +958,8 @@ func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionR
|
||||
schedule(c, "@every 24h", func() {
|
||||
_ = userAuthRepo.RemoveDeletedTokens(timeUtil.MicrosecondsBeforeDays(30))
|
||||
_ = castDb.DeleteOldSessions(context.Background(), timeUtil.MicrosecondsBeforeDays(7))
|
||||
_ = publicCollectionRepo.CleanupAccessHistory(context.Background())
|
||||
_ = collectionLinkRepo.CleanupAccessHistory(context.Background())
|
||||
_ = fileLinkRepo.CleanupAccessHistory(context.Background())
|
||||
})
|
||||
|
||||
schedule(c, "@every 1m", func() {
|
||||
|
||||
@@ -79,9 +79,14 @@ http:
|
||||
apps:
|
||||
# Default is https://albums.ente.io
|
||||
#
|
||||
# If you're running a self hosted instance and wish to serve public links,
|
||||
# If you're running a self hosted instance and wish to serve public links for photos,
|
||||
# set this to the URL where your albums web app is running.
|
||||
public-albums:
|
||||
# Default is https://locker.ente.io
|
||||
#
|
||||
# If you're running a self-hosted instance and wish to serve public links for locker,
|
||||
# set this to the URL where your albums web app is running.
|
||||
public-locker:
|
||||
# Default is https://cast.ente.io
|
||||
cast:
|
||||
# Default is https://accounts.ente.io
|
||||
|
||||
@@ -97,8 +97,8 @@ var ErrUserDeleted = errors.New("user account has been deleted")
|
||||
// ErrLockUnavailable is thrown when a lock could not be acquired
|
||||
var ErrLockUnavailable = errors.New("could not acquire lock")
|
||||
|
||||
// ErrActiveLinkAlreadyExists is thrown when the collection already has active public link
|
||||
var ErrActiveLinkAlreadyExists = errors.New("Collection already has active public link")
|
||||
// ErrActiveLinkAlreadyExists is thrown when an active link already exists for entity
|
||||
var ErrActiveLinkAlreadyExists = errors.New("link already exists for this entity")
|
||||
|
||||
// ErrNotImplemented indicates that the action that we tried to perform is not
|
||||
// available at this museum instance. e.g. this could be something that is not
|
||||
@@ -176,6 +176,11 @@ var ErrMaxPasskeysReached = ApiError{
|
||||
Message: "Max passkeys limit reached",
|
||||
HttpStatusCode: http.StatusConflict,
|
||||
}
|
||||
var ErrPassProtectedResource = ApiError{
|
||||
Code: "PASS_PROTECTED_RESOURCE",
|
||||
Message: "This resource is password protected",
|
||||
HttpStatusCode: http.StatusForbidden,
|
||||
}
|
||||
|
||||
var ErrCastPermissionDenied = ApiError{
|
||||
Code: "CAST_PERMISSION_DENIED",
|
||||
|
||||
94
server/ente/file_link.go
Normal file
94
server/ente/file_link.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package ente
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
)
|
||||
|
||||
// CreateFileUrl represents an encrypted file in the system
|
||||
type CreateFileUrl struct {
|
||||
FileID int64 `json:"fileID" binding:"required"`
|
||||
App App `json:"app" binding:"required"`
|
||||
}
|
||||
|
||||
// UpdateFileUrl ..
|
||||
type UpdateFileUrl struct {
|
||||
LinkID string `json:"linkID" binding:"required"`
|
||||
FileID int64 `json:"fileID" binding:"required"`
|
||||
ValidTill *int64 `json:"validTill"`
|
||||
DeviceLimit *int `json:"deviceLimit"`
|
||||
PassHash *string
|
||||
Nonce *string
|
||||
MemLimit *int64
|
||||
OpsLimit *int64
|
||||
EnableDownload *bool `json:"enableDownload"`
|
||||
DisablePassword *bool `json:"disablePassword"`
|
||||
}
|
||||
|
||||
func (ut *UpdateFileUrl) Validate() error {
|
||||
if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil &&
|
||||
ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil {
|
||||
return NewBadRequestWithMessage("all parameters are missing")
|
||||
}
|
||||
|
||||
if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) {
|
||||
return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit))
|
||||
}
|
||||
|
||||
if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() {
|
||||
return NewBadRequestWithMessage("valid till should be greater than current timestamp")
|
||||
}
|
||||
|
||||
var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil
|
||||
var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil
|
||||
|
||||
if !(allPassParamsMissing || allPassParamsPresent) {
|
||||
return NewBadRequestWithMessage("all password params should be either present or missing")
|
||||
}
|
||||
|
||||
if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword {
|
||||
return NewBadRequestWithMessage("can not set and disable password in same request")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FileLinkRow struct {
|
||||
LinkID string
|
||||
OwnerID int64
|
||||
FileID int64
|
||||
Token string
|
||||
DeviceLimit int
|
||||
ValidTill int64
|
||||
IsDisabled bool
|
||||
PassHash *string
|
||||
Nonce *string
|
||||
MemLimit *int64
|
||||
OpsLimit *int64
|
||||
EnableDownload bool
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
type FileUrl struct {
|
||||
LinkID string `json:"linkID" binding:"required"`
|
||||
URL string `json:"url" binding:"required"`
|
||||
OwnerID int64 `json:"ownerID" binding:"required"`
|
||||
FileID int64 `json:"fileID" binding:"required"`
|
||||
ValidTill int64 `json:"validTill"`
|
||||
DeviceLimit int `json:"deviceLimit"`
|
||||
PasswordEnabled bool `json:"passwordEnabled"`
|
||||
// Nonce contains the nonce value for the password if the link is password protected.
|
||||
Nonce *string `json:"nonce,omitempty"`
|
||||
MemLimit *int64 `json:"memLimit,omitempty"`
|
||||
OpsLimit *int64 `json:"opsLimit,omitempty"`
|
||||
EnableDownload bool `json:"enableDownload"`
|
||||
CreatedAt int64 `json:"createdAt"`
|
||||
}
|
||||
|
||||
type FileLinkAccessContext struct {
|
||||
LinkID string
|
||||
IP string
|
||||
UserAgent string
|
||||
FileID int64
|
||||
OwnerID int64
|
||||
}
|
||||
@@ -40,13 +40,13 @@ func (w WebCommonJWTClaim) Valid() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublicAlbumPasswordClaim refer to token granted post public album password verification
|
||||
type PublicAlbumPasswordClaim struct {
|
||||
// LinkPasswordClaim refer to token granted post link password verification
|
||||
type LinkPasswordClaim struct {
|
||||
PassHash string `json:"passKey"`
|
||||
ExpiryTime int64 `json:"expiryTime"`
|
||||
}
|
||||
|
||||
func (c PublicAlbumPasswordClaim) Valid() error {
|
||||
func (c LinkPasswordClaim) Valid() error {
|
||||
if c.ExpiryTime < time.Microseconds() {
|
||||
return errors.New("token expired")
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package ente
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
)
|
||||
@@ -32,6 +32,33 @@ type UpdatePublicAccessTokenRequest struct {
|
||||
EnableJoin *bool `json:"enableJoin"`
|
||||
}
|
||||
|
||||
func (ut *UpdatePublicAccessTokenRequest) Validate() error {
|
||||
if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil &&
|
||||
ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil && ut.EnableCollect == nil {
|
||||
return NewBadRequestWithMessage("all parameters are missing")
|
||||
}
|
||||
|
||||
if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) {
|
||||
return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit))
|
||||
}
|
||||
|
||||
if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() {
|
||||
return NewBadRequestWithMessage("valid till should be greater than current timestamp")
|
||||
}
|
||||
|
||||
var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil
|
||||
var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil
|
||||
|
||||
if !(allPassParamsMissing || allPassParamsPresent) {
|
||||
return NewBadRequestWithMessage("all password params should be either present or missing")
|
||||
}
|
||||
|
||||
if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword {
|
||||
return NewBadRequestWithMessage("can not set and disable password in same request")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type VerifyPasswordRequest struct {
|
||||
PassHash string `json:"passHash" binding:"required"`
|
||||
}
|
||||
@@ -40,8 +67,8 @@ type VerifyPasswordResponse struct {
|
||||
JWTToken string `json:"jwtToken"`
|
||||
}
|
||||
|
||||
// PublicCollectionToken represents row entity for public_collection_token table
|
||||
type PublicCollectionToken struct {
|
||||
// CollectionLinkRow represents row entity for public_collection_token table
|
||||
type CollectionLinkRow struct {
|
||||
ID int64
|
||||
CollectionID int64
|
||||
Token string
|
||||
@@ -57,7 +84,7 @@ type PublicCollectionToken struct {
|
||||
EnableJoin bool
|
||||
}
|
||||
|
||||
func (p PublicCollectionToken) CanJoin() error {
|
||||
func (p CollectionLinkRow) CanJoin() error {
|
||||
if p.IsDisabled {
|
||||
return NewBadRequestWithMessage("link disabled")
|
||||
}
|
||||
|
||||
3
server/migrations/103_single_file_url.down.sql
Normal file
3
server/migrations/103_single_file_url.down.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
|
||||
DROP TABLE IF EXISTS public_file_tokens_access_history;
|
||||
DROP TABLE IF EXISTS public_file_tokens;
|
||||
46
server/migrations/103_single_file_url.up.sql
Normal file
46
server/migrations/103_single_file_url.up.sql
Normal file
@@ -0,0 +1,46 @@
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS public_file_tokens
|
||||
(
|
||||
id text primary key,
|
||||
file_id bigint NOT NULL,
|
||||
owner_id bigint NOT NULL,
|
||||
app text NOT NULL,
|
||||
access_token text not null,
|
||||
valid_till bigint not null DEFAULT 0,
|
||||
device_limit int not null DEFAULT 0,
|
||||
is_disabled bool not null DEFAULT FALSE,
|
||||
enable_download bool not null DEFAULT TRUE,
|
||||
pw_hash TEXT,
|
||||
pw_nonce TEXT,
|
||||
mem_limit BIGINT,
|
||||
ops_limit BIGINT,
|
||||
created_at bigint NOT NULL DEFAULT now_utc_micro_seconds(),
|
||||
updated_at bigint NOT NULL DEFAULT now_utc_micro_seconds()
|
||||
);
|
||||
|
||||
|
||||
CREATE OR REPLACE TRIGGER update_public_file_tokens_updated_at
|
||||
BEFORE UPDATE
|
||||
ON public_file_tokens
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE
|
||||
trigger_updated_at_microseconds_column();
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS public_file_tokens_access_history
|
||||
(
|
||||
id text NOT NULL,
|
||||
ip text not null,
|
||||
user_agent text not null,
|
||||
created_at bigint NOT NULL DEFAULT now_utc_micro_seconds(),
|
||||
CONSTRAINT unique_access_id_ip_ua UNIQUE (id, ip, user_agent),
|
||||
CONSTRAINT fk_public_file_history_token_id
|
||||
FOREIGN KEY (id)
|
||||
REFERENCES public_file_tokens (id)
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS public_file_token_unique_idx ON public_file_tokens (access_token) WHERE is_disabled = FALSE;
|
||||
CREATE INDEX IF NOT EXISTS public_file_tokens_owner_id_updated_at_idx ON public_file_tokens (owner_id, updated_at);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS public_active_file_link_unique_idx ON public_file_tokens (file_id, is_disabled) WHERE is_disabled = FALSE;
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
"github.com/ente-io/museum/pkg/utils/handler"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
@@ -172,35 +171,6 @@ func (h *CollectionHandler) UpdateShareURL(c *gin.Context) {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
if req.DeviceLimit == nil && req.ValidTill == nil && req.DisablePassword == nil &&
|
||||
req.Nonce == nil && req.PassHash == nil && req.EnableDownload == nil && req.EnableCollect == nil {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all parameters are missing"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.DeviceLimit != nil && (*req.DeviceLimit < 0 || *req.DeviceLimit > controller.DeviceLimitThreshold) {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("device limit: %d out of range", *req.DeviceLimit)))
|
||||
return
|
||||
}
|
||||
|
||||
if req.ValidTill != nil && *req.ValidTill != 0 && *req.ValidTill < time.Microseconds() {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "valid till should be greater than current timestamp"))
|
||||
return
|
||||
}
|
||||
|
||||
var allPassParamsMissing = req.Nonce == nil && req.PassHash == nil && req.MemLimit == nil && req.OpsLimit == nil
|
||||
var allPassParamsPresent = req.Nonce != nil && req.PassHash != nil && req.MemLimit != nil && req.OpsLimit != nil
|
||||
|
||||
if !(allPassParamsMissing || allPassParamsPresent) {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all password params should be either present or missing"))
|
||||
return
|
||||
}
|
||||
|
||||
if allPassParamsPresent && req.DisablePassword != nil && *req.DisablePassword {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "can not set and disable password in same request"))
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.Controller.UpdateShareURL(c, auth.GetUserID(c.Request.Header), req)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/controller/file_copy"
|
||||
"github.com/ente-io/museum/pkg/controller/filedata"
|
||||
"github.com/ente-io/museum/pkg/controller/public"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
// FileHandler exposes request handlers for all encrypted file related requests
|
||||
type FileHandler struct {
|
||||
Controller *controller.FileController
|
||||
FileUrlCtrl *public.FileLinkController
|
||||
FileCopyCtrl *file_copy.FileCopyController
|
||||
FileDataCtrl *filedata.Controller
|
||||
}
|
||||
|
||||
141
server/pkg/api/file_link.go
Normal file
141
server/pkg/api/file_link.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
"github.com/ente-io/museum/pkg/utils/handler"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ShareUrl a sharable url for the file
|
||||
func (h *FileHandler) ShareUrl(c *gin.Context) {
|
||||
var file ente.CreateFileUrl
|
||||
if err := c.ShouldBindJSON(&file); err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.FileUrlCtrl.CreateLink(c, file)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *FileHandler) LinkInfo(c *gin.Context) {
|
||||
resp, err := h.FileUrlCtrl.Info(c)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"file": resp,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FileHandler) PasswordInfo(c *gin.Context) {
|
||||
resp, err := h.FileUrlCtrl.PassInfo(c)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"nonce": resp.Nonce,
|
||||
"opsLimit": resp.OpsLimit,
|
||||
"memLimit": resp.MemLimit,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FileHandler) LinkThumbnail(c *gin.Context) {
|
||||
linkCtx := auth.MustGetFileLinkAccessContext(c)
|
||||
url, err := h.Controller.GetThumbnailURL(c, linkCtx.OwnerID, linkCtx.FileID)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
}
|
||||
|
||||
func (h *FileHandler) LinkFile(c *gin.Context) {
|
||||
linkCtx := auth.MustGetFileLinkAccessContext(c)
|
||||
url, err := h.Controller.GetFileURL(c, linkCtx.OwnerID, linkCtx.FileID)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
}
|
||||
|
||||
func (h *FileHandler) DisableUrl(c *gin.Context) {
|
||||
cID, err := strconv.ParseInt(c.Param("fileID"), 10, 64)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, ""))
|
||||
return
|
||||
}
|
||||
err = h.FileUrlCtrl.Disable(c, cID)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
}
|
||||
|
||||
func (h *FileHandler) GetUrls(c *gin.Context) {
|
||||
sinceTime, err := strconv.ParseInt(c.Query("sinceTime"), 10, 64)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "sinceTime parsing failed"))
|
||||
return
|
||||
}
|
||||
limit := 500
|
||||
if c.Query("limit") != "" {
|
||||
limit, err = strconv.Atoi(c.Query("limit"))
|
||||
if err != nil || limit < 1 {
|
||||
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, ""))
|
||||
return
|
||||
}
|
||||
}
|
||||
response, err := h.FileUrlCtrl.GetUrls(c, sinceTime, int64(limit))
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"diff": response,
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyPassword verifies the password for given link access token and return signed jwt token if it's valid
|
||||
func (h *FileHandler) VerifyPassword(c *gin.Context) {
|
||||
var req ente.VerifyPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
resp, err := h.FileUrlCtrl.VerifyPassword(c, req)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// UpdateFileURL updates the share URL for a file
|
||||
func (h *FileHandler) UpdateFileURL(c *gin.Context) {
|
||||
var req ente.UpdateFileUrl
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
response, err := h.FileUrlCtrl.UpdateSharedUrl(c, req)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"result": response,
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
fileData "github.com/ente-io/museum/ente/filedata"
|
||||
"github.com/ente-io/museum/pkg/controller/collections"
|
||||
"github.com/ente-io/museum/pkg/controller/filedata"
|
||||
"github.com/ente-io/museum/pkg/controller/public"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -20,7 +21,7 @@ import (
|
||||
|
||||
// PublicCollectionHandler exposes request handlers for publicly accessible collections
|
||||
type PublicCollectionHandler struct {
|
||||
Controller *controller.PublicCollectionController
|
||||
Controller *public.CollectionLinkController
|
||||
FileCtrl *controller.FileController
|
||||
CollectionCtrl *collections.CollectionController
|
||||
FileDataCtrl *filedata.Controller
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/ente-io/museum/pkg/controller/access"
|
||||
"github.com/ente-io/museum/pkg/controller/email"
|
||||
"github.com/ente-io/museum/pkg/controller/public"
|
||||
"github.com/ente-io/museum/pkg/repo/cast"
|
||||
"github.com/ente-io/museum/pkg/utils/array"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
@@ -24,16 +25,16 @@ const (
|
||||
|
||||
// CollectionController encapsulates logic that deals with collections
|
||||
type CollectionController struct {
|
||||
PublicCollectionCtrl *controller.PublicCollectionController
|
||||
EmailCtrl *email.EmailNotificationController
|
||||
AccessCtrl access.Controller
|
||||
BillingCtrl *controller.BillingController
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
UserRepo *repo.UserRepository
|
||||
FileRepo *repo.FileRepository
|
||||
QueueRepo *repo.QueueRepository
|
||||
CastRepo *cast.Repository
|
||||
TaskRepo *repo.TaskLockRepository
|
||||
CollectionLinkCtrl *public.CollectionLinkController
|
||||
EmailCtrl *email.EmailNotificationController
|
||||
AccessCtrl access.Controller
|
||||
BillingCtrl *controller.BillingController
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
UserRepo *repo.UserRepository
|
||||
FileRepo *repo.FileRepository
|
||||
QueueRepo *repo.QueueRepository
|
||||
CastRepo *cast.Repository
|
||||
TaskRepo *repo.TaskLockRepository
|
||||
}
|
||||
|
||||
// Create creates a collection
|
||||
@@ -148,7 +149,7 @@ func (c *CollectionController) TrashV3(ctx *gin.Context, req ente.TrashCollectio
|
||||
}
|
||||
|
||||
}
|
||||
err = c.PublicCollectionCtrl.Disable(ctx, cID)
|
||||
err = c.CollectionLinkCtrl.Disable(ctx, cID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to disabled public share url")
|
||||
}
|
||||
@@ -209,7 +210,7 @@ func (c *CollectionController) HandleAccountDeletion(ctx context.Context, userID
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to revoke cast token for user")
|
||||
}
|
||||
err = c.PublicCollectionCtrl.HandleAccountDeletion(ctx, userID, logger)
|
||||
err = c.CollectionLinkCtrl.HandleAccountDeletion(ctx, userID, logger)
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
|
||||
@@ -70,21 +70,21 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec
|
||||
if !collection.AllowSharing() {
|
||||
return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("joining %s is not allowed", collection.Type))
|
||||
}
|
||||
publicCollectionToken, err := c.PublicCollectionCtrl.GetActivePublicCollectionToken(ctx, req.CollectionID)
|
||||
collectionLinkToken, err := c.CollectionLinkCtrl.GetActiveCollectionLinkToken(ctx, req.CollectionID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
if canJoin := publicCollectionToken.CanJoin(); canJoin != nil {
|
||||
if canJoin := collectionLinkToken.CanJoin(); canJoin != nil {
|
||||
return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("can not join collection: %s", canJoin.Error()))
|
||||
}
|
||||
accessToken := auth.GetAccessToken(ctx)
|
||||
if publicCollectionToken.Token != accessToken {
|
||||
if collectionLinkToken.Token != accessToken {
|
||||
return stacktrace.Propagate(ente.ErrPermissionDenied, "token doesn't match collection")
|
||||
}
|
||||
if publicCollectionToken.PassHash != nil && *publicCollectionToken.PassHash != "" {
|
||||
if collectionLinkToken.PassHash != nil && *collectionLinkToken.PassHash != "" {
|
||||
accessTokenJWT := auth.GetAccessTokenJWT(ctx)
|
||||
if passCheckErr := c.PublicCollectionCtrl.ValidateJWTToken(ctx, accessTokenJWT, *publicCollectionToken.PassHash); passCheckErr != nil {
|
||||
if passCheckErr := c.CollectionLinkCtrl.ValidateJWTToken(ctx, accessTokenJWT, *collectionLinkToken.PassHash); passCheckErr != nil {
|
||||
return stacktrace.Propagate(passCheckErr, "")
|
||||
}
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
role := ente.VIEWER
|
||||
if publicCollectionToken.EnableCollect {
|
||||
if collectionLinkToken.EnableCollect {
|
||||
role = ente.COLLABORATOR
|
||||
}
|
||||
joinErr := c.CollectionRepo.Share(req.CollectionID, collection.Owner.ID, userID, req.EncryptedKey, role, time.Microseconds())
|
||||
@@ -197,7 +197,7 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e
|
||||
if err != nil {
|
||||
return ente.PublicURL{}, stacktrace.Propagate(err, "")
|
||||
}
|
||||
response, err := c.PublicCollectionCtrl.CreateAccessToken(ctx, req)
|
||||
response, err := c.CollectionLinkCtrl.CreateLink(ctx, req)
|
||||
if err != nil {
|
||||
return ente.PublicURL{}, stacktrace.Propagate(err, "")
|
||||
}
|
||||
@@ -205,20 +205,26 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e
|
||||
}
|
||||
|
||||
// UpdateShareURL updates the shared url configuration
|
||||
func (c *CollectionController) UpdateShareURL(ctx context.Context, userID int64, req ente.UpdatePublicAccessTokenRequest) (
|
||||
ente.PublicURL, error) {
|
||||
func (c *CollectionController) UpdateShareURL(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
req ente.UpdatePublicAccessTokenRequest,
|
||||
) (*ente.PublicURL, error) {
|
||||
if err := req.Validate(); err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
if err := c.verifyOwnership(req.CollectionID, userID); err != nil {
|
||||
return ente.PublicURL{}, stacktrace.Propagate(err, "")
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
err := c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
|
||||
if err != nil {
|
||||
return ente.PublicURL{}, stacktrace.Propagate(err, "")
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
response, err := c.PublicCollectionCtrl.UpdateSharedUrl(ctx, req)
|
||||
response, err := c.CollectionLinkCtrl.UpdateSharedUrl(ctx, req)
|
||||
if err != nil {
|
||||
return ente.PublicURL{}, stacktrace.Propagate(err, "")
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return response, nil
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// DisableSharedURL disable a public auth-token for the given collectionID
|
||||
@@ -226,7 +232,7 @@ func (c *CollectionController) DisableSharedURL(ctx context.Context, userID int6
|
||||
if err := c.verifyOwnership(cID, userID); err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
err := c.PublicCollectionCtrl.Disable(ctx, cID)
|
||||
err := c.CollectionLinkCtrl.Disable(ctx, cID)
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package controller
|
||||
package public
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
enteJWT "github.com/ente-io/museum/ente/jwt"
|
||||
emailCtrl "github.com/ente-io/museum/pkg/controller/email"
|
||||
"github.com/ente-io/museum/pkg/repo"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/lithammer/shortuuid/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -49,23 +49,24 @@ const (
|
||||
AbuseLimitExceededTemplate = "report_limit_exceeded_alert.html"
|
||||
)
|
||||
|
||||
// PublicCollectionController controls share collection operations
|
||||
type PublicCollectionController struct {
|
||||
FileController *FileController
|
||||
// CollectionLinkController controls share collection operations
|
||||
type CollectionLinkController struct {
|
||||
FileController *controller.FileController
|
||||
EmailNotificationCtrl *emailCtrl.EmailNotificationController
|
||||
PublicCollectionRepo *repo.PublicCollectionRepository
|
||||
CollectionLinkRepo *public.CollectionLinkRepo
|
||||
FileLinkRepo *public.FileLinkRepository
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
UserRepo *repo.UserRepository
|
||||
JwtSecret []byte
|
||||
}
|
||||
|
||||
func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req ente.CreatePublicAccessTokenRequest) (ente.PublicURL, error) {
|
||||
func (c *CollectionLinkController) CreateLink(ctx context.Context, req ente.CreatePublicAccessTokenRequest) (ente.PublicURL, error) {
|
||||
accessToken := shortuuid.New()[0:AccessTokenLength]
|
||||
err := c.PublicCollectionRepo.
|
||||
err := c.CollectionLinkRepo.
|
||||
Insert(ctx, req.CollectionID, accessToken, req.ValidTill, req.DeviceLimit, req.EnableCollect, req.EnableJoin)
|
||||
if err != nil {
|
||||
if errors.Is(err, ente.ErrActiveLinkAlreadyExists) {
|
||||
collectionToPubUrlMap, err2 := c.PublicCollectionRepo.GetCollectionToActivePublicURLMap(ctx, []int64{req.CollectionID})
|
||||
collectionToPubUrlMap, err2 := c.CollectionLinkRepo.GetCollectionToActivePublicURLMap(ctx, []int64{req.CollectionID})
|
||||
if err2 != nil {
|
||||
return ente.PublicURL{}, stacktrace.Propagate(err2, "")
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req
|
||||
}
|
||||
}
|
||||
response := ente.PublicURL{
|
||||
URL: c.PublicCollectionRepo.GetAlbumUrl(accessToken),
|
||||
URL: c.CollectionLinkRepo.GetAlbumUrl(accessToken),
|
||||
ValidTill: req.ValidTill,
|
||||
DeviceLimit: req.DeviceLimit,
|
||||
EnableDownload: true,
|
||||
@@ -91,11 +92,11 @@ func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *PublicCollectionController) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) {
|
||||
return c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, collectionID)
|
||||
func (c *CollectionLinkController) GetActiveCollectionLinkToken(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) {
|
||||
return c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, collectionID)
|
||||
}
|
||||
|
||||
func (c *PublicCollectionController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) {
|
||||
func (c *CollectionLinkController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) {
|
||||
collection, err := c.GetPublicCollection(ctx, true)
|
||||
if err != nil {
|
||||
return ente.File{}, stacktrace.Propagate(err, "")
|
||||
@@ -118,13 +119,13 @@ func (c *PublicCollectionController) CreateFile(ctx *gin.Context, file ente.File
|
||||
}
|
||||
|
||||
// Disable all public accessTokens generated for the given cID till date.
|
||||
func (c *PublicCollectionController) Disable(ctx context.Context, cID int64) error {
|
||||
err := c.PublicCollectionRepo.DisableSharing(ctx, cID)
|
||||
func (c *CollectionLinkController) Disable(ctx context.Context, cID int64) error {
|
||||
err := c.CollectionLinkRepo.DisableSharing(ctx, cID)
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) {
|
||||
publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, req.CollectionID)
|
||||
func (c *CollectionLinkController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) {
|
||||
publicCollectionToken, err := c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, req.CollectionID)
|
||||
if err != nil {
|
||||
return ente.PublicURL{}, err
|
||||
}
|
||||
@@ -154,12 +155,12 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en
|
||||
if req.EnableJoin != nil {
|
||||
publicCollectionToken.EnableJoin = *req.EnableJoin
|
||||
}
|
||||
err = c.PublicCollectionRepo.UpdatePublicCollectionToken(ctx, publicCollectionToken)
|
||||
err = c.CollectionLinkRepo.UpdatePublicCollectionToken(ctx, publicCollectionToken)
|
||||
if err != nil {
|
||||
return ente.PublicURL{}, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return ente.PublicURL{
|
||||
URL: c.PublicCollectionRepo.GetAlbumUrl(publicCollectionToken.Token),
|
||||
URL: c.CollectionLinkRepo.GetAlbumUrl(publicCollectionToken.Token),
|
||||
DeviceLimit: publicCollectionToken.DeviceLimit,
|
||||
ValidTill: publicCollectionToken.ValidTill,
|
||||
EnableDownload: publicCollectionToken.EnableDownload,
|
||||
@@ -176,58 +177,23 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en
|
||||
// used by the client to pass in other requests for public collection.
|
||||
// Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force
|
||||
// attack for guessing password.
|
||||
func (c *PublicCollectionController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
|
||||
func (c *CollectionLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
|
||||
accessContext := auth.MustGetPublicAccessContext(ctx)
|
||||
publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, accessContext.CollectionID)
|
||||
collectionLinkRow, err := c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, accessContext.CollectionID)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to get public collection info")
|
||||
}
|
||||
if publicCollectionToken.PassHash == nil || *publicCollectionToken.PassHash == "" {
|
||||
return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link")
|
||||
}
|
||||
if req.PassHash != *publicCollectionToken.PassHash {
|
||||
return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link")
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.PublicAlbumPasswordClaim{
|
||||
PassHash: req.PassHash,
|
||||
ExpiryTime: time.NDaysFromNow(365),
|
||||
})
|
||||
// Sign and get the complete encoded token as a string using the secret
|
||||
tokenString, err := token.SignedString(c.JwtSecret)
|
||||
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return &ente.VerifyPasswordResponse{
|
||||
JWTToken: tokenString,
|
||||
}, nil
|
||||
return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req)
|
||||
}
|
||||
|
||||
func (c *PublicCollectionController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
|
||||
token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.PublicAlbumPasswordClaim{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil
|
||||
}
|
||||
return c.JwtSecret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "JWT parsed failed")
|
||||
}
|
||||
claims, ok := token.Claims.(*enteJWT.PublicAlbumPasswordClaim)
|
||||
|
||||
if !ok {
|
||||
return stacktrace.Propagate(errors.New("no claim in jwt token"), "")
|
||||
}
|
||||
if token.Valid && claims.PassHash == passwordHash {
|
||||
return nil
|
||||
}
|
||||
return ente.ErrInvalidPassword
|
||||
func (c *CollectionLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
|
||||
return validateJWTToken(c.JwtSecret, jwtToken, passwordHash)
|
||||
}
|
||||
|
||||
// ReportAbuse captures abuse report for a publicly shared collection.
|
||||
// It will also disable the accessToken for the collection if total abuse reports for the said collection
|
||||
// reaches AutoDisableAbuseThreshold
|
||||
func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.AbuseReportRequest) error {
|
||||
func (c *CollectionLinkController) ReportAbuse(ctx *gin.Context, req ente.AbuseReportRequest) error {
|
||||
accessContext := auth.MustGetPublicAccessContext(ctx)
|
||||
readableReason, found := AllowedReasons[req.Reason]
|
||||
if !found {
|
||||
@@ -235,11 +201,11 @@ func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.Abus
|
||||
}
|
||||
logrus.WithField("collectionID", accessContext.CollectionID).Error("CRITICAL: received abuse report")
|
||||
|
||||
err := c.PublicCollectionRepo.RecordAbuseReport(ctx, accessContext, req.URL, req.Reason, req.Details)
|
||||
err := c.CollectionLinkRepo.RecordAbuseReport(ctx, accessContext, req.URL, req.Reason, req.Details)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
count, err := c.PublicCollectionRepo.GetAbuseReportCount(ctx, accessContext)
|
||||
count, err := c.CollectionLinkRepo.GetAbuseReportCount(ctx, accessContext)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
@@ -253,7 +219,7 @@ func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.Abus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PublicCollectionController) onAbuseReportReceived(collectionID int64, report ente.AbuseReportRequest, readableReason string, abuseCount int64) {
|
||||
func (c *CollectionLinkController) onAbuseReportReceived(collectionID int64, report ente.AbuseReportRequest, readableReason string, abuseCount int64) {
|
||||
collection, err := c.CollectionRepo.Get(collectionID)
|
||||
if err != nil {
|
||||
logrus.Error("Could not get collection for abuse report")
|
||||
@@ -292,9 +258,9 @@ func (c *PublicCollectionController) onAbuseReportReceived(collectionID int64, r
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PublicCollectionController) HandleAccountDeletion(ctx context.Context, userID int64, logger *logrus.Entry) error {
|
||||
func (c *CollectionLinkController) HandleAccountDeletion(ctx context.Context, userID int64, logger *logrus.Entry) error {
|
||||
logger.Info("updating public collection on account deletion")
|
||||
collectionIDs, err := c.PublicCollectionRepo.GetActivePublicTokenForUser(ctx, userID)
|
||||
collectionIDs, err := c.CollectionLinkRepo.GetActivePublicTokenForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
@@ -305,12 +271,12 @@ func (c *PublicCollectionController) HandleAccountDeletion(ctx context.Context,
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return c.FileLinkRepo.DisableLinksForUser(ctx, userID)
|
||||
}
|
||||
|
||||
// GetPublicCollection will return collection info for a public url.
|
||||
// is mustAllowCollect is set to true but the underlying collection doesn't allow uploading
|
||||
func (c *PublicCollectionController) GetPublicCollection(ctx *gin.Context, mustAllowCollect bool) (ente.Collection, error) {
|
||||
func (c *CollectionLinkController) GetPublicCollection(ctx *gin.Context, mustAllowCollect bool) (ente.Collection, error) {
|
||||
accessContext := auth.MustGetPublicAccessContext(ctx)
|
||||
collection, err := c.CollectionRepo.Get(accessContext.CollectionID)
|
||||
if err != nil {
|
||||
162
server/pkg/controller/public/file_link.go
Normal file
162
server/pkg/controller/public/file_link.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/ente-io/museum/pkg/repo"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lithammer/shortuuid/v3"
|
||||
)
|
||||
|
||||
// FileLinkController controls share collection operations
|
||||
type FileLinkController struct {
|
||||
FileController *controller.FileController
|
||||
FileLinkRepo *public.FileLinkRepository
|
||||
FileRepo *repo.FileRepository
|
||||
JwtSecret []byte
|
||||
}
|
||||
|
||||
func (c *FileLinkController) CreateLink(ctx *gin.Context, req ente.CreateFileUrl) (*ente.FileUrl, error) {
|
||||
actorUserID := auth.GetUserID(ctx.Request.Header)
|
||||
app := auth.GetApp(ctx)
|
||||
if req.App != app {
|
||||
return nil, stacktrace.Propagate(ente.NewBadRequestWithMessage("app mismatch"), "app mismatch")
|
||||
}
|
||||
file, err := c.FileRepo.GetFileAttributes(req.FileID)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to get file attributes")
|
||||
}
|
||||
if actorUserID != file.OwnerID {
|
||||
return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "")
|
||||
}
|
||||
accessToken := shortuuid.New()[0:AccessTokenLength]
|
||||
_, err = c.FileLinkRepo.Insert(ctx, req.FileID, actorUserID, accessToken, app)
|
||||
if err == nil || err == ente.ErrActiveLinkAlreadyExists {
|
||||
row, rowErr := c.FileLinkRepo.GetFileUrlRowByFileID(ctx, req.FileID)
|
||||
if rowErr != nil {
|
||||
return nil, stacktrace.Propagate(rowErr, "failed to get active file url token")
|
||||
}
|
||||
return c.mapRowToFileUrl(ctx, row), nil
|
||||
}
|
||||
return nil, stacktrace.Propagate(err, "failed to create public file link")
|
||||
}
|
||||
|
||||
// Disable all public accessTokens generated for the given fileID till date.
|
||||
func (c *FileLinkController) Disable(ctx *gin.Context, fileID int64) error {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
file, err := c.FileRepo.GetFileAttributes(fileID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to get file attributes")
|
||||
}
|
||||
if userID != file.OwnerID {
|
||||
return stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "")
|
||||
}
|
||||
return c.FileLinkRepo.DisableLinkForFiles(ctx, []int64{fileID})
|
||||
}
|
||||
|
||||
func (c *FileLinkController) GetUrls(ctx *gin.Context, sinceTime int64, limit int64) ([]*ente.FileUrl, error) {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
app := auth.GetApp(ctx)
|
||||
fileLinks, err := c.FileLinkRepo.GetFileUrls(ctx, userID, sinceTime, limit, app)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to get file urls")
|
||||
}
|
||||
var fileUrls []*ente.FileUrl
|
||||
for _, row := range fileLinks {
|
||||
fileUrls = append(fileUrls, c.mapRowToFileUrl(ctx, row))
|
||||
}
|
||||
return fileUrls, nil
|
||||
}
|
||||
|
||||
func (c *FileLinkController) UpdateSharedUrl(ctx *gin.Context, req ente.UpdateFileUrl) (*ente.FileUrl, error) {
|
||||
if err := req.Validate(); err != nil {
|
||||
return nil, stacktrace.Propagate(err, "invalid request")
|
||||
}
|
||||
fileLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, req.FileID)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to get file link info")
|
||||
}
|
||||
if fileLinkRow.OwnerID != auth.GetUserID(ctx.Request.Header) {
|
||||
return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "")
|
||||
}
|
||||
if req.ValidTill != nil {
|
||||
fileLinkRow.ValidTill = *req.ValidTill
|
||||
}
|
||||
if req.DeviceLimit != nil {
|
||||
fileLinkRow.DeviceLimit = *req.DeviceLimit
|
||||
}
|
||||
if req.PassHash != nil && req.Nonce != nil && req.OpsLimit != nil && req.MemLimit != nil {
|
||||
fileLinkRow.PassHash = req.PassHash
|
||||
fileLinkRow.Nonce = req.Nonce
|
||||
fileLinkRow.OpsLimit = req.OpsLimit
|
||||
fileLinkRow.MemLimit = req.MemLimit
|
||||
} else if req.DisablePassword != nil && *req.DisablePassword {
|
||||
fileLinkRow.PassHash = nil
|
||||
fileLinkRow.Nonce = nil
|
||||
fileLinkRow.OpsLimit = nil
|
||||
fileLinkRow.MemLimit = nil
|
||||
}
|
||||
if req.EnableDownload != nil {
|
||||
fileLinkRow.EnableDownload = *req.EnableDownload
|
||||
}
|
||||
|
||||
err = c.FileLinkRepo.UpdateLink(ctx, *fileLinkRow)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return c.mapRowToFileUrl(ctx, fileLinkRow), nil
|
||||
}
|
||||
|
||||
func (c *FileLinkController) Info(ctx *gin.Context) (*ente.File, error) {
|
||||
accessContext := auth.MustGetFileLinkAccessContext(ctx)
|
||||
return c.FileRepo.GetFileAttributes(accessContext.FileID)
|
||||
}
|
||||
|
||||
func (c *FileLinkController) PassInfo(ctx *gin.Context) (*ente.FileLinkRow, error) {
|
||||
accessContext := auth.MustGetFileLinkAccessContext(ctx)
|
||||
return c.FileLinkRepo.GetFileUrlRowByFileID(ctx, accessContext.FileID)
|
||||
}
|
||||
|
||||
// VerifyPassword verifies if the user has provided correct pw hash. If yes, it returns a signed jwt token which can be
|
||||
// used by the client to pass in other requests for public collection.
|
||||
// Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force
|
||||
// attack for guessing password.
|
||||
func (c *FileLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
|
||||
accessContext := auth.MustGetFileLinkAccessContext(ctx)
|
||||
collectionLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, accessContext.FileID)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to get public collection info")
|
||||
}
|
||||
return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req)
|
||||
}
|
||||
|
||||
func (c *FileLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
|
||||
return validateJWTToken(c.JwtSecret, jwtToken, passwordHash)
|
||||
}
|
||||
|
||||
func (c *FileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl {
|
||||
app := auth.GetApp(ctx)
|
||||
var url string
|
||||
if app == ente.Locker {
|
||||
url = c.FileLinkRepo.LockerFileLink(row.Token)
|
||||
} else {
|
||||
url = c.FileLinkRepo.PhotoLink(row.Token)
|
||||
}
|
||||
return &ente.FileUrl{
|
||||
LinkID: row.LinkID,
|
||||
FileID: row.FileID,
|
||||
URL: url,
|
||||
OwnerID: row.OwnerID,
|
||||
ValidTill: row.ValidTill,
|
||||
DeviceLimit: row.DeviceLimit,
|
||||
PasswordEnabled: row.PassHash != nil,
|
||||
Nonce: row.Nonce,
|
||||
OpsLimit: row.OpsLimit,
|
||||
MemLimit: row.MemLimit,
|
||||
EnableDownload: row.EnableDownload,
|
||||
CreatedAt: row.CreatedAt,
|
||||
}
|
||||
}
|
||||
54
server/pkg/controller/public/link_common.go
Normal file
54
server/pkg/controller/public/link_common.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/ente"
|
||||
enteJWT "github.com/ente-io/museum/ente/jwt"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
func validateJWTToken(secret []byte, jwtToken string, passwordHash string) error {
|
||||
token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.LinkPasswordClaim{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "JWT parsed failed")
|
||||
}
|
||||
claims, ok := token.Claims.(*enteJWT.LinkPasswordClaim)
|
||||
|
||||
if !ok {
|
||||
return stacktrace.Propagate(errors.New("no claim in jwt token"), "")
|
||||
}
|
||||
if token.Valid && claims.PassHash == passwordHash {
|
||||
return nil
|
||||
}
|
||||
return ente.ErrInvalidPassword
|
||||
}
|
||||
|
||||
func verifyPassword(secret []byte, expectedPassHash *string, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
|
||||
if expectedPassHash == nil || *expectedPassHash == "" {
|
||||
return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link")
|
||||
}
|
||||
if req.PassHash != *expectedPassHash {
|
||||
return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link")
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.LinkPasswordClaim{
|
||||
PassHash: req.PassHash,
|
||||
ExpiryTime: time.NDaysFromNow(365),
|
||||
})
|
||||
// Sign and get the complete encoded token as a string using the secret
|
||||
tokenString, err := token.SignedString(secret)
|
||||
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return &ente.VerifyPasswordResponse{
|
||||
JWTToken: tokenString,
|
||||
}, nil
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
public2 "github.com/ente-io/museum/pkg/controller/public"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
"net/http"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
@@ -24,20 +26,20 @@ import (
|
||||
var passwordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"}
|
||||
var whitelistedCollectionShareIDs = []int64{111}
|
||||
|
||||
// AccessTokenMiddleware intercepts and authenticates incoming requests
|
||||
type AccessTokenMiddleware struct {
|
||||
PublicCollectionRepo *repo.PublicCollectionRepository
|
||||
PublicCollectionCtrl *controller.PublicCollectionController
|
||||
// CollectionLinkMiddleware intercepts and authenticates incoming requests
|
||||
type CollectionLinkMiddleware struct {
|
||||
CollectionLinkRepo *public.CollectionLinkRepo
|
||||
PublicCollectionCtrl *public2.CollectionLinkController
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
Cache *cache.Cache
|
||||
BillingCtrl *controller.BillingController
|
||||
DiscordController *discord.DiscordController
|
||||
}
|
||||
|
||||
// AccessTokenAuthMiddleware returns a middle ware that extracts the `X-Auth-Access-Token`
|
||||
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
|
||||
// within the header of a request and uses it to validate the access token and set the
|
||||
// ente.PublicAccessContext with auth.PublicAccessKey as key
|
||||
func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
||||
func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
accessToken := auth.GetAccessToken(c)
|
||||
if accessToken == "" {
|
||||
@@ -52,7 +54,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g
|
||||
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
|
||||
cachedValue, cacheHit := m.Cache.Get(cacheKey)
|
||||
if !cacheHit {
|
||||
publicCollectionSummary, err = m.PublicCollectionRepo.GetCollectionSummaryByToken(c, accessToken)
|
||||
publicCollectionSummary, err = m.CollectionLinkRepo.GetCollectionSummaryByToken(c, accessToken)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
@@ -112,7 +114,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
func (m *CollectionLinkMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
userID, err := m.CollectionRepo.GetOwnerID(cID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
@@ -120,7 +122,7 @@ func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
|
||||
}
|
||||
|
||||
func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
collectionSummary ente.PublicCollectionSummary, ip string, ua string) (bool, error) {
|
||||
// skip deviceLimit check & record keeping for requests via CF worker
|
||||
if network.IsCFWorkerIP(ip) {
|
||||
@@ -128,7 +130,7 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
}
|
||||
|
||||
sharedID := collectionSummary.ID
|
||||
hasAccessedInPast, err := m.PublicCollectionRepo.AccessedInPast(ctx, sharedID, ip, ua)
|
||||
hasAccessedInPast, err := m.CollectionLinkRepo.AccessedInPast(ctx, sharedID, ip, ua)
|
||||
if err != nil {
|
||||
return false, stacktrace.Propagate(err, "")
|
||||
}
|
||||
@@ -136,17 +138,17 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
if hasAccessedInPast {
|
||||
return false, nil
|
||||
}
|
||||
count, err := m.PublicCollectionRepo.GetUniqueAccessCount(ctx, sharedID)
|
||||
count, err := m.CollectionLinkRepo.GetUniqueAccessCount(ctx, sharedID)
|
||||
if err != nil {
|
||||
return false, stacktrace.Propagate(err, "failed to get unique access count")
|
||||
}
|
||||
|
||||
deviceLimit := int64(collectionSummary.DeviceLimit)
|
||||
if deviceLimit == controller.DeviceLimitThreshold {
|
||||
deviceLimit = controller.DeviceLimitThresholdMultiplier * controller.DeviceLimitThreshold
|
||||
if deviceLimit == public2.DeviceLimitThreshold {
|
||||
deviceLimit = public2.DeviceLimitThresholdMultiplier * public2.DeviceLimitThreshold
|
||||
}
|
||||
|
||||
if count >= controller.DeviceLimitWarningThreshold {
|
||||
if count >= public2.DeviceLimitWarningThreshold {
|
||||
if !array.Int64InList(sharedID, whitelistedCollectionShareIDs) {
|
||||
m.DiscordController.NotifyPotentialAbuse(
|
||||
fmt.Sprintf("Album exceeds warning threshold: {CollectionID: %d, ShareID: %d}",
|
||||
@@ -157,12 +159,12 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
if deviceLimit > 0 && count >= deviceLimit {
|
||||
return true, nil
|
||||
}
|
||||
err = m.PublicCollectionRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
|
||||
err = m.CollectionLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
|
||||
return false, stacktrace.Propagate(err, "failed to record access history")
|
||||
}
|
||||
|
||||
// validatePassword will verify if the user is provided correct password for the public album
|
||||
func (m *AccessTokenMiddleware) validatePassword(c *gin.Context, reqPath string,
|
||||
func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath string,
|
||||
collectionSummary ente.PublicCollectionSummary) error {
|
||||
if array.StringInList(reqPath, passwordWhiteListedURLs) {
|
||||
return nil
|
||||
168
server/pkg/middleware/file_link.go
Normal file
168
server/pkg/middleware/file_link.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
publicCtrl "github.com/ente-io/museum/pkg/controller/public"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
"github.com/ente-io/museum/pkg/utils/array"
|
||||
"net/http"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/ente-io/museum/pkg/controller/discord"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
"github.com/ente-io/museum/pkg/utils/network"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var filePasswordWhiteListedURLs = []string{"/file-link/pass-info", "/file-link/verify-password"}
|
||||
|
||||
// FileLinkMiddleware intercepts and authenticates incoming requests
|
||||
type FileLinkMiddleware struct {
|
||||
FileLinkRepo *public.FileLinkRepository
|
||||
FileLinkCtrl *publicCtrl.FileLinkController
|
||||
Cache *cache.Cache
|
||||
BillingCtrl *controller.BillingController
|
||||
DiscordController *discord.DiscordController
|
||||
}
|
||||
|
||||
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
|
||||
// within the header of a request and uses it to validate the access token and set the
|
||||
// ente.PublicAccessContext with auth.PublicAccessKey as key
|
||||
func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
accessToken := auth.GetAccessToken(c)
|
||||
if accessToken == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing accessToken"})
|
||||
return
|
||||
}
|
||||
clientIP := network.GetClientIP(c)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
|
||||
cachedValue, cacheHit := m.Cache.Get(cacheKey)
|
||||
var fileLinkRow *ente.FileLinkRow
|
||||
var err error
|
||||
if !cacheHit {
|
||||
fileLinkRow, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Info("failed to get file link row by token")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
if fileLinkRow.IsDisabled {
|
||||
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"})
|
||||
return
|
||||
}
|
||||
// validate if user still has active paid subscription
|
||||
if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(fileLinkRow.OwnerID, true); err != nil {
|
||||
logrus.WithError(err).Info("failed to verify active paid subscription")
|
||||
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
|
||||
return
|
||||
}
|
||||
|
||||
// validate device limit
|
||||
reached, limitErr := m.isDeviceLimitReached(c, fileLinkRow, clientIP, userAgent)
|
||||
if limitErr != nil {
|
||||
logrus.WithError(limitErr).Error("failed to check device limit")
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"})
|
||||
return
|
||||
}
|
||||
if reached {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "reached device limit"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
fileLinkRow = cachedValue.(*ente.FileLinkRow)
|
||||
}
|
||||
|
||||
if fileLinkRow.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry
|
||||
fileLinkRow.ValidTill < time.Microseconds() {
|
||||
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"})
|
||||
return
|
||||
}
|
||||
|
||||
// checks password protected public collection
|
||||
if fileLinkRow.PassHash != nil && *fileLinkRow.PassHash != "" {
|
||||
reqPath := urlSanitizer(c)
|
||||
if err = m.validatePassword(c, reqPath, fileLinkRow); err != nil {
|
||||
logrus.WithError(err).Warn("password validation failed")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !cacheHit {
|
||||
m.Cache.Set(cacheKey, fileLinkRow, cache.DefaultExpiration)
|
||||
}
|
||||
|
||||
c.Set(auth.FileLinkAccessKey, &ente.FileLinkAccessContext{
|
||||
LinkID: fileLinkRow.LinkID,
|
||||
IP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
FileID: fileLinkRow.FileID,
|
||||
OwnerID: fileLinkRow.OwnerID,
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *FileLinkMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
collectionSummary *ente.FileLinkRow, ip string, ua string) (bool, error) {
|
||||
// skip deviceLimit check & record keeping for requests via CF worker
|
||||
if network.IsCFWorkerIP(ip) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
sharedID := collectionSummary.LinkID
|
||||
hasAccessedInPast, err := m.FileLinkRepo.AccessedInPast(ctx, sharedID, ip, ua)
|
||||
if err != nil {
|
||||
return false, stacktrace.Propagate(err, "")
|
||||
}
|
||||
// if the device has accessed the url in the past, let it access it now as well, irrespective of device limit.
|
||||
if hasAccessedInPast {
|
||||
return false, nil
|
||||
}
|
||||
count, err := m.FileLinkRepo.GetUniqueAccessCount(ctx, sharedID)
|
||||
if err != nil {
|
||||
return false, stacktrace.Propagate(err, "failed to get unique access count")
|
||||
}
|
||||
|
||||
deviceLimit := int64(collectionSummary.DeviceLimit)
|
||||
if deviceLimit == publicCtrl.DeviceLimitThreshold {
|
||||
deviceLimit = publicCtrl.DeviceLimitThresholdMultiplier * publicCtrl.DeviceLimitThreshold
|
||||
}
|
||||
|
||||
if count >= publicCtrl.DeviceLimitWarningThreshold {
|
||||
m.DiscordController.NotifyPotentialAbuse(
|
||||
fmt.Sprintf("FileLink exceeds warning threshold: {FileID: %d, ShareID: %s}",
|
||||
collectionSummary.FileID, collectionSummary.LinkID))
|
||||
}
|
||||
|
||||
if deviceLimit > 0 && count >= deviceLimit {
|
||||
return true, nil
|
||||
}
|
||||
err = m.FileLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
|
||||
return false, stacktrace.Propagate(err, "failed to record access history")
|
||||
}
|
||||
|
||||
// validatePassword will verify if the user is provided correct password for the public album
|
||||
func (m *FileLinkMiddleware) validatePassword(
|
||||
c *gin.Context,
|
||||
reqPath string,
|
||||
fileLinkRow *ente.FileLinkRow,
|
||||
) error {
|
||||
accessTokenJWT := auth.GetAccessTokenJWT(c)
|
||||
if accessTokenJWT == "" {
|
||||
if array.StringInList(reqPath, filePasswordWhiteListedURLs) {
|
||||
return nil
|
||||
}
|
||||
return &ente.ErrPassProtectedResource
|
||||
}
|
||||
return m.FileLinkCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash)
|
||||
}
|
||||
@@ -140,6 +140,7 @@ func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limi
|
||||
reqPath == "/users/verify-email" ||
|
||||
reqPath == "/user/change-email" ||
|
||||
reqPath == "/public-collection/verify-password" ||
|
||||
reqPath == "/file-link/verify-password" ||
|
||||
reqPath == "/family/accept-invite" ||
|
||||
reqPath == "/users/srp/attributes" ||
|
||||
(reqPath == "/cast/device-info" && reqMethod == "POST") ||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
"strconv"
|
||||
t "time"
|
||||
|
||||
@@ -22,13 +23,13 @@ import (
|
||||
// CollectionRepository defines the methods for inserting, updating and
|
||||
// retrieving collection entities from the underlying repository
|
||||
type CollectionRepository struct {
|
||||
DB *sql.DB
|
||||
FileRepo *FileRepository
|
||||
PublicCollectionRepo *PublicCollectionRepository
|
||||
TrashRepo *TrashRepository
|
||||
SecretEncryptionKey []byte
|
||||
QueueRepo *QueueRepository
|
||||
LatencyLogger *prometheus.HistogramVec
|
||||
DB *sql.DB
|
||||
FileRepo *FileRepository
|
||||
CollectionLinkRepo *public.CollectionLinkRepo
|
||||
TrashRepo *TrashRepository
|
||||
SecretEncryptionKey []byte
|
||||
QueueRepo *QueueRepository
|
||||
LatencyLogger *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
type SharedCollection struct {
|
||||
@@ -74,7 +75,7 @@ func (repo *CollectionRepository) Get(collectionID int64) (ente.Collection, erro
|
||||
c.EncryptedName = encryptedName.String
|
||||
c.NameDecryptionNonce = nameDecryptionNonce.String
|
||||
}
|
||||
urlMap, err := repo.PublicCollectionRepo.GetCollectionToActivePublicURLMap(context.Background(), []int64{collectionID})
|
||||
urlMap, err := repo.CollectionLinkRepo.GetCollectionToActivePublicURLMap(context.Background(), []int64{collectionID})
|
||||
if err != nil {
|
||||
return ente.Collection{}, stacktrace.Propagate(err, "failed to get publicURL info")
|
||||
}
|
||||
@@ -174,7 +175,7 @@ pct.access_token, pct.valid_till, pct.device_limit, pct.created_at, pct.updated_
|
||||
if _, ok := addPublicUrlMap[pctToken.String]; !ok {
|
||||
addPublicUrlMap[pctToken.String] = true
|
||||
url := ente.PublicURL{
|
||||
URL: repo.PublicCollectionRepo.GetAlbumUrl(pctToken.String),
|
||||
URL: repo.CollectionLinkRepo.GetAlbumUrl(pctToken.String),
|
||||
DeviceLimit: int(pctDeviceLimit.Int32),
|
||||
ValidTill: pctValidTill.Int64,
|
||||
EnableDownload: pctEnableDownload.Bool,
|
||||
|
||||
@@ -638,6 +638,16 @@ func (repo *FileRepository) GetFileAttributesForCopy(fileIDs []int64) ([]ente.Fi
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (repo *FileRepository) GetFileAttributes(fileID int64) (*ente.File, error) {
|
||||
rows := repo.DB.QueryRow(`SELECT file_id, owner_id, file_decryption_header, thumbnail_decryption_header, metadata_decryption_header, encrypted_metadata, pub_magic_metadata FROM files WHERE file_id = $1`, fileID)
|
||||
var file ente.File
|
||||
err := rows.Scan(&file.ID, &file.OwnerID, &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader, &file.Metadata.DecryptionHeader, &file.Metadata.EncryptedData, &file.PubicMagicMetadata)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
// GetUsage gets the Storage usage of a user
|
||||
// Deprecated: GetUsage is deprecated, use UsageRepository.GetUsage
|
||||
func (repo *FileRepository) GetUsage(userID int64) (int64, error) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package repo
|
||||
package public
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,29 +13,29 @@ import (
|
||||
|
||||
const BaseShareURL = "https://albums.ente.io/?t=%s"
|
||||
|
||||
// PublicCollectionRepository defines the methods for inserting, updating and
|
||||
// CollectionLinkRepo defines the methods for inserting, updating and
|
||||
// retrieving entities related to public collections
|
||||
type PublicCollectionRepository struct {
|
||||
type CollectionLinkRepo struct {
|
||||
DB *sql.DB
|
||||
albumHost string
|
||||
}
|
||||
|
||||
// NewPublicCollectionRepository ..
|
||||
func NewPublicCollectionRepository(db *sql.DB, albumHost string) *PublicCollectionRepository {
|
||||
// NewCollectionLinkRepository ..
|
||||
func NewCollectionLinkRepository(db *sql.DB, albumHost string) *CollectionLinkRepo {
|
||||
if albumHost == "" {
|
||||
albumHost = "https://albums.ente.io"
|
||||
}
|
||||
return &PublicCollectionRepository{
|
||||
return &CollectionLinkRepo{
|
||||
DB: db,
|
||||
albumHost: albumHost,
|
||||
}
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) GetAlbumUrl(token string) string {
|
||||
func (pcr *CollectionLinkRepo) GetAlbumUrl(token string) string {
|
||||
return fmt.Sprintf("%s/?t=%s", pcr.albumHost, token)
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) Insert(ctx context.Context,
|
||||
func (pcr *CollectionLinkRepo) Insert(ctx context.Context,
|
||||
cID int64, token string, validTill int64, deviceLimit int, enableCollect bool, enableJoin *bool) error {
|
||||
// default value for enableJoin is true
|
||||
join := true
|
||||
@@ -51,7 +51,7 @@ func (pcr *PublicCollectionRepository) Insert(ctx context.Context,
|
||||
return stacktrace.Propagate(err, "failed to insert")
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) DisableSharing(ctx context.Context, cID int64) error {
|
||||
func (pcr *CollectionLinkRepo) DisableSharing(ctx context.Context, cID int64) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_collection_tokens SET is_disabled = true where
|
||||
collection_id = $1 and is_disabled = false`, cID)
|
||||
return stacktrace.Propagate(err, "failed to disable sharing")
|
||||
@@ -59,7 +59,7 @@ func (pcr *PublicCollectionRepository) DisableSharing(ctx context.Context, cID i
|
||||
|
||||
// GetCollectionToActivePublicURLMap will return map of collectionID to PublicURLs which are not disabled yet.
|
||||
// Note: The url could be expired or deviceLimit is already reached
|
||||
func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx context.Context, collectionIDs []int64) (map[int64][]ente.PublicURL, error) {
|
||||
func (pcr *CollectionLinkRepo) GetCollectionToActivePublicURLMap(ctx context.Context, collectionIDs []int64) (map[int64][]ente.PublicURL, error) {
|
||||
rows, err := pcr.DB.QueryContext(ctx, `SELECT collection_id, access_token, valid_till, device_limit, enable_download, enable_collect, enable_join, pw_nonce, mem_limit, ops_limit FROM
|
||||
public_collection_tokens WHERE collection_id = ANY($1) and is_disabled = FALSE`,
|
||||
pq.Array(collectionIDs))
|
||||
@@ -92,26 +92,26 @@ func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx con
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetActivePublicCollectionToken will return ente.PublicCollectionToken for given collection ID
|
||||
// GetActiveCollectionLinkRow will return ente.CollectionLinkRow for given collection ID
|
||||
// Note: The token could be expired or deviceLimit is already reached
|
||||
func (pcr *PublicCollectionRepository) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) {
|
||||
func (pcr *CollectionLinkRepo) GetActiveCollectionLinkRow(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `SELECT id, collection_id, access_token, valid_till, device_limit,
|
||||
is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download, enable_collect, enable_join FROM
|
||||
public_collection_tokens WHERE collection_id = $1 and is_disabled = FALSE`,
|
||||
collectionID)
|
||||
|
||||
//defer rows.Close()
|
||||
ret := ente.PublicCollectionToken{}
|
||||
ret := ente.CollectionLinkRow{}
|
||||
err := row.Scan(&ret.ID, &ret.CollectionID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit,
|
||||
&ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload, &ret.EnableCollect, &ret.EnableJoin)
|
||||
if err != nil {
|
||||
return ente.PublicCollectionToken{}, stacktrace.Propagate(err, "")
|
||||
return ente.CollectionLinkRow{}, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// UpdatePublicCollectionToken will update the row for corresponding public collection token
|
||||
func (pcr *PublicCollectionRepository) UpdatePublicCollectionToken(ctx context.Context, pct ente.PublicCollectionToken) error {
|
||||
func (pcr *CollectionLinkRepo) UpdatePublicCollectionToken(ctx context.Context, pct ente.CollectionLinkRow) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_collection_tokens SET valid_till = $1, device_limit = $2,
|
||||
pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7, enable_collect = $8, enable_join = $9
|
||||
where id = $10`,
|
||||
@@ -119,7 +119,7 @@ func (pcr *PublicCollectionRepository) UpdatePublicCollectionToken(ctx context.C
|
||||
return stacktrace.Propagate(err, "failed to update public collection token")
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) RecordAbuseReport(ctx context.Context, accessCtx ente.PublicAccessContext,
|
||||
func (pcr *CollectionLinkRepo) RecordAbuseReport(ctx context.Context, accessCtx ente.PublicAccessContext,
|
||||
url string, reason string, details ente.AbuseReportDetails) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_abuse_report
|
||||
(share_id, ip, user_agent, url, reason, details) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
@@ -128,7 +128,7 @@ func (pcr *PublicCollectionRepository) RecordAbuseReport(ctx context.Context, ac
|
||||
return stacktrace.Propagate(err, "failed to record abuse report")
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) GetAbuseReportCount(ctx context.Context, accessCtx ente.PublicAccessContext) (int64, error) {
|
||||
func (pcr *CollectionLinkRepo) GetAbuseReportCount(ctx context.Context, accessCtx ente.PublicAccessContext) (int64, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_abuse_report WHERE share_id = $1`, accessCtx.ID)
|
||||
var count int64 = 0
|
||||
err := row.Scan(&count)
|
||||
@@ -138,7 +138,7 @@ func (pcr *PublicCollectionRepository) GetAbuseReportCount(ctx context.Context,
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) GetUniqueAccessCount(ctx context.Context, shareId int64) (int64, error) {
|
||||
func (pcr *CollectionLinkRepo) GetUniqueAccessCount(ctx context.Context, shareId int64) (int64, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_collection_access_history WHERE share_id = $1`, shareId)
|
||||
var count int64 = 0
|
||||
err := row.Scan(&count)
|
||||
@@ -148,7 +148,7 @@ func (pcr *PublicCollectionRepository) GetUniqueAccessCount(ctx context.Context,
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error {
|
||||
func (pcr *CollectionLinkRepo) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_collection_access_history
|
||||
(share_id, ip, user_agent) VALUES ($1, $2, $3)
|
||||
ON CONFLICT ON CONSTRAINT unique_access_sid_ip_ua DO NOTHING;`,
|
||||
@@ -157,7 +157,7 @@ func (pcr *PublicCollectionRepository) RecordAccessHistory(ctx context.Context,
|
||||
}
|
||||
|
||||
// AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past
|
||||
func (pcr *PublicCollectionRepository) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) {
|
||||
func (pcr *CollectionLinkRepo) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `select share_id from public_collection_access_history where share_id =$1 and ip = $2 and user_agent = $3`,
|
||||
shareID, ip, ua)
|
||||
var tempID int64
|
||||
@@ -168,7 +168,7 @@ func (pcr *PublicCollectionRepository) AccessedInPast(ctx context.Context, share
|
||||
return true, stacktrace.Propagate(err, "failed to record access history")
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) GetCollectionSummaryByToken(ctx context.Context, accessToken string) (ente.PublicCollectionSummary, error) {
|
||||
func (pcr *CollectionLinkRepo) GetCollectionSummaryByToken(ctx context.Context, accessToken string) (ente.PublicCollectionSummary, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx,
|
||||
`SELECT sct.id, sct.collection_id, sct.is_disabled, sct.valid_till, sct.device_limit, sct.pw_hash,
|
||||
sct.created_at, sct.updated_at, count(ah.share_id)
|
||||
@@ -185,7 +185,7 @@ func (pcr *PublicCollectionRepository) GetCollectionSummaryByToken(ctx context.C
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (pcr *PublicCollectionRepository) GetActivePublicTokenForUser(ctx context.Context, userID int64) ([]int64, error) {
|
||||
func (pcr *CollectionLinkRepo) GetActivePublicTokenForUser(ctx context.Context, userID int64) ([]int64, error) {
|
||||
rows, err := pcr.DB.QueryContext(ctx, `select pt.collection_id from public_collection_tokens pt left join collections c on pt.collection_id = c.collection_id where pt.is_disabled = FALSE and c.owner_id= $1;`, userID)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
@@ -204,7 +204,7 @@ func (pcr *PublicCollectionRepository) GetActivePublicTokenForUser(ctx context.C
|
||||
}
|
||||
|
||||
// CleanupAccessHistory public_collection_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days
|
||||
func (pcr *PublicCollectionRepository) CleanupAccessHistory(ctx context.Context) error {
|
||||
func (pcr *CollectionLinkRepo) CleanupAccessHistory(ctx context.Context) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `DELETE FROM public_collection_access_history WHERE share_id IN (SELECT id FROM public_collection_tokens WHERE is_disabled = TRUE AND updated_at < (now_utc_micro_seconds() - (24::BIGINT * 30 * 60 * 60 * 1000 * 1000)))`)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to clean up public collection access history")
|
||||
218
server/pkg/repo/public/file_link.go
Normal file
218
server/pkg/repo/public/file_link.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/ente/base"
|
||||
"github.com/lib/pq"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/stacktrace"
|
||||
)
|
||||
|
||||
// FileLinkRepository defines the methods for inserting, updating and
|
||||
// retrieving entities related to public file
|
||||
type FileLinkRepository struct {
|
||||
DB *sql.DB
|
||||
photoHost string
|
||||
lockerHost string
|
||||
}
|
||||
|
||||
// NewFileLinkRepo ..
|
||||
func NewFileLinkRepo(db *sql.DB) *FileLinkRepository {
|
||||
albumHost := viper.GetString("apps.public-albums")
|
||||
if albumHost == "" {
|
||||
albumHost = "https://albums.ente.io"
|
||||
}
|
||||
lockerHost := viper.GetString("apps.public-locker")
|
||||
if lockerHost == "" {
|
||||
lockerHost = "https://locker.ente.io"
|
||||
}
|
||||
return &FileLinkRepository{
|
||||
DB: db,
|
||||
photoHost: albumHost,
|
||||
lockerHost: lockerHost,
|
||||
}
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) PhotoLink(token string) string {
|
||||
return fmt.Sprintf("%s/?t=%s", pcr.photoHost, token)
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) LockerFileLink(token string) string {
|
||||
return fmt.Sprintf("%s/?t=%s", pcr.lockerHost, token)
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) Insert(
|
||||
ctx context.Context,
|
||||
fileID int64,
|
||||
ownerID int64,
|
||||
token string,
|
||||
app ente.App,
|
||||
) (*string, error) {
|
||||
id, err := base.NewID("pft")
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to generate new ID for public file token")
|
||||
}
|
||||
_, err = pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens
|
||||
(id, file_id, owner_id, access_token, app) VALUES ($1, $2, $3, $4, $5)`,
|
||||
id, fileID, ownerID, token, string(app))
|
||||
if err != nil {
|
||||
if err.Error() == "pq: duplicate key value violates unique constraint \"public_active_file_link_unique_idx\"" {
|
||||
return nil, ente.ErrActiveLinkAlreadyExists
|
||||
}
|
||||
return nil, stacktrace.Propagate(err, "failed to insert")
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetActiveFileUrlToken will return ente.CollectionLinkRow for given collection ID
|
||||
// Note: The token could be expired or deviceLimit is already reached
|
||||
func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `SELECT id, file_id, owner_id, access_token, valid_till, device_limit,
|
||||
is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download FROM
|
||||
public_file_tokens WHERE file_id = $1 and is_disabled = FALSE`,
|
||||
fileID)
|
||||
|
||||
ret := ente.FileLinkRow{}
|
||||
err := row.Scan(&ret.LinkID, &ret.FileID, ret.OwnerID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit,
|
||||
&ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
func (pcr *FileLinkRepository) GetFileUrls(ctx context.Context, userID int64, sinceTime int64, limit int64, app ente.App) ([]*ente.FileLinkRow, error) {
|
||||
if limit <= 0 {
|
||||
limit = 500
|
||||
}
|
||||
query := `SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit,
|
||||
created_at, updated_at FROM public_file_tokens
|
||||
WHERE owner_id = $1 AND created_at > $2 AND app = $3 ORDER BY updated_at DESC LIMIT $4`
|
||||
rows, err := pcr.DB.QueryContext(ctx, query, userID, sinceTime, string(app), limit)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to get public file urls")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []*ente.FileLinkRow
|
||||
for rows.Next() {
|
||||
var row ente.FileLinkRow
|
||||
err = rows.Scan(&row.LinkID, &row.FileID, &row.OwnerID, &row.IsDisabled,
|
||||
&row.ValidTill, &row.DeviceLimit, &row.EnableDownload,
|
||||
&row.PassHash, &row.Nonce, &row.MemLimit,
|
||||
&row.OpsLimit, &row.CreatedAt, &row.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to scan public file url row")
|
||||
}
|
||||
result = append(result, &row)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) DisableLinkForFiles(ctx context.Context, fileIDs []int64) error {
|
||||
if len(fileIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
query := `UPDATE public_file_tokens SET is_disabled = TRUE WHERE file_id = ANY($1)`
|
||||
_, err := pcr.DB.ExecContext(ctx, query, pq.Array(fileIDs))
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to disable public file links")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableLinksForUser will disable all public file links for the given user
|
||||
func (pcr *FileLinkRepository) DisableLinksForUser(ctx context.Context, userID int64) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET is_disabled = TRUE WHERE owner_id = $1`, userID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to disable public file link")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.FileLinkRow, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx,
|
||||
`SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit
|
||||
created_at, updated_at
|
||||
from public_file_tokens
|
||||
where access_token = $1
|
||||
`, accessToken)
|
||||
var result = ente.FileLinkRow{}
|
||||
err := row.Scan(&result.LinkID, &result.FileID, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ente.ErrNotFound
|
||||
}
|
||||
return nil, stacktrace.Propagate(err, "failed to get public file url summary by token")
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) GetFileUrlRowByFileID(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx,
|
||||
`SELECT id, file_id, access_token, owner_id, is_disabled, enable_download, valid_till, device_limit, pw_hash, pw_nonce, mem_limit, ops_limit,
|
||||
created_at, updated_at
|
||||
from public_file_tokens
|
||||
where file_id = $1 and is_disabled = FALSE`, fileID)
|
||||
var result = ente.FileLinkRow{}
|
||||
err := row.Scan(&result.LinkID, &result.FileID, &result.Token, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ente.ErrNotFound
|
||||
}
|
||||
return nil, stacktrace.Propagate(err, "failed to get public file url summary by file ID")
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// UpdateLink will update the row for corresponding public file token
|
||||
func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.FileLinkRow) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET valid_till = $1, device_limit = $2,
|
||||
pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7
|
||||
where id = $8`,
|
||||
pct.ValidTill, pct.DeviceLimit, pct.PassHash, pct.Nonce, pct.MemLimit, pct.OpsLimit, pct.EnableDownload, pct.LinkID)
|
||||
return stacktrace.Propagate(err, "failed to update public file token")
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) GetUniqueAccessCount(ctx context.Context, linkId string) (int64, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_file_tokens_access_history WHERE id = $1`, linkId)
|
||||
var count int64 = 0
|
||||
err := row.Scan(&count)
|
||||
if err != nil {
|
||||
return -1, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID string, ip string, ua string) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens_access_history
|
||||
(id, ip, user_agent) VALUES ($1, $2, $3)
|
||||
ON CONFLICT ON CONSTRAINT unique_access_id_ip_ua DO NOTHING;`,
|
||||
shareID, ip, ua)
|
||||
return stacktrace.Propagate(err, "failed to record access history")
|
||||
}
|
||||
|
||||
// AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past
|
||||
func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID string, ip string, ua string) (bool, error) {
|
||||
row := pcr.DB.QueryRowContext(ctx, `select id from public_file_tokens_access_history where id =$1 and ip = $2 and user_agent = $3`,
|
||||
shareID, ip, ua)
|
||||
var tempID int64
|
||||
err := row.Scan(&tempID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return true, stacktrace.Propagate(err, "failed to record access history")
|
||||
}
|
||||
|
||||
// CleanupAccessHistory public_file_tokens_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days
|
||||
func (pcr *FileLinkRepository) CleanupAccessHistory(ctx context.Context) error {
|
||||
_, err := pcr.DB.ExecContext(ctx, `DELETE FROM public_file_tokens_access_history WHERE id IN (SELECT id FROM public_file_tokens WHERE is_disabled = TRUE AND updated_at < (now_utc_micro_seconds() - (24::BIGINT * 30 * 60 * 60 * 1000 * 1000)))`)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to clean up public file access history")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
"strings"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
@@ -32,10 +33,11 @@ type FileWithUpdatedAt struct {
|
||||
}
|
||||
|
||||
type TrashRepository struct {
|
||||
DB *sql.DB
|
||||
ObjectRepo *ObjectRepository
|
||||
FileRepo *FileRepository
|
||||
QueueRepo *QueueRepository
|
||||
DB *sql.DB
|
||||
ObjectRepo *ObjectRepository
|
||||
FileRepo *FileRepository
|
||||
QueueRepo *QueueRepository
|
||||
FileLinkRepo *public.FileLinkRepository
|
||||
}
|
||||
|
||||
func (t *TrashRepository) InsertItems(ctx context.Context, tx *sql.Tx, userID int64, items []ente.TrashItemRequest) error {
|
||||
@@ -156,6 +158,13 @@ func (t *TrashRepository) TrashFiles(fileIDs []int64, userID int64, trash ente.T
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
err = tx.Commit()
|
||||
|
||||
if err == nil {
|
||||
removeLinkErr := t.FileLinkRepo.DisableLinkForFiles(ctx, fileIDs)
|
||||
if removeLinkErr != nil {
|
||||
return stacktrace.Propagate(removeLinkErr, "failed to disable file links for files being trashed")
|
||||
}
|
||||
}
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
PublicAccessKey = "X-Public-Access-ID"
|
||||
CastContext = "X-Cast-Context"
|
||||
PublicAccessKey = "X-Public-Access-ID"
|
||||
FileLinkAccessKey = "X-Public-FileLink-Access-ID"
|
||||
CastContext = "X-Cast-Context"
|
||||
)
|
||||
|
||||
// GenerateRandomBytes returns securely generated random bytes.
|
||||
@@ -120,6 +121,8 @@ func GetCastToken(c *gin.Context) string {
|
||||
return token
|
||||
}
|
||||
|
||||
// GetAccessTokenJWT fetches the JWT access token from the request header or query parameters.
|
||||
// This token is issued by server on password verification of links that are protected by password.
|
||||
func GetAccessTokenJWT(c *gin.Context) string {
|
||||
token := c.GetHeader("X-Auth-Access-Token-JWT")
|
||||
if token == "" {
|
||||
@@ -132,6 +135,10 @@ func MustGetPublicAccessContext(c *gin.Context) ente.PublicAccessContext {
|
||||
return c.MustGet(PublicAccessKey).(ente.PublicAccessContext)
|
||||
}
|
||||
|
||||
func MustGetFileLinkAccessContext(c *gin.Context) *ente.FileLinkAccessContext {
|
||||
return c.MustGet(FileLinkAccessKey).(*ente.FileLinkAccessContext)
|
||||
}
|
||||
|
||||
func GetCastCtx(c *gin.Context) cast.AuthContext {
|
||||
return c.MustGet(CastContext).(cast.AuthContext)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user