From af533ebc1de3c210b806360f63f80792bcc799aa Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Thu, 2 Jan 2025 12:13:44 +0530 Subject: [PATCH] [server] Enable metadata fetch for shared files --- server/pkg/controller/access/access.go | 1 + server/pkg/controller/access/file.go | 38 +++++++++++++++ server/pkg/controller/filedata/controller.go | 15 +++--- .../pkg/controller/filedata/preview_files.go | 4 +- server/pkg/controller/filedata/video.go | 2 +- server/pkg/repo/collection.go | 47 +++++++++++++++++++ 6 files changed, 96 insertions(+), 11 deletions(-) diff --git a/server/pkg/controller/access/access.go b/server/pkg/controller/access/access.go index b2acbff25b..60c77d54e0 100644 --- a/server/pkg/controller/access/access.go +++ b/server/pkg/controller/access/access.go @@ -10,6 +10,7 @@ import ( type Controller interface { GetCollection(ctx *gin.Context, req *GetCollectionParams) (*GetCollectionResponse, error) VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwnershipParams) error + CanAccessFile(ctx *gin.Context, req *CanAccessFileParams) error } // controllerImpl implements Controller diff --git a/server/pkg/controller/access/file.go b/server/pkg/controller/access/file.go index 9c6cfdc160..fb50ee83a8 100644 --- a/server/pkg/controller/access/file.go +++ b/server/pkg/controller/access/file.go @@ -15,6 +15,11 @@ type VerifyFileOwnershipParams struct { FileIDs []int64 } +type CanAccessFileParams struct { + ActorUserID int64 + FileIDs []int64 +} + // VerifyFileOwnership will return error if given fileIDs are not valid or don't belong to the ownerID func (c controllerImpl) VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwnershipParams) error { if enteArray.ContainsDuplicateInInt64Array(req.FileIDs) { @@ -26,3 +31,36 @@ func (c controllerImpl) VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwn }) return c.FileRepo.VerifyFileOwner(ctx, req.FileIDs, ownerID, logger) } + +func (c controllerImpl) CanAccessFile(ctx *gin.Context, req *CanAccessFileParams) error { + if enteArray.ContainsDuplicateInInt64Array(req.FileIDs) { + return stacktrace.Propagate(ente.ErrBadRequest, "duplicate fileIDs") + } + ownerToFilesMap, err := c.FileRepo.GetOwnerToFileIDsMap(ctx, req.FileIDs) + if err != nil { + return stacktrace.Propagate(err, "failed to get owner to fileIDs map") + } + // iterate over the map and check if the ownerID has access to the fileIDs + for owner, fileIDs := range ownerToFilesMap { + if owner == req.ActorUserID { + continue + } + cIDs, collErr := c.CollectionRepo.GetCollectionIDsSharedWithUser(req.ActorUserID) + if collErr != nil { + return stacktrace.Propagate(collErr, "") + } + cwIDS, collErr := c.CollectionRepo.GetCollectionIDsSharedWithUser(owner) + if collErr != nil { + return stacktrace.Propagate(collErr, "") + } + cIDs = append(cIDs, cwIDS...) + accessErr := c.CollectionRepo.DoAllFilesExistInGivenCollections(fileIDs, cIDs) + if accessErr != nil { + log.WithFields(log.Fields{ + "req_id": requestid.Get(ctx), + }).WithError(accessErr).Error("access check failed") + return stacktrace.Propagate(ente.ErrPermissionDenied, "access denied") + } + } + return nil +} diff --git a/server/pkg/controller/filedata/controller.go b/server/pkg/controller/filedata/controller.go index deea74ef91..aa04376bb6 100644 --- a/server/pkg/controller/filedata/controller.go +++ b/server/pkg/controller/filedata/controller.go @@ -81,7 +81,7 @@ func (c *Controller) InsertOrUpdateMetadata(ctx *gin.Context, req *fileData.PutF return stacktrace.Propagate(err, "validation failed") } userID := auth.GetUserID(ctx.Request.Header) - err := c._validatePermission(ctx, req.FileID, userID) + err := c._validateWritePermission(ctx, req.FileID, userID) if err != nil { return stacktrace.Propagate(err, "") } @@ -126,7 +126,7 @@ func (c *Controller) GetFileData(ctx *gin.Context, req fileData.GetFileData) (*f if err := req.Validate(); err != nil { return nil, stacktrace.Propagate(err, "validation failed") } - if err := c._validatePermission(ctx, req.FileID, auth.GetUserID(ctx.Request.Header)); err != nil { + if err := c._validateWritePermission(ctx, req.FileID, auth.GetUserID(ctx.Request.Header)); err != nil { return nil, stacktrace.Propagate(err, "") } doRows, err := c.Repo.GetFilesData(ctx, req.Type, []int64{req.FileID}) @@ -150,7 +150,7 @@ func (c *Controller) GetFileData(ctx *gin.Context, req fileData.GetFileData) (*f func (c *Controller) GetFilesData(ctx *gin.Context, req fileData.GetFilesData) (*fileData.GetFilesDataResponse, error) { userID := auth.GetUserID(ctx.Request.Header) - if err := c._validateGetFilesData(ctx, userID, req); err != nil { + if err := c._validateReadPermission(ctx, userID, req); err != nil { return nil, stacktrace.Propagate(err, "") } @@ -273,21 +273,20 @@ func (c *Controller) fetchS3FileMetadata(ctx context.Context, row fileData.Row, return nil, stacktrace.Propagate(errors.New("failed to fetch object"), "") } -func (c *Controller) _validateGetFilesData(ctx *gin.Context, userID int64, req fileData.GetFilesData) error { +func (c *Controller) _validateReadPermission(ctx *gin.Context, userID int64, req fileData.GetFilesData) error { if err := req.Validate(); err != nil { return stacktrace.Propagate(err, "validation failed") } - if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{ - ActorUserId: userID, + if err := c.AccessCtrl.CanAccessFile(ctx, &access.CanAccessFileParams{ + ActorUserID: userID, FileIDs: req.FileIDs, }); err != nil { return stacktrace.Propagate(err, "User does not own some file(s)") } - return nil } -func (c *Controller) _validatePermission(ctx *gin.Context, fileID int64, actorID int64) error { +func (c *Controller) _validateWritePermission(ctx *gin.Context, fileID int64, actorID int64) error { err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{ ActorUserId: actorID, FileIDs: []int64{fileID}, diff --git a/server/pkg/controller/filedata/preview_files.go b/server/pkg/controller/filedata/preview_files.go index 087cb825c2..d9a7d21362 100644 --- a/server/pkg/controller/filedata/preview_files.go +++ b/server/pkg/controller/filedata/preview_files.go @@ -13,7 +13,7 @@ func (c *Controller) GetPreviewUrl(ctx *gin.Context, request filedata.GetPreview return nil, err } actorUser := auth.GetUserID(ctx.Request.Header) - if err := c._validatePermission(ctx, request.FileID, actorUser); err != nil { + if err := c._validateWritePermission(ctx, request.FileID, actorUser); err != nil { return nil, err } data, err := c.Repo.GetFilesData(ctx, request.Type, []int64{request.FileID}) @@ -35,7 +35,7 @@ func (c *Controller) PreviewUploadURL(ctx *gin.Context, request filedata.Preview return nil, err } actorUser := auth.GetUserID(ctx.Request.Header) - if err := c._validatePermission(ctx, request.FileID, actorUser); err != nil { + if err := c._validateWritePermission(ctx, request.FileID, actorUser); err != nil { return nil, err } fileOwnerID, err := c.FileRepo.GetOwnerID(request.FileID) diff --git a/server/pkg/controller/filedata/video.go b/server/pkg/controller/filedata/video.go index 0c88b43cba..47ded8a3b1 100644 --- a/server/pkg/controller/filedata/video.go +++ b/server/pkg/controller/filedata/video.go @@ -16,7 +16,7 @@ func (c *Controller) InsertVideoPreview(ctx *gin.Context, req *filedata.VidPrevi return stacktrace.Propagate(err, "validation failed") } userID := auth.GetUserID(ctx.Request.Header) - err := c._validatePermission(ctx, req.FileID, userID) + err := c._validateWritePermission(ctx, req.FileID, userID) if err != nil { return stacktrace.Propagate(err, "") } diff --git a/server/pkg/repo/collection.go b/server/pkg/repo/collection.go index 4a0de28175..add9f6aafb 100644 --- a/server/pkg/repo/collection.go +++ b/server/pkg/repo/collection.go @@ -375,6 +375,53 @@ func (repo *CollectionRepository) DoesFileExistInCollections(fileID int64, cIDs return exists, stacktrace.Propagate(err, "") } +func (repo *CollectionRepository) DoAllFilesExistInGivenCollections(fileIDs []int64, cIDs []int64) error { + // Query to get all distinct file_ids that exist in the collections + rows, err := repo.DB.Query(` + SELECT DISTINCT file_id + FROM collection_files + WHERE file_id = ANY ($1) + AND is_deleted = false + AND collection_id = ANY ($2)`, + pq.Array(fileIDs), pq.Array(cIDs)) + + if err != nil { + return stacktrace.Propagate(err, "") + } + defer rows.Close() + + // Create a map of input fileIDs for easy lookup + fileIDMap := make(map[int64]bool) + for _, id := range fileIDs { + fileIDMap[id] = false // false means not found yet + } + // Mark files that were found + for rows.Next() { + var fileID int64 + if err := rows.Scan(&fileID); err != nil { + return stacktrace.Propagate(err, "") + } + fileIDMap[fileID] = true // mark as found + } + + if err = rows.Err(); err != nil { + return stacktrace.Propagate(err, "") + } + + // Collect missing files + var missingFiles []int64 + for id, found := range fileIDMap { + if !found { + missingFiles = append(missingFiles, id) + } + } + if len(missingFiles) > 0 { + logrus.WithField("missingFiles", missingFiles).Info("missing files") + return stacktrace.Propagate(fmt.Errorf("missing files %v", missingFiles), "") + } + return nil +} + // VerifyAllFileIDsExistsInCollection returns error if the fileIDs don't exist in the collection func (repo *CollectionRepository) VerifyAllFileIDsExistsInCollection(ctx context.Context, cID int64, fileIDs []int64) error { fileIdMap := make(map[int64]bool)