diff --git a/server/pkg/api/file.go b/server/pkg/api/file.go index 990336e372..064bc3be08 100644 --- a/server/pkg/api/file.go +++ b/server/pkg/api/file.go @@ -139,7 +139,7 @@ func (h *FileHandler) GetMultipartUploadURLs(c *gin.Context) { // Get redirects the request to the file location func (h *FileHandler) Get(c *gin.Context) { userID, fileID := getUserAndFileIDs(c) - url, err := h.Controller.GetFileURL(userID, fileID) + url, err := h.Controller.GetFileURL(c, userID, fileID) if err != nil { handler.Error(c, stacktrace.Propagate(err, "")) return @@ -151,7 +151,7 @@ func (h *FileHandler) Get(c *gin.Context) { // GetV2 returns the URL of the file to client func (h *FileHandler) GetV2(c *gin.Context) { userID, fileID := getUserAndFileIDs(c) - url, err := h.Controller.GetFileURL(userID, fileID) + url, err := h.Controller.GetFileURL(c, userID, fileID) if err != nil { handler.Error(c, stacktrace.Propagate(err, "")) return @@ -164,7 +164,7 @@ func (h *FileHandler) GetV2(c *gin.Context) { // GetThumbnail redirects the request to the file's thumbnail location func (h *FileHandler) GetThumbnail(c *gin.Context) { userID, fileID := getUserAndFileIDs(c) - url, err := h.Controller.GetThumbnailURL(userID, fileID) + url, err := h.Controller.GetThumbnailURL(c, userID, fileID) if err != nil { handler.Error(c, stacktrace.Propagate(err, "")) return @@ -176,7 +176,7 @@ func (h *FileHandler) GetThumbnail(c *gin.Context) { // GetThumbnailV2 returns the URL of the thumbnail to the client func (h *FileHandler) GetThumbnailV2(c *gin.Context) { userID, fileID := getUserAndFileIDs(c) - url, err := h.Controller.GetThumbnailURL(userID, fileID) + url, err := h.Controller.GetThumbnailURL(c, userID, fileID) if err != nil { handler.Error(c, stacktrace.Propagate(err, "")) return diff --git a/server/pkg/controller/file.go b/server/pkg/controller/file.go index b3fec115d0..a988b8c683 100644 --- a/server/pkg/controller/file.go +++ b/server/pkg/controller/file.go @@ -285,12 +285,12 @@ func (c *FileController) GetUploadURLs(ctx context.Context, userID int64, count } // GetFileURL verifies permissions and returns a presigned url to the requested file -func (c *FileController) GetFileURL(userID int64, fileID int64) (string, error) { +func (c *FileController) GetFileURL(ctx *gin.Context, userID int64, fileID int64) (string, error) { err := c.verifyFileAccess(userID, fileID) if err != nil { return "", stacktrace.Propagate(err, "") } - url, err := c.getSignedURLForType(fileID, ente.FILE) + url, err := c.getSignedURLForType(ctx, fileID, ente.FILE) if err != nil { if errors.Is(err, sql.ErrNoRows) { go c.CleanUpStaleCollectionFiles(userID, fileID) @@ -301,12 +301,12 @@ func (c *FileController) GetFileURL(userID int64, fileID int64) (string, error) } // GetThumbnailURL verifies permissions and returns a presigned url to the requested thumbnail -func (c *FileController) GetThumbnailURL(userID int64, fileID int64) (string, error) { +func (c *FileController) GetThumbnailURL(ctx *gin.Context, userID int64, fileID int64) (string, error) { err := c.verifyFileAccess(userID, fileID) if err != nil { return "", stacktrace.Propagate(err, "") } - url, err := c.getSignedURLForType(fileID, ente.THUMBNAIL) + url, err := c.getSignedURLForType(ctx, fileID, ente.THUMBNAIL) if err != nil { if errors.Is(err, sql.ErrNoRows) { go c.CleanUpStaleCollectionFiles(userID, fileID) @@ -356,7 +356,7 @@ func (c *FileController) GetPublicFileURL(ctx *gin.Context, fileID int64, objTyp if !accessible { return "", stacktrace.Propagate(ente.ErrPermissionDenied, "") } - return c.getSignedURLForType(fileID, objType) + return c.getSignedURLForType(ctx, fileID, objType) } // GetCastFileUrl verifies permissions and returns a presigned url to the requested file @@ -369,15 +369,46 @@ func (c *FileController) GetCastFileUrl(ctx *gin.Context, fileID int64, objType if !accessible { return "", stacktrace.Propagate(ente.ErrPermissionDenied, "") } - return c.getSignedURLForType(fileID, objType) + return c.getSignedURLForType(ctx, fileID, objType) } -func (c *FileController) getSignedURLForType(fileID int64, objType ente.ObjectType) (string, error) { +func (c *FileController) getSignedURLForType(ctx *gin.Context, fileID int64, objType ente.ObjectType) (string, error) { + if isCliRequest(ctx) { + return c.getWasabiSignedUrlIfAvailable(fileID, objType) + } s3Object, err := c.ObjectRepo.GetObject(fileID, objType) if err != nil { return "", stacktrace.Propagate(err, "") } - return c.getPreSignedURL(s3Object.ObjectKey) + return c.getHotDcSignedUrl(s3Object.ObjectKey) +} + +func isCliRequest(ctx *gin.Context) bool { + // check if user-agent contains go-resty + userAgent := ctx.Request.Header.Get("User-Agent") + if strings.Contains(userAgent, "go-resty") { + return true + } + return false +} + +// getWasabiSignedUrlIfAvailable returns a signed URL for the given fileID and objectType. It prefers wasabi over b2 +// if the file is not found in wasabi, it will return signed url from B2 +func (c *FileController) getWasabiSignedUrlIfAvailable(fileID int64, objType ente.ObjectType) (string, error) { + s3Object, dcs, err := c.ObjectRepo.GetObjectWithDCs(fileID, objType) + if err != nil { + return "", stacktrace.Propagate(err, "") + } + for _, dc := range dcs { + if dc == c.S3Config.GetHotWasabiDC() { + return c.getPreSignedURLForDC(s3Object.ObjectKey, dc) + } + } + // todo: (neeraj) remove this log after some time + log.WithFields(log.Fields{ + "fileID": fileID}).Info("File not found in wasabi, returning signed url from B2") + // return signed url from default hot bucket + return c.getHotDcSignedUrl(s3Object.ObjectKey) } // Trash deletes file and move them to trash @@ -704,7 +735,7 @@ func (c *FileController) cleanupDeletedFile(qItem repo.QueueItem) { ctxLogger.Info("Successfully deleted item") } -func (c *FileController) getPreSignedURL(objectKey string) (string, error) { +func (c *FileController) getHotDcSignedUrl(objectKey string) (string, error) { s3Client := c.S3Config.GetHotS3Client() r, _ := s3Client.GetObjectRequest(&s3.GetObjectInput{ Bucket: c.S3Config.GetHotBucket(), @@ -713,6 +744,15 @@ func (c *FileController) getPreSignedURL(objectKey string) (string, error) { return r.Presign(PreSignedRequestValidityDuration) } +func (c *FileController) getPreSignedURLForDC(objectKey string, dc string) (string, error) { + s3Client := c.S3Config.GetS3Client(dc) + r, _ := s3Client.GetObjectRequest(&s3.GetObjectInput{ + Bucket: c.S3Config.GetBucket(dc), + Key: &objectKey, + }) + return r.Presign(PreSignedRequestValidityDuration) +} + func (c *FileController) sizeOf(objectKey string) (int64, error) { s3Client := c.S3Config.GetHotS3Client() head, err := s3Client.HeadObject(&s3.HeadObjectInput{ diff --git a/server/pkg/repo/object.go b/server/pkg/repo/object.go index fdbbbf52c0..052278402d 100644 --- a/server/pkg/repo/object.go +++ b/server/pkg/repo/object.go @@ -64,6 +64,16 @@ func (repo *ObjectRepository) GetObject(fileID int64, objType ente.ObjectType) ( return s3ObjectKey, stacktrace.Propagate(err, "") } +func (repo *ObjectRepository) GetObjectWithDCs(fileID int64, objType ente.ObjectType) (ente.S3ObjectKey, []string, error) { + row := repo.DB.QueryRow(`SELECT object_key, size, o_type, datacenters FROM object_keys WHERE file_id = $1 AND o_type = $2 AND is_deleted=false`, + fileID, objType) + var s3ObjectKey ente.S3ObjectKey + var datacenters []string + s3ObjectKey.FileID = fileID + err := row.Scan(&s3ObjectKey.ObjectKey, &s3ObjectKey.FileSize, &s3ObjectKey.Type, pq.Array(&datacenters)) + return s3ObjectKey, datacenters, stacktrace.Propagate(err, "") +} + func (repo *ObjectRepository) GetAllFileObjectsByObjectKey(objectKey string) ([]ente.S3ObjectKey, error) { rows, err := repo.DB.Query(`SELECT file_id, o_type, object_key, size from object_keys where file_id in (select file_id from object_keys where object_key= $1)