[Server] Single file link (#6566)

## Description
Adds 4 authenticate API for
- Creating link for individual file
- Update Link
- Delete Link
- Fetch all links (based on header, the server will return particular
app's link)

For link preview
- API to get Info (pending discussion)
- API to get file attributes (pending discussion)
- APIs to get thumbnail and file
- API to verify password


Pending
- [x] Clean up on account deletion
- [x] Clean up on file deletion
- [x] Clean up history for disabled links

## Tests

Basic santiy check during client integration
This commit is contained in:
Neeraj
2025-08-04 14:41:50 +05:30
committed by GitHub
26 changed files with 1148 additions and 215 deletions

View File

@@ -5,6 +5,9 @@ import (
"database/sql"
b64 "encoding/base64"
"fmt"
"github.com/ente-io/museum/pkg/controller/collections"
publicCtrl "github.com/ente-io/museum/pkg/controller/public"
"github.com/ente-io/museum/pkg/repo/public"
"net/http"
"os"
"os/signal"
@@ -14,8 +17,6 @@ import (
"syscall"
"time"
"github.com/ente-io/museum/pkg/controller/collections"
"github.com/ente-io/museum/ente/base"
"github.com/ente-io/museum/pkg/controller/emergency"
"github.com/ente-io/museum/pkg/controller/file_copy"
@@ -97,6 +98,7 @@ func main() {
}
viper.SetDefault("apps.public-albums", "https://albums.ente.io")
viper.SetDefault("apps.public-locker", "https://locker.ente.io")
viper.SetDefault("apps.accounts", "https://accounts.ente.io")
viper.SetDefault("apps.cast", "https://cast.ente.io")
viper.SetDefault("apps.family", "https://family.ente.io")
@@ -174,11 +176,13 @@ func main() {
fileRepo := &repo.FileRepository{DB: db, S3Config: s3Config, QueueRepo: queueRepo,
ObjectRepo: objectRepo, ObjectCleanupRepo: objectCleanupRepo,
ObjectCopiesRepo: objectCopiesRepo, UsageRepo: usageRepo}
fileLinkRepo := public.NewFileLinkRepo(db)
fileDataRepo := &fileDataRepo.Repository{DB: db, ObjectCleanupRepo: objectCleanupRepo}
familyRepo := &repo.FamilyRepository{DB: db}
trashRepo := &repo.TrashRepository{DB: db, ObjectRepo: objectRepo, FileRepo: fileRepo, QueueRepo: queueRepo}
publicCollectionRepo := repo.NewPublicCollectionRepository(db, viper.GetString("apps.public-albums"))
collectionRepo := &repo.CollectionRepository{DB: db, FileRepo: fileRepo, PublicCollectionRepo: publicCollectionRepo,
trashRepo := &repo.TrashRepository{DB: db, ObjectRepo: objectRepo, FileRepo: fileRepo, QueueRepo: queueRepo, FileLinkRepo: fileLinkRepo}
collectionLinkRepo := public.NewCollectionLinkRepository(db, viper.GetString("apps.public-albums"))
collectionRepo := &repo.CollectionRepository{DB: db, FileRepo: fileRepo, CollectionLinkRepo: collectionLinkRepo,
TrashRepo: trashRepo, SecretEncryptionKey: secretEncryptionKeyBytes, QueueRepo: queueRepo, LatencyLogger: latencyLogger}
pushRepo := &repo.PushTokenRepository{DB: db}
kexRepo := &kex.Repository{
@@ -300,26 +304,27 @@ func main() {
UsageRepo: usageRepo,
}
publicCollectionCtrl := &controller.PublicCollectionController{
collectionLinkCtrl := &publicCtrl.CollectionLinkController{
FileController: fileController,
EmailNotificationCtrl: emailNotificationCtrl,
PublicCollectionRepo: publicCollectionRepo,
CollectionLinkRepo: collectionLinkRepo,
FileLinkRepo: fileLinkRepo,
CollectionRepo: collectionRepo,
UserRepo: userRepo,
JwtSecret: jwtSecretBytes,
}
collectionController := &collections.CollectionController{
CollectionRepo: collectionRepo,
EmailCtrl: emailNotificationCtrl,
AccessCtrl: accessCtrl,
PublicCollectionCtrl: publicCollectionCtrl,
UserRepo: userRepo,
FileRepo: fileRepo,
CastRepo: &castDb,
BillingCtrl: billingController,
QueueRepo: queueRepo,
TaskRepo: taskLockingRepo,
CollectionRepo: collectionRepo,
EmailCtrl: emailNotificationCtrl,
AccessCtrl: accessCtrl,
CollectionLinkCtrl: collectionLinkCtrl,
UserRepo: userRepo,
FileRepo: fileRepo,
CastRepo: &castDb,
BillingCtrl: billingController,
QueueRepo: queueRepo,
TaskRepo: taskLockingRepo,
}
kexCtrl := &kexCtrl.Controller{
@@ -351,6 +356,12 @@ func main() {
userCache,
userCacheCtrl,
)
fileLinkCtrl := &publicCtrl.FileLinkController{
FileController: fileController,
FileLinkRepo: fileLinkRepo,
FileRepo: fileRepo,
JwtSecret: jwtSecretBytes,
}
passkeyCtrl := &controller.PasskeyController{
Repo: passkeysRepo,
@@ -358,14 +369,21 @@ func main() {
}
authMiddleware := middleware.AuthMiddleware{UserAuthRepo: userAuthRepo, Cache: authCache, UserController: userController}
accessTokenMiddleware := middleware.AccessTokenMiddleware{
PublicCollectionRepo: publicCollectionRepo,
PublicCollectionCtrl: publicCollectionCtrl,
collectionLinkMiddleware := middleware.CollectionLinkMiddleware{
CollectionLinkRepo: collectionLinkRepo,
PublicCollectionCtrl: collectionLinkCtrl,
CollectionRepo: collectionRepo,
Cache: accessTokenCache,
BillingCtrl: billingController,
DiscordController: discordController,
}
fileLinkMiddleware := &middleware.FileLinkMiddleware{
FileLinkRepo: fileLinkRepo,
FileLinkCtrl: fileLinkCtrl,
Cache: accessTokenCache,
BillingCtrl: billingController,
DiscordController: discordController,
}
if environment != "local" {
gin.SetMode(gin.ReleaseMode)
@@ -404,7 +422,9 @@ func main() {
familiesJwtAuthAPI.Use(rateLimiter.GlobalRateLimiter(), authMiddleware.TokenAuthMiddleware(jwt.FAMILIES.Ptr()), rateLimiter.APIRateLimitForUserMiddleware(urlSanitizer))
publicCollectionAPI := server.Group("/public-collection")
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), accessTokenMiddleware.AccessTokenAuthMiddleware(urlSanitizer))
publicCollectionAPI.Use(rateLimiter.GlobalRateLimiter(), collectionLinkMiddleware.Authenticate(urlSanitizer))
fileLinkApi := server.GET("/file-link")
fileLinkApi.Use(rateLimiter.GlobalRateLimiter(), fileLinkMiddleware.Authenticate(urlSanitizer))
healthCheckHandler := &api.HealthCheckHandler{
DB: db,
@@ -432,6 +452,7 @@ func main() {
Controller: fileController,
FileCopyCtrl: fileCopyCtrl,
FileDataCtrl: fileDataCtrl,
FileUrlCtrl: fileLinkCtrl,
}
privateAPI.GET("/files/upload-urls", fileHandler.GetUploadURLs)
privateAPI.GET("/files/multipart-upload-urls", fileHandler.GetMultipartUploadURLs)
@@ -440,6 +461,11 @@ func main() {
privateAPI.GET("/files/preview/:fileID", fileHandler.GetThumbnail)
privateAPI.GET("/files/preview/v2/:fileID", fileHandler.GetThumbnail)
privateAPI.POST("/files/share-url", fileHandler.ShareUrl)
privateAPI.PUT("/files/share-url", fileHandler.UpdateFileURL)
privateAPI.DELETE("/files/share-url/:fileID", fileHandler.DisableUrl)
privateAPI.GET("/files/share-urls/", fileHandler.GetUrls)
privateAPI.PUT("/files/data", fileHandler.PutFileData)
privateAPI.PUT("/files/video-data", fileHandler.PutVideoData)
privateAPI.POST("/files/data/status-diff", fileHandler.FileDataStatusDiff)
@@ -566,13 +592,19 @@ func main() {
privateAPI.PUT("/collections/sharee-magic-metadata", collectionHandler.ShareeMagicMetadataUpdate)
publicCollectionHandler := &api.PublicCollectionHandler{
Controller: publicCollectionCtrl,
Controller: collectionLinkCtrl,
FileCtrl: fileController,
CollectionCtrl: collectionController,
FileDataCtrl: fileDataCtrl,
StorageBonusController: storageBonusCtrl,
}
fileLinkApi.GET("/info", fileHandler.LinkInfo)
fileLinkApi.GET("/pass-info", fileHandler.PasswordInfo)
fileLinkApi.GET("/thumbnail", fileHandler.LinkThumbnail)
fileLinkApi.GET("/file", fileHandler.LinkFile)
fileLinkApi.POST("/verify-password", fileHandler.VerifyPassword)
publicCollectionAPI.GET("/files/preview/:fileID", publicCollectionHandler.GetThumbnail)
publicCollectionAPI.GET("/files/download/:fileID", publicCollectionHandler.GetFile)
publicCollectionAPI.GET("/files/data/fetch", publicCollectionHandler.GetFileData)
@@ -770,7 +802,7 @@ func main() {
setKnownAPIs(server.Routes())
setupAndStartBackgroundJobs(objectCleanupController, replicationController3, fileDataCtrl)
setupAndStartCrons(
userAuthRepo, publicCollectionRepo, twoFactorRepo, passkeysRepo, fileController, taskLockingRepo, emailNotificationCtrl,
userAuthRepo, collectionLinkRepo, fileLinkRepo, twoFactorRepo, passkeysRepo, fileController, taskLockingRepo, emailNotificationCtrl,
trashController, pushController, objectController, dataCleanupController, storageBonusCtrl, emergencyCtrl,
embeddingController, healthCheckHandler, kexCtrl, castDb)
@@ -899,7 +931,8 @@ func setupAndStartBackgroundJobs(
objectCleanupController.StartClearingOrphanObjects()
}
func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionRepo *repo.PublicCollectionRepository,
func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, collectionLinkRepo *public.CollectionLinkRepo,
fileLinkRepo *public.FileLinkRepository,
twoFactorRepo *repo.TwoFactorRepository, passkeysRepo *passkey.Repository, fileController *controller.FileController,
taskRepo *repo.TaskLockRepository, emailNotificationCtrl *email.EmailNotificationController,
trashController *controller.TrashController, pushController *controller.PushController,
@@ -925,7 +958,8 @@ func setupAndStartCrons(userAuthRepo *repo.UserAuthRepository, publicCollectionR
schedule(c, "@every 24h", func() {
_ = userAuthRepo.RemoveDeletedTokens(timeUtil.MicrosecondsBeforeDays(30))
_ = castDb.DeleteOldSessions(context.Background(), timeUtil.MicrosecondsBeforeDays(7))
_ = publicCollectionRepo.CleanupAccessHistory(context.Background())
_ = collectionLinkRepo.CleanupAccessHistory(context.Background())
_ = fileLinkRepo.CleanupAccessHistory(context.Background())
})
schedule(c, "@every 1m", func() {

View File

@@ -79,9 +79,14 @@ http:
apps:
# Default is https://albums.ente.io
#
# If you're running a self hosted instance and wish to serve public links,
# If you're running a self hosted instance and wish to serve public links for photos,
# set this to the URL where your albums web app is running.
public-albums:
# Default is https://locker.ente.io
#
# If you're running a self-hosted instance and wish to serve public links for locker,
# set this to the URL where your albums web app is running.
public-locker:
# Default is https://cast.ente.io
cast:
# Default is https://accounts.ente.io

View File

@@ -97,8 +97,8 @@ var ErrUserDeleted = errors.New("user account has been deleted")
// ErrLockUnavailable is thrown when a lock could not be acquired
var ErrLockUnavailable = errors.New("could not acquire lock")
// ErrActiveLinkAlreadyExists is thrown when the collection already has active public link
var ErrActiveLinkAlreadyExists = errors.New("Collection already has active public link")
// ErrActiveLinkAlreadyExists is thrown when an active link already exists for entity
var ErrActiveLinkAlreadyExists = errors.New("link already exists for this entity")
// ErrNotImplemented indicates that the action that we tried to perform is not
// available at this museum instance. e.g. this could be something that is not
@@ -176,6 +176,11 @@ var ErrMaxPasskeysReached = ApiError{
Message: "Max passkeys limit reached",
HttpStatusCode: http.StatusConflict,
}
var ErrPassProtectedResource = ApiError{
Code: "PASS_PROTECTED_RESOURCE",
Message: "This resource is password protected",
HttpStatusCode: http.StatusForbidden,
}
var ErrCastPermissionDenied = ApiError{
Code: "CAST_PERMISSION_DENIED",

94
server/ente/file_link.go Normal file
View File

@@ -0,0 +1,94 @@
package ente
import (
"fmt"
"github.com/ente-io/museum/pkg/utils/time"
)
// CreateFileUrl represents an encrypted file in the system
type CreateFileUrl struct {
FileID int64 `json:"fileID" binding:"required"`
App App `json:"app" binding:"required"`
}
// UpdateFileUrl ..
type UpdateFileUrl struct {
LinkID string `json:"linkID" binding:"required"`
FileID int64 `json:"fileID" binding:"required"`
ValidTill *int64 `json:"validTill"`
DeviceLimit *int `json:"deviceLimit"`
PassHash *string
Nonce *string
MemLimit *int64
OpsLimit *int64
EnableDownload *bool `json:"enableDownload"`
DisablePassword *bool `json:"disablePassword"`
}
func (ut *UpdateFileUrl) Validate() error {
if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil &&
ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil {
return NewBadRequestWithMessage("all parameters are missing")
}
if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) {
return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit))
}
if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() {
return NewBadRequestWithMessage("valid till should be greater than current timestamp")
}
var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil
var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil
if !(allPassParamsMissing || allPassParamsPresent) {
return NewBadRequestWithMessage("all password params should be either present or missing")
}
if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword {
return NewBadRequestWithMessage("can not set and disable password in same request")
}
return nil
}
type FileLinkRow struct {
LinkID string
OwnerID int64
FileID int64
Token string
DeviceLimit int
ValidTill int64
IsDisabled bool
PassHash *string
Nonce *string
MemLimit *int64
OpsLimit *int64
EnableDownload bool
CreatedAt int64
UpdatedAt int64
}
type FileUrl struct {
LinkID string `json:"linkID" binding:"required"`
URL string `json:"url" binding:"required"`
OwnerID int64 `json:"ownerID" binding:"required"`
FileID int64 `json:"fileID" binding:"required"`
ValidTill int64 `json:"validTill"`
DeviceLimit int `json:"deviceLimit"`
PasswordEnabled bool `json:"passwordEnabled"`
// Nonce contains the nonce value for the password if the link is password protected.
Nonce *string `json:"nonce,omitempty"`
MemLimit *int64 `json:"memLimit,omitempty"`
OpsLimit *int64 `json:"opsLimit,omitempty"`
EnableDownload bool `json:"enableDownload"`
CreatedAt int64 `json:"createdAt"`
}
type FileLinkAccessContext struct {
LinkID string
IP string
UserAgent string
FileID int64
OwnerID int64
}

View File

@@ -40,13 +40,13 @@ func (w WebCommonJWTClaim) Valid() error {
return nil
}
// PublicAlbumPasswordClaim refer to token granted post public album password verification
type PublicAlbumPasswordClaim struct {
// LinkPasswordClaim refer to token granted post link password verification
type LinkPasswordClaim struct {
PassHash string `json:"passKey"`
ExpiryTime int64 `json:"expiryTime"`
}
func (c PublicAlbumPasswordClaim) Valid() error {
func (c LinkPasswordClaim) Valid() error {
if c.ExpiryTime < time.Microseconds() {
return errors.New("token expired")
}

View File

@@ -3,7 +3,7 @@ package ente
import (
"database/sql/driver"
"encoding/json"
"fmt"
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
)
@@ -32,6 +32,33 @@ type UpdatePublicAccessTokenRequest struct {
EnableJoin *bool `json:"enableJoin"`
}
func (ut *UpdatePublicAccessTokenRequest) Validate() error {
if ut.DeviceLimit == nil && ut.ValidTill == nil && ut.DisablePassword == nil &&
ut.Nonce == nil && ut.PassHash == nil && ut.EnableDownload == nil && ut.EnableCollect == nil {
return NewBadRequestWithMessage("all parameters are missing")
}
if ut.DeviceLimit != nil && (*ut.DeviceLimit < 0 || *ut.DeviceLimit > 50) {
return NewBadRequestWithMessage(fmt.Sprintf("device limit: %d out of range [0-50]", *ut.DeviceLimit))
}
if ut.ValidTill != nil && *ut.ValidTill != 0 && *ut.ValidTill < time.Microseconds() {
return NewBadRequestWithMessage("valid till should be greater than current timestamp")
}
var allPassParamsMissing = ut.Nonce == nil && ut.PassHash == nil && ut.MemLimit == nil && ut.OpsLimit == nil
var allPassParamsPresent = ut.Nonce != nil && ut.PassHash != nil && ut.MemLimit != nil && ut.OpsLimit != nil
if !(allPassParamsMissing || allPassParamsPresent) {
return NewBadRequestWithMessage("all password params should be either present or missing")
}
if allPassParamsPresent && ut.DisablePassword != nil && *ut.DisablePassword {
return NewBadRequestWithMessage("can not set and disable password in same request")
}
return nil
}
type VerifyPasswordRequest struct {
PassHash string `json:"passHash" binding:"required"`
}
@@ -40,8 +67,8 @@ type VerifyPasswordResponse struct {
JWTToken string `json:"jwtToken"`
}
// PublicCollectionToken represents row entity for public_collection_token table
type PublicCollectionToken struct {
// CollectionLinkRow represents row entity for public_collection_token table
type CollectionLinkRow struct {
ID int64
CollectionID int64
Token string
@@ -57,7 +84,7 @@ type PublicCollectionToken struct {
EnableJoin bool
}
func (p PublicCollectionToken) CanJoin() error {
func (p CollectionLinkRow) CanJoin() error {
if p.IsDisabled {
return NewBadRequestWithMessage("link disabled")
}

View File

@@ -0,0 +1,3 @@
DROP TABLE IF EXISTS public_file_tokens_access_history;
DROP TABLE IF EXISTS public_file_tokens;

View File

@@ -0,0 +1,46 @@
CREATE TABLE IF NOT EXISTS public_file_tokens
(
id text primary key,
file_id bigint NOT NULL,
owner_id bigint NOT NULL,
app text NOT NULL,
access_token text not null,
valid_till bigint not null DEFAULT 0,
device_limit int not null DEFAULT 0,
is_disabled bool not null DEFAULT FALSE,
enable_download bool not null DEFAULT TRUE,
pw_hash TEXT,
pw_nonce TEXT,
mem_limit BIGINT,
ops_limit BIGINT,
created_at bigint NOT NULL DEFAULT now_utc_micro_seconds(),
updated_at bigint NOT NULL DEFAULT now_utc_micro_seconds()
);
CREATE OR REPLACE TRIGGER update_public_file_tokens_updated_at
BEFORE UPDATE
ON public_file_tokens
FOR EACH ROW
EXECUTE PROCEDURE
trigger_updated_at_microseconds_column();
CREATE TABLE IF NOT EXISTS public_file_tokens_access_history
(
id text NOT NULL,
ip text not null,
user_agent text not null,
created_at bigint NOT NULL DEFAULT now_utc_micro_seconds(),
CONSTRAINT unique_access_id_ip_ua UNIQUE (id, ip, user_agent),
CONSTRAINT fk_public_file_history_token_id
FOREIGN KEY (id)
REFERENCES public_file_tokens (id)
ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS public_file_token_unique_idx ON public_file_tokens (access_token) WHERE is_disabled = FALSE;
CREATE INDEX IF NOT EXISTS public_file_tokens_owner_id_updated_at_idx ON public_file_tokens (owner_id, updated_at);
CREATE UNIQUE INDEX IF NOT EXISTS public_active_file_link_unique_idx ON public_file_tokens (file_id, is_disabled) WHERE is_disabled = FALSE;

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/controller"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/museum/pkg/utils/handler"
"github.com/ente-io/museum/pkg/utils/time"
@@ -172,35 +171,6 @@ func (h *CollectionHandler) UpdateShareURL(c *gin.Context) {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
if req.DeviceLimit == nil && req.ValidTill == nil && req.DisablePassword == nil &&
req.Nonce == nil && req.PassHash == nil && req.EnableDownload == nil && req.EnableCollect == nil {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all parameters are missing"))
return
}
if req.DeviceLimit != nil && (*req.DeviceLimit < 0 || *req.DeviceLimit > controller.DeviceLimitThreshold) {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("device limit: %d out of range", *req.DeviceLimit)))
return
}
if req.ValidTill != nil && *req.ValidTill != 0 && *req.ValidTill < time.Microseconds() {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "valid till should be greater than current timestamp"))
return
}
var allPassParamsMissing = req.Nonce == nil && req.PassHash == nil && req.MemLimit == nil && req.OpsLimit == nil
var allPassParamsPresent = req.Nonce != nil && req.PassHash != nil && req.MemLimit != nil && req.OpsLimit != nil
if !(allPassParamsMissing || allPassParamsPresent) {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "all password params should be either present or missing"))
return
}
if allPassParamsPresent && req.DisablePassword != nil && *req.DisablePassword {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "can not set and disable password in same request"))
return
}
response, err := h.Controller.UpdateShareURL(c, auth.GetUserID(c.Request.Header), req)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/ente-io/museum/pkg/controller/file_copy"
"github.com/ente-io/museum/pkg/controller/filedata"
"github.com/ente-io/museum/pkg/controller/public"
"net/http"
"os"
"strconv"
@@ -24,6 +25,7 @@ import (
// FileHandler exposes request handlers for all encrypted file related requests
type FileHandler struct {
Controller *controller.FileController
FileUrlCtrl *public.FileLinkController
FileCopyCtrl *file_copy.FileCopyController
FileDataCtrl *filedata.Controller
}

141
server/pkg/api/file_link.go Normal file
View File

@@ -0,0 +1,141 @@
package api
import (
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/museum/pkg/utils/handler"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"net/http"
"strconv"
)
// ShareUrl a sharable url for the file
func (h *FileHandler) ShareUrl(c *gin.Context) {
var file ente.CreateFileUrl
if err := c.ShouldBindJSON(&file); err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
response, err := h.FileUrlCtrl.CreateLink(c, file)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, response)
}
func (h *FileHandler) LinkInfo(c *gin.Context) {
resp, err := h.FileUrlCtrl.Info(c)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, gin.H{
"file": resp,
})
}
func (h *FileHandler) PasswordInfo(c *gin.Context) {
resp, err := h.FileUrlCtrl.PassInfo(c)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, gin.H{
"nonce": resp.Nonce,
"opsLimit": resp.OpsLimit,
"memLimit": resp.MemLimit,
})
}
func (h *FileHandler) LinkThumbnail(c *gin.Context) {
linkCtx := auth.MustGetFileLinkAccessContext(c)
url, err := h.Controller.GetThumbnailURL(c, linkCtx.OwnerID, linkCtx.FileID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.Redirect(http.StatusTemporaryRedirect, url)
}
func (h *FileHandler) LinkFile(c *gin.Context) {
linkCtx := auth.MustGetFileLinkAccessContext(c)
url, err := h.Controller.GetFileURL(c, linkCtx.OwnerID, linkCtx.FileID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.Redirect(http.StatusTemporaryRedirect, url)
}
func (h *FileHandler) DisableUrl(c *gin.Context) {
cID, err := strconv.ParseInt(c.Param("fileID"), 10, 64)
if err != nil {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, ""))
return
}
err = h.FileUrlCtrl.Disable(c, cID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, gin.H{})
}
func (h *FileHandler) GetUrls(c *gin.Context) {
sinceTime, err := strconv.ParseInt(c.Query("sinceTime"), 10, 64)
if err != nil {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "sinceTime parsing failed"))
return
}
limit := 500
if c.Query("limit") != "" {
limit, err = strconv.Atoi(c.Query("limit"))
if err != nil || limit < 1 {
handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, ""))
return
}
}
response, err := h.FileUrlCtrl.GetUrls(c, sinceTime, int64(limit))
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, gin.H{
"diff": response,
})
}
// VerifyPassword verifies the password for given link access token and return signed jwt token if it's valid
func (h *FileHandler) VerifyPassword(c *gin.Context) {
var req ente.VerifyPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
resp, err := h.FileUrlCtrl.VerifyPassword(c, req)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, resp)
}
// UpdateFileURL updates the share URL for a file
func (h *FileHandler) UpdateFileURL(c *gin.Context) {
var req ente.UpdateFileUrl
if err := c.ShouldBindJSON(&req); err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
response, err := h.FileUrlCtrl.UpdateSharedUrl(c, req)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, gin.H{
"result": response,
})
}

View File

@@ -5,6 +5,7 @@ import (
fileData "github.com/ente-io/museum/ente/filedata"
"github.com/ente-io/museum/pkg/controller/collections"
"github.com/ente-io/museum/pkg/controller/filedata"
"github.com/ente-io/museum/pkg/controller/public"
"net/http"
"strconv"
@@ -20,7 +21,7 @@ import (
// PublicCollectionHandler exposes request handlers for publicly accessible collections
type PublicCollectionHandler struct {
Controller *controller.PublicCollectionController
Controller *public.CollectionLinkController
FileCtrl *controller.FileController
CollectionCtrl *collections.CollectionController
FileDataCtrl *filedata.Controller

View File

@@ -6,6 +6,7 @@ import (
"github.com/ente-io/museum/pkg/controller"
"github.com/ente-io/museum/pkg/controller/access"
"github.com/ente-io/museum/pkg/controller/email"
"github.com/ente-io/museum/pkg/controller/public"
"github.com/ente-io/museum/pkg/repo/cast"
"github.com/ente-io/museum/pkg/utils/array"
"github.com/ente-io/museum/pkg/utils/auth"
@@ -24,16 +25,16 @@ const (
// CollectionController encapsulates logic that deals with collections
type CollectionController struct {
PublicCollectionCtrl *controller.PublicCollectionController
EmailCtrl *email.EmailNotificationController
AccessCtrl access.Controller
BillingCtrl *controller.BillingController
CollectionRepo *repo.CollectionRepository
UserRepo *repo.UserRepository
FileRepo *repo.FileRepository
QueueRepo *repo.QueueRepository
CastRepo *cast.Repository
TaskRepo *repo.TaskLockRepository
CollectionLinkCtrl *public.CollectionLinkController
EmailCtrl *email.EmailNotificationController
AccessCtrl access.Controller
BillingCtrl *controller.BillingController
CollectionRepo *repo.CollectionRepository
UserRepo *repo.UserRepository
FileRepo *repo.FileRepository
QueueRepo *repo.QueueRepository
CastRepo *cast.Repository
TaskRepo *repo.TaskLockRepository
}
// Create creates a collection
@@ -148,7 +149,7 @@ func (c *CollectionController) TrashV3(ctx *gin.Context, req ente.TrashCollectio
}
}
err = c.PublicCollectionCtrl.Disable(ctx, cID)
err = c.CollectionLinkCtrl.Disable(ctx, cID)
if err != nil {
return stacktrace.Propagate(err, "failed to disabled public share url")
}
@@ -209,7 +210,7 @@ func (c *CollectionController) HandleAccountDeletion(ctx context.Context, userID
if err != nil {
return stacktrace.Propagate(err, "failed to revoke cast token for user")
}
err = c.PublicCollectionCtrl.HandleAccountDeletion(ctx, userID, logger)
err = c.CollectionLinkCtrl.HandleAccountDeletion(ctx, userID, logger)
return stacktrace.Propagate(err, "")
}

View File

@@ -70,21 +70,21 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec
if !collection.AllowSharing() {
return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("joining %s is not allowed", collection.Type))
}
publicCollectionToken, err := c.PublicCollectionCtrl.GetActivePublicCollectionToken(ctx, req.CollectionID)
collectionLinkToken, err := c.CollectionLinkCtrl.GetActiveCollectionLinkToken(ctx, req.CollectionID)
if err != nil {
return stacktrace.Propagate(err, "")
}
if canJoin := publicCollectionToken.CanJoin(); canJoin != nil {
if canJoin := collectionLinkToken.CanJoin(); canJoin != nil {
return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("can not join collection: %s", canJoin.Error()))
}
accessToken := auth.GetAccessToken(ctx)
if publicCollectionToken.Token != accessToken {
if collectionLinkToken.Token != accessToken {
return stacktrace.Propagate(ente.ErrPermissionDenied, "token doesn't match collection")
}
if publicCollectionToken.PassHash != nil && *publicCollectionToken.PassHash != "" {
if collectionLinkToken.PassHash != nil && *collectionLinkToken.PassHash != "" {
accessTokenJWT := auth.GetAccessTokenJWT(ctx)
if passCheckErr := c.PublicCollectionCtrl.ValidateJWTToken(ctx, accessTokenJWT, *publicCollectionToken.PassHash); passCheckErr != nil {
if passCheckErr := c.CollectionLinkCtrl.ValidateJWTToken(ctx, accessTokenJWT, *collectionLinkToken.PassHash); passCheckErr != nil {
return stacktrace.Propagate(passCheckErr, "")
}
}
@@ -93,7 +93,7 @@ func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollec
return stacktrace.Propagate(err, "")
}
role := ente.VIEWER
if publicCollectionToken.EnableCollect {
if collectionLinkToken.EnableCollect {
role = ente.COLLABORATOR
}
joinErr := c.CollectionRepo.Share(req.CollectionID, collection.Owner.ID, userID, req.EncryptedKey, role, time.Microseconds())
@@ -197,7 +197,7 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e
if err != nil {
return ente.PublicURL{}, stacktrace.Propagate(err, "")
}
response, err := c.PublicCollectionCtrl.CreateAccessToken(ctx, req)
response, err := c.CollectionLinkCtrl.CreateLink(ctx, req)
if err != nil {
return ente.PublicURL{}, stacktrace.Propagate(err, "")
}
@@ -205,20 +205,26 @@ func (c *CollectionController) ShareURL(ctx context.Context, userID int64, req e
}
// UpdateShareURL updates the shared url configuration
func (c *CollectionController) UpdateShareURL(ctx context.Context, userID int64, req ente.UpdatePublicAccessTokenRequest) (
ente.PublicURL, error) {
func (c *CollectionController) UpdateShareURL(
ctx context.Context,
userID int64,
req ente.UpdatePublicAccessTokenRequest,
) (*ente.PublicURL, error) {
if err := req.Validate(); err != nil {
return nil, stacktrace.Propagate(err, "")
}
if err := c.verifyOwnership(req.CollectionID, userID); err != nil {
return ente.PublicURL{}, stacktrace.Propagate(err, "")
return nil, stacktrace.Propagate(err, "")
}
err := c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
if err != nil {
return ente.PublicURL{}, stacktrace.Propagate(err, "")
return nil, stacktrace.Propagate(err, "")
}
response, err := c.PublicCollectionCtrl.UpdateSharedUrl(ctx, req)
response, err := c.CollectionLinkCtrl.UpdateSharedUrl(ctx, req)
if err != nil {
return ente.PublicURL{}, stacktrace.Propagate(err, "")
return nil, stacktrace.Propagate(err, "")
}
return response, nil
return &response, nil
}
// DisableSharedURL disable a public auth-token for the given collectionID
@@ -226,7 +232,7 @@ func (c *CollectionController) DisableSharedURL(ctx context.Context, userID int6
if err := c.verifyOwnership(cID, userID); err != nil {
return stacktrace.Propagate(err, "")
}
err := c.PublicCollectionCtrl.Disable(ctx, cID)
err := c.CollectionLinkCtrl.Disable(ctx, cID)
return stacktrace.Propagate(err, "")
}

View File

@@ -1,12 +1,13 @@
package controller
package public
import (
"context"
"errors"
"fmt"
"github.com/ente-io/museum/pkg/controller"
"github.com/ente-io/museum/pkg/repo/public"
"github.com/ente-io/museum/ente"
enteJWT "github.com/ente-io/museum/ente/jwt"
emailCtrl "github.com/ente-io/museum/pkg/controller/email"
"github.com/ente-io/museum/pkg/repo"
"github.com/ente-io/museum/pkg/utils/auth"
@@ -14,7 +15,6 @@ import (
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/lithammer/shortuuid/v3"
"github.com/sirupsen/logrus"
)
@@ -49,23 +49,24 @@ const (
AbuseLimitExceededTemplate = "report_limit_exceeded_alert.html"
)
// PublicCollectionController controls share collection operations
type PublicCollectionController struct {
FileController *FileController
// CollectionLinkController controls share collection operations
type CollectionLinkController struct {
FileController *controller.FileController
EmailNotificationCtrl *emailCtrl.EmailNotificationController
PublicCollectionRepo *repo.PublicCollectionRepository
CollectionLinkRepo *public.CollectionLinkRepo
FileLinkRepo *public.FileLinkRepository
CollectionRepo *repo.CollectionRepository
UserRepo *repo.UserRepository
JwtSecret []byte
}
func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req ente.CreatePublicAccessTokenRequest) (ente.PublicURL, error) {
func (c *CollectionLinkController) CreateLink(ctx context.Context, req ente.CreatePublicAccessTokenRequest) (ente.PublicURL, error) {
accessToken := shortuuid.New()[0:AccessTokenLength]
err := c.PublicCollectionRepo.
err := c.CollectionLinkRepo.
Insert(ctx, req.CollectionID, accessToken, req.ValidTill, req.DeviceLimit, req.EnableCollect, req.EnableJoin)
if err != nil {
if errors.Is(err, ente.ErrActiveLinkAlreadyExists) {
collectionToPubUrlMap, err2 := c.PublicCollectionRepo.GetCollectionToActivePublicURLMap(ctx, []int64{req.CollectionID})
collectionToPubUrlMap, err2 := c.CollectionLinkRepo.GetCollectionToActivePublicURLMap(ctx, []int64{req.CollectionID})
if err2 != nil {
return ente.PublicURL{}, stacktrace.Propagate(err2, "")
}
@@ -81,7 +82,7 @@ func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req
}
}
response := ente.PublicURL{
URL: c.PublicCollectionRepo.GetAlbumUrl(accessToken),
URL: c.CollectionLinkRepo.GetAlbumUrl(accessToken),
ValidTill: req.ValidTill,
DeviceLimit: req.DeviceLimit,
EnableDownload: true,
@@ -91,11 +92,11 @@ func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req
return response, nil
}
func (c *PublicCollectionController) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) {
return c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, collectionID)
func (c *CollectionLinkController) GetActiveCollectionLinkToken(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) {
return c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, collectionID)
}
func (c *PublicCollectionController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) {
func (c *CollectionLinkController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) {
collection, err := c.GetPublicCollection(ctx, true)
if err != nil {
return ente.File{}, stacktrace.Propagate(err, "")
@@ -118,13 +119,13 @@ func (c *PublicCollectionController) CreateFile(ctx *gin.Context, file ente.File
}
// Disable all public accessTokens generated for the given cID till date.
func (c *PublicCollectionController) Disable(ctx context.Context, cID int64) error {
err := c.PublicCollectionRepo.DisableSharing(ctx, cID)
func (c *CollectionLinkController) Disable(ctx context.Context, cID int64) error {
err := c.CollectionLinkRepo.DisableSharing(ctx, cID)
return stacktrace.Propagate(err, "")
}
func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) {
publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, req.CollectionID)
func (c *CollectionLinkController) UpdateSharedUrl(ctx context.Context, req ente.UpdatePublicAccessTokenRequest) (ente.PublicURL, error) {
publicCollectionToken, err := c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, req.CollectionID)
if err != nil {
return ente.PublicURL{}, err
}
@@ -154,12 +155,12 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en
if req.EnableJoin != nil {
publicCollectionToken.EnableJoin = *req.EnableJoin
}
err = c.PublicCollectionRepo.UpdatePublicCollectionToken(ctx, publicCollectionToken)
err = c.CollectionLinkRepo.UpdatePublicCollectionToken(ctx, publicCollectionToken)
if err != nil {
return ente.PublicURL{}, stacktrace.Propagate(err, "")
}
return ente.PublicURL{
URL: c.PublicCollectionRepo.GetAlbumUrl(publicCollectionToken.Token),
URL: c.CollectionLinkRepo.GetAlbumUrl(publicCollectionToken.Token),
DeviceLimit: publicCollectionToken.DeviceLimit,
ValidTill: publicCollectionToken.ValidTill,
EnableDownload: publicCollectionToken.EnableDownload,
@@ -176,58 +177,23 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en
// used by the client to pass in other requests for public collection.
// Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force
// attack for guessing password.
func (c *PublicCollectionController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
func (c *CollectionLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
accessContext := auth.MustGetPublicAccessContext(ctx)
publicCollectionToken, err := c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, accessContext.CollectionID)
collectionLinkRow, err := c.CollectionLinkRepo.GetActiveCollectionLinkRow(ctx, accessContext.CollectionID)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get public collection info")
}
if publicCollectionToken.PassHash == nil || *publicCollectionToken.PassHash == "" {
return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link")
}
if req.PassHash != *publicCollectionToken.PassHash {
return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.PublicAlbumPasswordClaim{
PassHash: req.PassHash,
ExpiryTime: time.NDaysFromNow(365),
})
// Sign and get the complete encoded token as a string using the secret
tokenString, err := token.SignedString(c.JwtSecret)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &ente.VerifyPasswordResponse{
JWTToken: tokenString,
}, nil
return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req)
}
func (c *PublicCollectionController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.PublicAlbumPasswordClaim{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil
}
return c.JwtSecret, nil
})
if err != nil {
return stacktrace.Propagate(err, "JWT parsed failed")
}
claims, ok := token.Claims.(*enteJWT.PublicAlbumPasswordClaim)
if !ok {
return stacktrace.Propagate(errors.New("no claim in jwt token"), "")
}
if token.Valid && claims.PassHash == passwordHash {
return nil
}
return ente.ErrInvalidPassword
func (c *CollectionLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
return validateJWTToken(c.JwtSecret, jwtToken, passwordHash)
}
// ReportAbuse captures abuse report for a publicly shared collection.
// It will also disable the accessToken for the collection if total abuse reports for the said collection
// reaches AutoDisableAbuseThreshold
func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.AbuseReportRequest) error {
func (c *CollectionLinkController) ReportAbuse(ctx *gin.Context, req ente.AbuseReportRequest) error {
accessContext := auth.MustGetPublicAccessContext(ctx)
readableReason, found := AllowedReasons[req.Reason]
if !found {
@@ -235,11 +201,11 @@ func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.Abus
}
logrus.WithField("collectionID", accessContext.CollectionID).Error("CRITICAL: received abuse report")
err := c.PublicCollectionRepo.RecordAbuseReport(ctx, accessContext, req.URL, req.Reason, req.Details)
err := c.CollectionLinkRepo.RecordAbuseReport(ctx, accessContext, req.URL, req.Reason, req.Details)
if err != nil {
return stacktrace.Propagate(err, "")
}
count, err := c.PublicCollectionRepo.GetAbuseReportCount(ctx, accessContext)
count, err := c.CollectionLinkRepo.GetAbuseReportCount(ctx, accessContext)
if err != nil {
return stacktrace.Propagate(err, "")
}
@@ -253,7 +219,7 @@ func (c *PublicCollectionController) ReportAbuse(ctx *gin.Context, req ente.Abus
return nil
}
func (c *PublicCollectionController) onAbuseReportReceived(collectionID int64, report ente.AbuseReportRequest, readableReason string, abuseCount int64) {
func (c *CollectionLinkController) onAbuseReportReceived(collectionID int64, report ente.AbuseReportRequest, readableReason string, abuseCount int64) {
collection, err := c.CollectionRepo.Get(collectionID)
if err != nil {
logrus.Error("Could not get collection for abuse report")
@@ -292,9 +258,9 @@ func (c *PublicCollectionController) onAbuseReportReceived(collectionID int64, r
}
}
func (c *PublicCollectionController) HandleAccountDeletion(ctx context.Context, userID int64, logger *logrus.Entry) error {
func (c *CollectionLinkController) HandleAccountDeletion(ctx context.Context, userID int64, logger *logrus.Entry) error {
logger.Info("updating public collection on account deletion")
collectionIDs, err := c.PublicCollectionRepo.GetActivePublicTokenForUser(ctx, userID)
collectionIDs, err := c.CollectionLinkRepo.GetActivePublicTokenForUser(ctx, userID)
if err != nil {
return stacktrace.Propagate(err, "")
}
@@ -305,12 +271,12 @@ func (c *PublicCollectionController) HandleAccountDeletion(ctx context.Context,
return stacktrace.Propagate(err, "")
}
}
return nil
return c.FileLinkRepo.DisableLinksForUser(ctx, userID)
}
// GetPublicCollection will return collection info for a public url.
// is mustAllowCollect is set to true but the underlying collection doesn't allow uploading
func (c *PublicCollectionController) GetPublicCollection(ctx *gin.Context, mustAllowCollect bool) (ente.Collection, error) {
func (c *CollectionLinkController) GetPublicCollection(ctx *gin.Context, mustAllowCollect bool) (ente.Collection, error) {
accessContext := auth.MustGetPublicAccessContext(ctx)
collection, err := c.CollectionRepo.Get(accessContext.CollectionID)
if err != nil {

View File

@@ -0,0 +1,162 @@
package public
import (
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/controller"
"github.com/ente-io/museum/pkg/repo"
"github.com/ente-io/museum/pkg/repo/public"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"github.com/lithammer/shortuuid/v3"
)
// FileLinkController controls share collection operations
type FileLinkController struct {
FileController *controller.FileController
FileLinkRepo *public.FileLinkRepository
FileRepo *repo.FileRepository
JwtSecret []byte
}
func (c *FileLinkController) CreateLink(ctx *gin.Context, req ente.CreateFileUrl) (*ente.FileUrl, error) {
actorUserID := auth.GetUserID(ctx.Request.Header)
app := auth.GetApp(ctx)
if req.App != app {
return nil, stacktrace.Propagate(ente.NewBadRequestWithMessage("app mismatch"), "app mismatch")
}
file, err := c.FileRepo.GetFileAttributes(req.FileID)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get file attributes")
}
if actorUserID != file.OwnerID {
return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "")
}
accessToken := shortuuid.New()[0:AccessTokenLength]
_, err = c.FileLinkRepo.Insert(ctx, req.FileID, actorUserID, accessToken, app)
if err == nil || err == ente.ErrActiveLinkAlreadyExists {
row, rowErr := c.FileLinkRepo.GetFileUrlRowByFileID(ctx, req.FileID)
if rowErr != nil {
return nil, stacktrace.Propagate(rowErr, "failed to get active file url token")
}
return c.mapRowToFileUrl(ctx, row), nil
}
return nil, stacktrace.Propagate(err, "failed to create public file link")
}
// Disable all public accessTokens generated for the given fileID till date.
func (c *FileLinkController) Disable(ctx *gin.Context, fileID int64) error {
userID := auth.GetUserID(ctx.Request.Header)
file, err := c.FileRepo.GetFileAttributes(fileID)
if err != nil {
return stacktrace.Propagate(err, "failed to get file attributes")
}
if userID != file.OwnerID {
return stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "")
}
return c.FileLinkRepo.DisableLinkForFiles(ctx, []int64{fileID})
}
func (c *FileLinkController) GetUrls(ctx *gin.Context, sinceTime int64, limit int64) ([]*ente.FileUrl, error) {
userID := auth.GetUserID(ctx.Request.Header)
app := auth.GetApp(ctx)
fileLinks, err := c.FileLinkRepo.GetFileUrls(ctx, userID, sinceTime, limit, app)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get file urls")
}
var fileUrls []*ente.FileUrl
for _, row := range fileLinks {
fileUrls = append(fileUrls, c.mapRowToFileUrl(ctx, row))
}
return fileUrls, nil
}
func (c *FileLinkController) UpdateSharedUrl(ctx *gin.Context, req ente.UpdateFileUrl) (*ente.FileUrl, error) {
if err := req.Validate(); err != nil {
return nil, stacktrace.Propagate(err, "invalid request")
}
fileLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, req.FileID)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get file link info")
}
if fileLinkRow.OwnerID != auth.GetUserID(ctx.Request.Header) {
return nil, stacktrace.Propagate(ente.NewPermissionDeniedError("not file owner"), "")
}
if req.ValidTill != nil {
fileLinkRow.ValidTill = *req.ValidTill
}
if req.DeviceLimit != nil {
fileLinkRow.DeviceLimit = *req.DeviceLimit
}
if req.PassHash != nil && req.Nonce != nil && req.OpsLimit != nil && req.MemLimit != nil {
fileLinkRow.PassHash = req.PassHash
fileLinkRow.Nonce = req.Nonce
fileLinkRow.OpsLimit = req.OpsLimit
fileLinkRow.MemLimit = req.MemLimit
} else if req.DisablePassword != nil && *req.DisablePassword {
fileLinkRow.PassHash = nil
fileLinkRow.Nonce = nil
fileLinkRow.OpsLimit = nil
fileLinkRow.MemLimit = nil
}
if req.EnableDownload != nil {
fileLinkRow.EnableDownload = *req.EnableDownload
}
err = c.FileLinkRepo.UpdateLink(ctx, *fileLinkRow)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return c.mapRowToFileUrl(ctx, fileLinkRow), nil
}
func (c *FileLinkController) Info(ctx *gin.Context) (*ente.File, error) {
accessContext := auth.MustGetFileLinkAccessContext(ctx)
return c.FileRepo.GetFileAttributes(accessContext.FileID)
}
func (c *FileLinkController) PassInfo(ctx *gin.Context) (*ente.FileLinkRow, error) {
accessContext := auth.MustGetFileLinkAccessContext(ctx)
return c.FileLinkRepo.GetFileUrlRowByFileID(ctx, accessContext.FileID)
}
// VerifyPassword verifies if the user has provided correct pw hash. If yes, it returns a signed jwt token which can be
// used by the client to pass in other requests for public collection.
// Having a separate endpoint for password validation allows us to easily rate-limit the attempts for brute-force
// attack for guessing password.
func (c *FileLinkController) VerifyPassword(ctx *gin.Context, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
accessContext := auth.MustGetFileLinkAccessContext(ctx)
collectionLinkRow, err := c.FileLinkRepo.GetActiveFileUrlToken(ctx, accessContext.FileID)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get public collection info")
}
return verifyPassword(c.JwtSecret, collectionLinkRow.PassHash, req)
}
func (c *FileLinkController) ValidateJWTToken(ctx *gin.Context, jwtToken string, passwordHash string) error {
return validateJWTToken(c.JwtSecret, jwtToken, passwordHash)
}
func (c *FileLinkController) mapRowToFileUrl(ctx *gin.Context, row *ente.FileLinkRow) *ente.FileUrl {
app := auth.GetApp(ctx)
var url string
if app == ente.Locker {
url = c.FileLinkRepo.LockerFileLink(row.Token)
} else {
url = c.FileLinkRepo.PhotoLink(row.Token)
}
return &ente.FileUrl{
LinkID: row.LinkID,
FileID: row.FileID,
URL: url,
OwnerID: row.OwnerID,
ValidTill: row.ValidTill,
DeviceLimit: row.DeviceLimit,
PasswordEnabled: row.PassHash != nil,
Nonce: row.Nonce,
OpsLimit: row.OpsLimit,
MemLimit: row.MemLimit,
EnableDownload: row.EnableDownload,
CreatedAt: row.CreatedAt,
}
}

View File

@@ -0,0 +1,54 @@
package public
import (
"errors"
"fmt"
"github.com/ente-io/museum/ente"
enteJWT "github.com/ente-io/museum/ente/jwt"
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
"github.com/golang-jwt/jwt"
)
func validateJWTToken(secret []byte, jwtToken string, passwordHash string) error {
token, err := jwt.ParseWithClaims(jwtToken, &enteJWT.LinkPasswordClaim{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return stacktrace.Propagate(fmt.Errorf("unexpected signing method: %v", token.Header["alg"]), ""), nil
}
return secret, nil
})
if err != nil {
return stacktrace.Propagate(err, "JWT parsed failed")
}
claims, ok := token.Claims.(*enteJWT.LinkPasswordClaim)
if !ok {
return stacktrace.Propagate(errors.New("no claim in jwt token"), "")
}
if token.Valid && claims.PassHash == passwordHash {
return nil
}
return ente.ErrInvalidPassword
}
func verifyPassword(secret []byte, expectedPassHash *string, req ente.VerifyPasswordRequest) (*ente.VerifyPasswordResponse, error) {
if expectedPassHash == nil || *expectedPassHash == "" {
return nil, stacktrace.Propagate(ente.ErrBadRequest, "password is not configured for the link")
}
if req.PassHash != *expectedPassHash {
return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "incorrect password for link")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &enteJWT.LinkPasswordClaim{
PassHash: req.PassHash,
ExpiryTime: time.NDaysFromNow(365),
})
// Sign and get the complete encoded token as a string using the secret
tokenString, err := token.SignedString(secret)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &ente.VerifyPasswordResponse{
JWTToken: tokenString,
}, nil
}

View File

@@ -5,6 +5,8 @@ import (
"context"
"crypto/sha256"
"fmt"
public2 "github.com/ente-io/museum/pkg/controller/public"
"github.com/ente-io/museum/pkg/repo/public"
"net/http"
"github.com/ente-io/museum/ente"
@@ -24,20 +26,20 @@ import (
var passwordWhiteListedURLs = []string{"/public-collection/info", "/public-collection/report-abuse", "/public-collection/verify-password"}
var whitelistedCollectionShareIDs = []int64{111}
// AccessTokenMiddleware intercepts and authenticates incoming requests
type AccessTokenMiddleware struct {
PublicCollectionRepo *repo.PublicCollectionRepository
PublicCollectionCtrl *controller.PublicCollectionController
// CollectionLinkMiddleware intercepts and authenticates incoming requests
type CollectionLinkMiddleware struct {
CollectionLinkRepo *public.CollectionLinkRepo
PublicCollectionCtrl *public2.CollectionLinkController
CollectionRepo *repo.CollectionRepository
Cache *cache.Cache
BillingCtrl *controller.BillingController
DiscordController *discord.DiscordController
}
// AccessTokenAuthMiddleware returns a middle ware that extracts the `X-Auth-Access-Token`
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
// within the header of a request and uses it to validate the access token and set the
// ente.PublicAccessContext with auth.PublicAccessKey as key
func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
accessToken := auth.GetAccessToken(c)
if accessToken == "" {
@@ -52,7 +54,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
cachedValue, cacheHit := m.Cache.Get(cacheKey)
if !cacheHit {
publicCollectionSummary, err = m.PublicCollectionRepo.GetCollectionSummaryByToken(c, accessToken)
publicCollectionSummary, err = m.CollectionLinkRepo.GetCollectionSummaryByToken(c, accessToken)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
@@ -112,7 +114,7 @@ func (m *AccessTokenMiddleware) AccessTokenAuthMiddleware(urlSanitizer func(_ *g
c.Next()
}
}
func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
func (m *CollectionLinkMiddleware) validateOwnersSubscription(cID int64) error {
userID, err := m.CollectionRepo.GetOwnerID(cID)
if err != nil {
return stacktrace.Propagate(err, "")
@@ -120,7 +122,7 @@ func (m *AccessTokenMiddleware) validateOwnersSubscription(cID int64) error {
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
}
func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context,
collectionSummary ente.PublicCollectionSummary, ip string, ua string) (bool, error) {
// skip deviceLimit check & record keeping for requests via CF worker
if network.IsCFWorkerIP(ip) {
@@ -128,7 +130,7 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
}
sharedID := collectionSummary.ID
hasAccessedInPast, err := m.PublicCollectionRepo.AccessedInPast(ctx, sharedID, ip, ua)
hasAccessedInPast, err := m.CollectionLinkRepo.AccessedInPast(ctx, sharedID, ip, ua)
if err != nil {
return false, stacktrace.Propagate(err, "")
}
@@ -136,17 +138,17 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
if hasAccessedInPast {
return false, nil
}
count, err := m.PublicCollectionRepo.GetUniqueAccessCount(ctx, sharedID)
count, err := m.CollectionLinkRepo.GetUniqueAccessCount(ctx, sharedID)
if err != nil {
return false, stacktrace.Propagate(err, "failed to get unique access count")
}
deviceLimit := int64(collectionSummary.DeviceLimit)
if deviceLimit == controller.DeviceLimitThreshold {
deviceLimit = controller.DeviceLimitThresholdMultiplier * controller.DeviceLimitThreshold
if deviceLimit == public2.DeviceLimitThreshold {
deviceLimit = public2.DeviceLimitThresholdMultiplier * public2.DeviceLimitThreshold
}
if count >= controller.DeviceLimitWarningThreshold {
if count >= public2.DeviceLimitWarningThreshold {
if !array.Int64InList(sharedID, whitelistedCollectionShareIDs) {
m.DiscordController.NotifyPotentialAbuse(
fmt.Sprintf("Album exceeds warning threshold: {CollectionID: %d, ShareID: %d}",
@@ -157,12 +159,12 @@ func (m *AccessTokenMiddleware) isDeviceLimitReached(ctx context.Context,
if deviceLimit > 0 && count >= deviceLimit {
return true, nil
}
err = m.PublicCollectionRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
err = m.CollectionLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
return false, stacktrace.Propagate(err, "failed to record access history")
}
// validatePassword will verify if the user is provided correct password for the public album
func (m *AccessTokenMiddleware) validatePassword(c *gin.Context, reqPath string,
func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath string,
collectionSummary ente.PublicCollectionSummary) error {
if array.StringInList(reqPath, passwordWhiteListedURLs) {
return nil

View File

@@ -0,0 +1,168 @@
package middleware
import (
"context"
"fmt"
publicCtrl "github.com/ente-io/museum/pkg/controller/public"
"github.com/ente-io/museum/pkg/repo/public"
"github.com/ente-io/museum/pkg/utils/array"
"net/http"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/controller"
"github.com/ente-io/museum/pkg/controller/discord"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/museum/pkg/utils/network"
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/sirupsen/logrus"
)
var filePasswordWhiteListedURLs = []string{"/file-link/pass-info", "/file-link/verify-password"}
// FileLinkMiddleware intercepts and authenticates incoming requests
type FileLinkMiddleware struct {
FileLinkRepo *public.FileLinkRepository
FileLinkCtrl *publicCtrl.FileLinkController
Cache *cache.Cache
BillingCtrl *controller.BillingController
DiscordController *discord.DiscordController
}
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
// within the header of a request and uses it to validate the access token and set the
// ente.PublicAccessContext with auth.PublicAccessKey as key
func (m *FileLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
accessToken := auth.GetAccessToken(c)
if accessToken == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing accessToken"})
return
}
clientIP := network.GetClientIP(c)
userAgent := c.GetHeader("User-Agent")
cacheKey := computeHashKeyForList([]string{accessToken, clientIP, userAgent}, ":")
cachedValue, cacheHit := m.Cache.Get(cacheKey)
var fileLinkRow *ente.FileLinkRow
var err error
if !cacheHit {
fileLinkRow, err = m.FileLinkRepo.GetFileUrlRowByToken(c, accessToken)
if err != nil {
logrus.WithError(err).Info("failed to get file link row by token")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if fileLinkRow.IsDisabled {
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "disabled token"})
return
}
// validate if user still has active paid subscription
if err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(fileLinkRow.OwnerID, true); err != nil {
logrus.WithError(err).Info("failed to verify active paid subscription")
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
return
}
// validate device limit
reached, limitErr := m.isDeviceLimitReached(c, fileLinkRow, clientIP, userAgent)
if limitErr != nil {
logrus.WithError(limitErr).Error("failed to check device limit")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "something went wrong"})
return
}
if reached {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "reached device limit"})
return
}
} else {
fileLinkRow = cachedValue.(*ente.FileLinkRow)
}
if fileLinkRow.ValidTill > 0 && // expiry time is defined, 0 indicates no expiry
fileLinkRow.ValidTill < time.Microseconds() {
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "expired token"})
return
}
// checks password protected public collection
if fileLinkRow.PassHash != nil && *fileLinkRow.PassHash != "" {
reqPath := urlSanitizer(c)
if err = m.validatePassword(c, reqPath, fileLinkRow); err != nil {
logrus.WithError(err).Warn("password validation failed")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err})
return
}
}
if !cacheHit {
m.Cache.Set(cacheKey, fileLinkRow, cache.DefaultExpiration)
}
c.Set(auth.FileLinkAccessKey, &ente.FileLinkAccessContext{
LinkID: fileLinkRow.LinkID,
IP: clientIP,
UserAgent: userAgent,
FileID: fileLinkRow.FileID,
OwnerID: fileLinkRow.OwnerID,
})
c.Next()
}
}
func (m *FileLinkMiddleware) isDeviceLimitReached(ctx context.Context,
collectionSummary *ente.FileLinkRow, ip string, ua string) (bool, error) {
// skip deviceLimit check & record keeping for requests via CF worker
if network.IsCFWorkerIP(ip) {
return false, nil
}
sharedID := collectionSummary.LinkID
hasAccessedInPast, err := m.FileLinkRepo.AccessedInPast(ctx, sharedID, ip, ua)
if err != nil {
return false, stacktrace.Propagate(err, "")
}
// if the device has accessed the url in the past, let it access it now as well, irrespective of device limit.
if hasAccessedInPast {
return false, nil
}
count, err := m.FileLinkRepo.GetUniqueAccessCount(ctx, sharedID)
if err != nil {
return false, stacktrace.Propagate(err, "failed to get unique access count")
}
deviceLimit := int64(collectionSummary.DeviceLimit)
if deviceLimit == publicCtrl.DeviceLimitThreshold {
deviceLimit = publicCtrl.DeviceLimitThresholdMultiplier * publicCtrl.DeviceLimitThreshold
}
if count >= publicCtrl.DeviceLimitWarningThreshold {
m.DiscordController.NotifyPotentialAbuse(
fmt.Sprintf("FileLink exceeds warning threshold: {FileID: %d, ShareID: %s}",
collectionSummary.FileID, collectionSummary.LinkID))
}
if deviceLimit > 0 && count >= deviceLimit {
return true, nil
}
err = m.FileLinkRepo.RecordAccessHistory(ctx, sharedID, ip, ua)
return false, stacktrace.Propagate(err, "failed to record access history")
}
// validatePassword will verify if the user is provided correct password for the public album
func (m *FileLinkMiddleware) validatePassword(
c *gin.Context,
reqPath string,
fileLinkRow *ente.FileLinkRow,
) error {
accessTokenJWT := auth.GetAccessTokenJWT(c)
if accessTokenJWT == "" {
if array.StringInList(reqPath, filePasswordWhiteListedURLs) {
return nil
}
return &ente.ErrPassProtectedResource
}
return m.FileLinkCtrl.ValidateJWTToken(c, accessTokenJWT, *fileLinkRow.PassHash)
}

View File

@@ -140,6 +140,7 @@ func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limi
reqPath == "/users/verify-email" ||
reqPath == "/user/change-email" ||
reqPath == "/public-collection/verify-password" ||
reqPath == "/file-link/verify-password" ||
reqPath == "/family/accept-invite" ||
reqPath == "/users/srp/attributes" ||
(reqPath == "/cast/device-info" && reqMethod == "POST") ||

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"github.com/ente-io/museum/pkg/repo/public"
"strconv"
t "time"
@@ -22,13 +23,13 @@ import (
// CollectionRepository defines the methods for inserting, updating and
// retrieving collection entities from the underlying repository
type CollectionRepository struct {
DB *sql.DB
FileRepo *FileRepository
PublicCollectionRepo *PublicCollectionRepository
TrashRepo *TrashRepository
SecretEncryptionKey []byte
QueueRepo *QueueRepository
LatencyLogger *prometheus.HistogramVec
DB *sql.DB
FileRepo *FileRepository
CollectionLinkRepo *public.CollectionLinkRepo
TrashRepo *TrashRepository
SecretEncryptionKey []byte
QueueRepo *QueueRepository
LatencyLogger *prometheus.HistogramVec
}
type SharedCollection struct {
@@ -74,7 +75,7 @@ func (repo *CollectionRepository) Get(collectionID int64) (ente.Collection, erro
c.EncryptedName = encryptedName.String
c.NameDecryptionNonce = nameDecryptionNonce.String
}
urlMap, err := repo.PublicCollectionRepo.GetCollectionToActivePublicURLMap(context.Background(), []int64{collectionID})
urlMap, err := repo.CollectionLinkRepo.GetCollectionToActivePublicURLMap(context.Background(), []int64{collectionID})
if err != nil {
return ente.Collection{}, stacktrace.Propagate(err, "failed to get publicURL info")
}
@@ -174,7 +175,7 @@ pct.access_token, pct.valid_till, pct.device_limit, pct.created_at, pct.updated_
if _, ok := addPublicUrlMap[pctToken.String]; !ok {
addPublicUrlMap[pctToken.String] = true
url := ente.PublicURL{
URL: repo.PublicCollectionRepo.GetAlbumUrl(pctToken.String),
URL: repo.CollectionLinkRepo.GetAlbumUrl(pctToken.String),
DeviceLimit: int(pctDeviceLimit.Int32),
ValidTill: pctValidTill.Int64,
EnableDownload: pctEnableDownload.Bool,

View File

@@ -638,6 +638,16 @@ func (repo *FileRepository) GetFileAttributesForCopy(fileIDs []int64) ([]ente.Fi
return result, nil
}
func (repo *FileRepository) GetFileAttributes(fileID int64) (*ente.File, error) {
rows := repo.DB.QueryRow(`SELECT file_id, owner_id, file_decryption_header, thumbnail_decryption_header, metadata_decryption_header, encrypted_metadata, pub_magic_metadata FROM files WHERE file_id = $1`, fileID)
var file ente.File
err := rows.Scan(&file.ID, &file.OwnerID, &file.File.DecryptionHeader, &file.Thumbnail.DecryptionHeader, &file.Metadata.DecryptionHeader, &file.Metadata.EncryptedData, &file.PubicMagicMetadata)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &file, nil
}
// GetUsage gets the Storage usage of a user
// Deprecated: GetUsage is deprecated, use UsageRepository.GetUsage
func (repo *FileRepository) GetUsage(userID int64) (int64, error) {

View File

@@ -1,4 +1,4 @@
package repo
package public
import (
"context"
@@ -13,29 +13,29 @@ import (
const BaseShareURL = "https://albums.ente.io/?t=%s"
// PublicCollectionRepository defines the methods for inserting, updating and
// CollectionLinkRepo defines the methods for inserting, updating and
// retrieving entities related to public collections
type PublicCollectionRepository struct {
type CollectionLinkRepo struct {
DB *sql.DB
albumHost string
}
// NewPublicCollectionRepository ..
func NewPublicCollectionRepository(db *sql.DB, albumHost string) *PublicCollectionRepository {
// NewCollectionLinkRepository ..
func NewCollectionLinkRepository(db *sql.DB, albumHost string) *CollectionLinkRepo {
if albumHost == "" {
albumHost = "https://albums.ente.io"
}
return &PublicCollectionRepository{
return &CollectionLinkRepo{
DB: db,
albumHost: albumHost,
}
}
func (pcr *PublicCollectionRepository) GetAlbumUrl(token string) string {
func (pcr *CollectionLinkRepo) GetAlbumUrl(token string) string {
return fmt.Sprintf("%s/?t=%s", pcr.albumHost, token)
}
func (pcr *PublicCollectionRepository) Insert(ctx context.Context,
func (pcr *CollectionLinkRepo) Insert(ctx context.Context,
cID int64, token string, validTill int64, deviceLimit int, enableCollect bool, enableJoin *bool) error {
// default value for enableJoin is true
join := true
@@ -51,7 +51,7 @@ func (pcr *PublicCollectionRepository) Insert(ctx context.Context,
return stacktrace.Propagate(err, "failed to insert")
}
func (pcr *PublicCollectionRepository) DisableSharing(ctx context.Context, cID int64) error {
func (pcr *CollectionLinkRepo) DisableSharing(ctx context.Context, cID int64) error {
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_collection_tokens SET is_disabled = true where
collection_id = $1 and is_disabled = false`, cID)
return stacktrace.Propagate(err, "failed to disable sharing")
@@ -59,7 +59,7 @@ func (pcr *PublicCollectionRepository) DisableSharing(ctx context.Context, cID i
// GetCollectionToActivePublicURLMap will return map of collectionID to PublicURLs which are not disabled yet.
// Note: The url could be expired or deviceLimit is already reached
func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx context.Context, collectionIDs []int64) (map[int64][]ente.PublicURL, error) {
func (pcr *CollectionLinkRepo) GetCollectionToActivePublicURLMap(ctx context.Context, collectionIDs []int64) (map[int64][]ente.PublicURL, error) {
rows, err := pcr.DB.QueryContext(ctx, `SELECT collection_id, access_token, valid_till, device_limit, enable_download, enable_collect, enable_join, pw_nonce, mem_limit, ops_limit FROM
public_collection_tokens WHERE collection_id = ANY($1) and is_disabled = FALSE`,
pq.Array(collectionIDs))
@@ -92,26 +92,26 @@ func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx con
return result, nil
}
// GetActivePublicCollectionToken will return ente.PublicCollectionToken for given collection ID
// GetActiveCollectionLinkRow will return ente.CollectionLinkRow for given collection ID
// Note: The token could be expired or deviceLimit is already reached
func (pcr *PublicCollectionRepository) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) {
func (pcr *CollectionLinkRepo) GetActiveCollectionLinkRow(ctx context.Context, collectionID int64) (ente.CollectionLinkRow, error) {
row := pcr.DB.QueryRowContext(ctx, `SELECT id, collection_id, access_token, valid_till, device_limit,
is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download, enable_collect, enable_join FROM
public_collection_tokens WHERE collection_id = $1 and is_disabled = FALSE`,
collectionID)
//defer rows.Close()
ret := ente.PublicCollectionToken{}
ret := ente.CollectionLinkRow{}
err := row.Scan(&ret.ID, &ret.CollectionID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit,
&ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload, &ret.EnableCollect, &ret.EnableJoin)
if err != nil {
return ente.PublicCollectionToken{}, stacktrace.Propagate(err, "")
return ente.CollectionLinkRow{}, stacktrace.Propagate(err, "")
}
return ret, nil
}
// UpdatePublicCollectionToken will update the row for corresponding public collection token
func (pcr *PublicCollectionRepository) UpdatePublicCollectionToken(ctx context.Context, pct ente.PublicCollectionToken) error {
func (pcr *CollectionLinkRepo) UpdatePublicCollectionToken(ctx context.Context, pct ente.CollectionLinkRow) error {
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_collection_tokens SET valid_till = $1, device_limit = $2,
pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7, enable_collect = $8, enable_join = $9
where id = $10`,
@@ -119,7 +119,7 @@ func (pcr *PublicCollectionRepository) UpdatePublicCollectionToken(ctx context.C
return stacktrace.Propagate(err, "failed to update public collection token")
}
func (pcr *PublicCollectionRepository) RecordAbuseReport(ctx context.Context, accessCtx ente.PublicAccessContext,
func (pcr *CollectionLinkRepo) RecordAbuseReport(ctx context.Context, accessCtx ente.PublicAccessContext,
url string, reason string, details ente.AbuseReportDetails) error {
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_abuse_report
(share_id, ip, user_agent, url, reason, details) VALUES ($1, $2, $3, $4, $5, $6)
@@ -128,7 +128,7 @@ func (pcr *PublicCollectionRepository) RecordAbuseReport(ctx context.Context, ac
return stacktrace.Propagate(err, "failed to record abuse report")
}
func (pcr *PublicCollectionRepository) GetAbuseReportCount(ctx context.Context, accessCtx ente.PublicAccessContext) (int64, error) {
func (pcr *CollectionLinkRepo) GetAbuseReportCount(ctx context.Context, accessCtx ente.PublicAccessContext) (int64, error) {
row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_abuse_report WHERE share_id = $1`, accessCtx.ID)
var count int64 = 0
err := row.Scan(&count)
@@ -138,7 +138,7 @@ func (pcr *PublicCollectionRepository) GetAbuseReportCount(ctx context.Context,
return count, nil
}
func (pcr *PublicCollectionRepository) GetUniqueAccessCount(ctx context.Context, shareId int64) (int64, error) {
func (pcr *CollectionLinkRepo) GetUniqueAccessCount(ctx context.Context, shareId int64) (int64, error) {
row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_collection_access_history WHERE share_id = $1`, shareId)
var count int64 = 0
err := row.Scan(&count)
@@ -148,7 +148,7 @@ func (pcr *PublicCollectionRepository) GetUniqueAccessCount(ctx context.Context,
return count, nil
}
func (pcr *PublicCollectionRepository) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error {
func (pcr *CollectionLinkRepo) RecordAccessHistory(ctx context.Context, shareID int64, ip string, ua string) error {
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_collection_access_history
(share_id, ip, user_agent) VALUES ($1, $2, $3)
ON CONFLICT ON CONSTRAINT unique_access_sid_ip_ua DO NOTHING;`,
@@ -157,7 +157,7 @@ func (pcr *PublicCollectionRepository) RecordAccessHistory(ctx context.Context,
}
// AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past
func (pcr *PublicCollectionRepository) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) {
func (pcr *CollectionLinkRepo) AccessedInPast(ctx context.Context, shareID int64, ip string, ua string) (bool, error) {
row := pcr.DB.QueryRowContext(ctx, `select share_id from public_collection_access_history where share_id =$1 and ip = $2 and user_agent = $3`,
shareID, ip, ua)
var tempID int64
@@ -168,7 +168,7 @@ func (pcr *PublicCollectionRepository) AccessedInPast(ctx context.Context, share
return true, stacktrace.Propagate(err, "failed to record access history")
}
func (pcr *PublicCollectionRepository) GetCollectionSummaryByToken(ctx context.Context, accessToken string) (ente.PublicCollectionSummary, error) {
func (pcr *CollectionLinkRepo) GetCollectionSummaryByToken(ctx context.Context, accessToken string) (ente.PublicCollectionSummary, error) {
row := pcr.DB.QueryRowContext(ctx,
`SELECT sct.id, sct.collection_id, sct.is_disabled, sct.valid_till, sct.device_limit, sct.pw_hash,
sct.created_at, sct.updated_at, count(ah.share_id)
@@ -185,7 +185,7 @@ func (pcr *PublicCollectionRepository) GetCollectionSummaryByToken(ctx context.C
return result, nil
}
func (pcr *PublicCollectionRepository) GetActivePublicTokenForUser(ctx context.Context, userID int64) ([]int64, error) {
func (pcr *CollectionLinkRepo) GetActivePublicTokenForUser(ctx context.Context, userID int64) ([]int64, error) {
rows, err := pcr.DB.QueryContext(ctx, `select pt.collection_id from public_collection_tokens pt left join collections c on pt.collection_id = c.collection_id where pt.is_disabled = FALSE and c.owner_id= $1;`, userID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
@@ -204,7 +204,7 @@ func (pcr *PublicCollectionRepository) GetActivePublicTokenForUser(ctx context.C
}
// CleanupAccessHistory public_collection_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days
func (pcr *PublicCollectionRepository) CleanupAccessHistory(ctx context.Context) error {
func (pcr *CollectionLinkRepo) CleanupAccessHistory(ctx context.Context) error {
_, err := pcr.DB.ExecContext(ctx, `DELETE FROM public_collection_access_history WHERE share_id IN (SELECT id FROM public_collection_tokens WHERE is_disabled = TRUE AND updated_at < (now_utc_micro_seconds() - (24::BIGINT * 30 * 60 * 60 * 1000 * 1000)))`)
if err != nil {
return stacktrace.Propagate(err, "failed to clean up public collection access history")

View File

@@ -0,0 +1,218 @@
package public
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/ente-io/museum/ente/base"
"github.com/lib/pq"
"github.com/spf13/viper"
"github.com/ente-io/museum/ente"
"github.com/ente-io/stacktrace"
)
// FileLinkRepository defines the methods for inserting, updating and
// retrieving entities related to public file
type FileLinkRepository struct {
DB *sql.DB
photoHost string
lockerHost string
}
// NewFileLinkRepo ..
func NewFileLinkRepo(db *sql.DB) *FileLinkRepository {
albumHost := viper.GetString("apps.public-albums")
if albumHost == "" {
albumHost = "https://albums.ente.io"
}
lockerHost := viper.GetString("apps.public-locker")
if lockerHost == "" {
lockerHost = "https://locker.ente.io"
}
return &FileLinkRepository{
DB: db,
photoHost: albumHost,
lockerHost: lockerHost,
}
}
func (pcr *FileLinkRepository) PhotoLink(token string) string {
return fmt.Sprintf("%s/?t=%s", pcr.photoHost, token)
}
func (pcr *FileLinkRepository) LockerFileLink(token string) string {
return fmt.Sprintf("%s/?t=%s", pcr.lockerHost, token)
}
func (pcr *FileLinkRepository) Insert(
ctx context.Context,
fileID int64,
ownerID int64,
token string,
app ente.App,
) (*string, error) {
id, err := base.NewID("pft")
if err != nil {
return nil, stacktrace.Propagate(err, "failed to generate new ID for public file token")
}
_, err = pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens
(id, file_id, owner_id, access_token, app) VALUES ($1, $2, $3, $4, $5)`,
id, fileID, ownerID, token, string(app))
if err != nil {
if err.Error() == "pq: duplicate key value violates unique constraint \"public_active_file_link_unique_idx\"" {
return nil, ente.ErrActiveLinkAlreadyExists
}
return nil, stacktrace.Propagate(err, "failed to insert")
}
return id, nil
}
// GetActiveFileUrlToken will return ente.CollectionLinkRow for given collection ID
// Note: The token could be expired or deviceLimit is already reached
func (pcr *FileLinkRepository) GetActiveFileUrlToken(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) {
row := pcr.DB.QueryRowContext(ctx, `SELECT id, file_id, owner_id, access_token, valid_till, device_limit,
is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download FROM
public_file_tokens WHERE file_id = $1 and is_disabled = FALSE`,
fileID)
ret := ente.FileLinkRow{}
err := row.Scan(&ret.LinkID, &ret.FileID, ret.OwnerID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit,
&ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &ret, nil
}
func (pcr *FileLinkRepository) GetFileUrls(ctx context.Context, userID int64, sinceTime int64, limit int64, app ente.App) ([]*ente.FileLinkRow, error) {
if limit <= 0 {
limit = 500
}
query := `SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit,
created_at, updated_at FROM public_file_tokens
WHERE owner_id = $1 AND created_at > $2 AND app = $3 ORDER BY updated_at DESC LIMIT $4`
rows, err := pcr.DB.QueryContext(ctx, query, userID, sinceTime, string(app), limit)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to get public file urls")
}
defer rows.Close()
var result []*ente.FileLinkRow
for rows.Next() {
var row ente.FileLinkRow
err = rows.Scan(&row.LinkID, &row.FileID, &row.OwnerID, &row.IsDisabled,
&row.ValidTill, &row.DeviceLimit, &row.EnableDownload,
&row.PassHash, &row.Nonce, &row.MemLimit,
&row.OpsLimit, &row.CreatedAt, &row.UpdatedAt)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to scan public file url row")
}
result = append(result, &row)
}
return result, nil
}
func (pcr *FileLinkRepository) DisableLinkForFiles(ctx context.Context, fileIDs []int64) error {
if len(fileIDs) == 0 {
return nil
}
query := `UPDATE public_file_tokens SET is_disabled = TRUE WHERE file_id = ANY($1)`
_, err := pcr.DB.ExecContext(ctx, query, pq.Array(fileIDs))
if err != nil {
return stacktrace.Propagate(err, "failed to disable public file links")
}
return nil
}
// DisableLinksForUser will disable all public file links for the given user
func (pcr *FileLinkRepository) DisableLinksForUser(ctx context.Context, userID int64) error {
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET is_disabled = TRUE WHERE owner_id = $1`, userID)
if err != nil {
return stacktrace.Propagate(err, "failed to disable public file link")
}
return nil
}
func (pcr *FileLinkRepository) GetFileUrlRowByToken(ctx context.Context, accessToken string) (*ente.FileLinkRow, error) {
row := pcr.DB.QueryRowContext(ctx,
`SELECT id, file_id, owner_id, is_disabled, valid_till, device_limit, enable_download, pw_hash, pw_nonce, mem_limit, ops_limit
created_at, updated_at
from public_file_tokens
where access_token = $1
`, accessToken)
var result = ente.FileLinkRow{}
err := row.Scan(&result.LinkID, &result.FileID, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, ente.ErrNotFound
}
return nil, stacktrace.Propagate(err, "failed to get public file url summary by token")
}
return &result, nil
}
func (pcr *FileLinkRepository) GetFileUrlRowByFileID(ctx context.Context, fileID int64) (*ente.FileLinkRow, error) {
row := pcr.DB.QueryRowContext(ctx,
`SELECT id, file_id, access_token, owner_id, is_disabled, enable_download, valid_till, device_limit, pw_hash, pw_nonce, mem_limit, ops_limit,
created_at, updated_at
from public_file_tokens
where file_id = $1 and is_disabled = FALSE`, fileID)
var result = ente.FileLinkRow{}
err := row.Scan(&result.LinkID, &result.FileID, &result.Token, &result.OwnerID, &result.IsDisabled, &result.EnableDownload, &result.ValidTill, &result.DeviceLimit, &result.PassHash, &result.Nonce, &result.MemLimit, &result.OpsLimit, &result.CreatedAt, &result.UpdatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, ente.ErrNotFound
}
return nil, stacktrace.Propagate(err, "failed to get public file url summary by file ID")
}
return &result, nil
}
// UpdateLink will update the row for corresponding public file token
func (pcr *FileLinkRepository) UpdateLink(ctx context.Context, pct ente.FileLinkRow) error {
_, err := pcr.DB.ExecContext(ctx, `UPDATE public_file_tokens SET valid_till = $1, device_limit = $2,
pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7
where id = $8`,
pct.ValidTill, pct.DeviceLimit, pct.PassHash, pct.Nonce, pct.MemLimit, pct.OpsLimit, pct.EnableDownload, pct.LinkID)
return stacktrace.Propagate(err, "failed to update public file token")
}
func (pcr *FileLinkRepository) GetUniqueAccessCount(ctx context.Context, linkId string) (int64, error) {
row := pcr.DB.QueryRowContext(ctx, `SELECT count(*) FROM public_file_tokens_access_history WHERE id = $1`, linkId)
var count int64 = 0
err := row.Scan(&count)
if err != nil {
return -1, stacktrace.Propagate(err, "")
}
return count, nil
}
func (pcr *FileLinkRepository) RecordAccessHistory(ctx context.Context, shareID string, ip string, ua string) error {
_, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_file_tokens_access_history
(id, ip, user_agent) VALUES ($1, $2, $3)
ON CONFLICT ON CONSTRAINT unique_access_id_ip_ua DO NOTHING;`,
shareID, ip, ua)
return stacktrace.Propagate(err, "failed to record access history")
}
// AccessedInPast returns true if the given ip, ua agent combination has accessed the url in the past
func (pcr *FileLinkRepository) AccessedInPast(ctx context.Context, shareID string, ip string, ua string) (bool, error) {
row := pcr.DB.QueryRowContext(ctx, `select id from public_file_tokens_access_history where id =$1 and ip = $2 and user_agent = $3`,
shareID, ip, ua)
var tempID int64
err := row.Scan(&tempID)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return true, stacktrace.Propagate(err, "failed to record access history")
}
// CleanupAccessHistory public_file_tokens_access_history where public_collection_tokens is disabled and the last updated time is older than 30 days
func (pcr *FileLinkRepository) CleanupAccessHistory(ctx context.Context) error {
_, err := pcr.DB.ExecContext(ctx, `DELETE FROM public_file_tokens_access_history WHERE id IN (SELECT id FROM public_file_tokens WHERE is_disabled = TRUE AND updated_at < (now_utc_micro_seconds() - (24::BIGINT * 30 * 60 * 60 * 1000 * 1000)))`)
if err != nil {
return stacktrace.Propagate(err, "failed to clean up public file access history")
}
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/ente-io/museum/pkg/repo/public"
"strings"
"github.com/ente-io/museum/ente"
@@ -32,10 +33,11 @@ type FileWithUpdatedAt struct {
}
type TrashRepository struct {
DB *sql.DB
ObjectRepo *ObjectRepository
FileRepo *FileRepository
QueueRepo *QueueRepository
DB *sql.DB
ObjectRepo *ObjectRepository
FileRepo *FileRepository
QueueRepo *QueueRepository
FileLinkRepo *public.FileLinkRepository
}
func (t *TrashRepository) InsertItems(ctx context.Context, tx *sql.Tx, userID int64, items []ente.TrashItemRequest) error {
@@ -156,6 +158,13 @@ func (t *TrashRepository) TrashFiles(fileIDs []int64, userID int64, trash ente.T
return stacktrace.Propagate(err, "")
}
err = tx.Commit()
if err == nil {
removeLinkErr := t.FileLinkRepo.DisableLinkForFiles(ctx, fileIDs)
if removeLinkErr != nil {
return stacktrace.Propagate(removeLinkErr, "failed to disable file links for files being trashed")
}
}
return stacktrace.Propagate(err, "")
}

View File

@@ -17,8 +17,9 @@ import (
)
const (
PublicAccessKey = "X-Public-Access-ID"
CastContext = "X-Cast-Context"
PublicAccessKey = "X-Public-Access-ID"
FileLinkAccessKey = "X-Public-FileLink-Access-ID"
CastContext = "X-Cast-Context"
)
// GenerateRandomBytes returns securely generated random bytes.
@@ -120,6 +121,8 @@ func GetCastToken(c *gin.Context) string {
return token
}
// GetAccessTokenJWT fetches the JWT access token from the request header or query parameters.
// This token is issued by server on password verification of links that are protected by password.
func GetAccessTokenJWT(c *gin.Context) string {
token := c.GetHeader("X-Auth-Access-Token-JWT")
if token == "" {
@@ -132,6 +135,10 @@ func MustGetPublicAccessContext(c *gin.Context) ente.PublicAccessContext {
return c.MustGet(PublicAccessKey).(ente.PublicAccessContext)
}
func MustGetFileLinkAccessContext(c *gin.Context) *ente.FileLinkAccessContext {
return c.MustGet(FileLinkAccessKey).(*ente.FileLinkAccessContext)
}
func GetCastCtx(c *gin.Context) cast.AuthContext {
return c.MustGet(CastContext).(cast.AuthContext)
}