[server] Enable metadata fetch for shared files

This commit is contained in:
Neeraj Gupta
2025-01-02 12:13:44 +05:30
parent 65d144be77
commit af533ebc1d
6 changed files with 96 additions and 11 deletions

View File

@@ -10,6 +10,7 @@ import (
type Controller interface {
GetCollection(ctx *gin.Context, req *GetCollectionParams) (*GetCollectionResponse, error)
VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwnershipParams) error
CanAccessFile(ctx *gin.Context, req *CanAccessFileParams) error
}
// controllerImpl implements Controller

View File

@@ -15,6 +15,11 @@ type VerifyFileOwnershipParams struct {
FileIDs []int64
}
type CanAccessFileParams struct {
ActorUserID int64
FileIDs []int64
}
// VerifyFileOwnership will return error if given fileIDs are not valid or don't belong to the ownerID
func (c controllerImpl) VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwnershipParams) error {
if enteArray.ContainsDuplicateInInt64Array(req.FileIDs) {
@@ -26,3 +31,36 @@ func (c controllerImpl) VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwn
})
return c.FileRepo.VerifyFileOwner(ctx, req.FileIDs, ownerID, logger)
}
func (c controllerImpl) CanAccessFile(ctx *gin.Context, req *CanAccessFileParams) error {
if enteArray.ContainsDuplicateInInt64Array(req.FileIDs) {
return stacktrace.Propagate(ente.ErrBadRequest, "duplicate fileIDs")
}
ownerToFilesMap, err := c.FileRepo.GetOwnerToFileIDsMap(ctx, req.FileIDs)
if err != nil {
return stacktrace.Propagate(err, "failed to get owner to fileIDs map")
}
// iterate over the map and check if the ownerID has access to the fileIDs
for owner, fileIDs := range ownerToFilesMap {
if owner == req.ActorUserID {
continue
}
cIDs, collErr := c.CollectionRepo.GetCollectionIDsSharedWithUser(req.ActorUserID)
if collErr != nil {
return stacktrace.Propagate(collErr, "")
}
cwIDS, collErr := c.CollectionRepo.GetCollectionIDsSharedWithUser(owner)
if collErr != nil {
return stacktrace.Propagate(collErr, "")
}
cIDs = append(cIDs, cwIDS...)
accessErr := c.CollectionRepo.DoAllFilesExistInGivenCollections(fileIDs, cIDs)
if accessErr != nil {
log.WithFields(log.Fields{
"req_id": requestid.Get(ctx),
}).WithError(accessErr).Error("access check failed")
return stacktrace.Propagate(ente.ErrPermissionDenied, "access denied")
}
}
return nil
}

View File

@@ -81,7 +81,7 @@ func (c *Controller) InsertOrUpdateMetadata(ctx *gin.Context, req *fileData.PutF
return stacktrace.Propagate(err, "validation failed")
}
userID := auth.GetUserID(ctx.Request.Header)
err := c._validatePermission(ctx, req.FileID, userID)
err := c._validateWritePermission(ctx, req.FileID, userID)
if err != nil {
return stacktrace.Propagate(err, "")
}
@@ -126,7 +126,7 @@ func (c *Controller) GetFileData(ctx *gin.Context, req fileData.GetFileData) (*f
if err := req.Validate(); err != nil {
return nil, stacktrace.Propagate(err, "validation failed")
}
if err := c._validatePermission(ctx, req.FileID, auth.GetUserID(ctx.Request.Header)); err != nil {
if err := c._validateWritePermission(ctx, req.FileID, auth.GetUserID(ctx.Request.Header)); err != nil {
return nil, stacktrace.Propagate(err, "")
}
doRows, err := c.Repo.GetFilesData(ctx, req.Type, []int64{req.FileID})
@@ -150,7 +150,7 @@ func (c *Controller) GetFileData(ctx *gin.Context, req fileData.GetFileData) (*f
func (c *Controller) GetFilesData(ctx *gin.Context, req fileData.GetFilesData) (*fileData.GetFilesDataResponse, error) {
userID := auth.GetUserID(ctx.Request.Header)
if err := c._validateGetFilesData(ctx, userID, req); err != nil {
if err := c._validateReadPermission(ctx, userID, req); err != nil {
return nil, stacktrace.Propagate(err, "")
}
@@ -273,21 +273,20 @@ func (c *Controller) fetchS3FileMetadata(ctx context.Context, row fileData.Row,
return nil, stacktrace.Propagate(errors.New("failed to fetch object"), "")
}
func (c *Controller) _validateGetFilesData(ctx *gin.Context, userID int64, req fileData.GetFilesData) error {
func (c *Controller) _validateReadPermission(ctx *gin.Context, userID int64, req fileData.GetFilesData) error {
if err := req.Validate(); err != nil {
return stacktrace.Propagate(err, "validation failed")
}
if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
ActorUserId: userID,
if err := c.AccessCtrl.CanAccessFile(ctx, &access.CanAccessFileParams{
ActorUserID: userID,
FileIDs: req.FileIDs,
}); err != nil {
return stacktrace.Propagate(err, "User does not own some file(s)")
}
return nil
}
func (c *Controller) _validatePermission(ctx *gin.Context, fileID int64, actorID int64) error {
func (c *Controller) _validateWritePermission(ctx *gin.Context, fileID int64, actorID int64) error {
err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
ActorUserId: actorID,
FileIDs: []int64{fileID},

View File

@@ -13,7 +13,7 @@ func (c *Controller) GetPreviewUrl(ctx *gin.Context, request filedata.GetPreview
return nil, err
}
actorUser := auth.GetUserID(ctx.Request.Header)
if err := c._validatePermission(ctx, request.FileID, actorUser); err != nil {
if err := c._validateWritePermission(ctx, request.FileID, actorUser); err != nil {
return nil, err
}
data, err := c.Repo.GetFilesData(ctx, request.Type, []int64{request.FileID})
@@ -35,7 +35,7 @@ func (c *Controller) PreviewUploadURL(ctx *gin.Context, request filedata.Preview
return nil, err
}
actorUser := auth.GetUserID(ctx.Request.Header)
if err := c._validatePermission(ctx, request.FileID, actorUser); err != nil {
if err := c._validateWritePermission(ctx, request.FileID, actorUser); err != nil {
return nil, err
}
fileOwnerID, err := c.FileRepo.GetOwnerID(request.FileID)

View File

@@ -16,7 +16,7 @@ func (c *Controller) InsertVideoPreview(ctx *gin.Context, req *filedata.VidPrevi
return stacktrace.Propagate(err, "validation failed")
}
userID := auth.GetUserID(ctx.Request.Header)
err := c._validatePermission(ctx, req.FileID, userID)
err := c._validateWritePermission(ctx, req.FileID, userID)
if err != nil {
return stacktrace.Propagate(err, "")
}

View File

@@ -375,6 +375,53 @@ func (repo *CollectionRepository) DoesFileExistInCollections(fileID int64, cIDs
return exists, stacktrace.Propagate(err, "")
}
func (repo *CollectionRepository) DoAllFilesExistInGivenCollections(fileIDs []int64, cIDs []int64) error {
// Query to get all distinct file_ids that exist in the collections
rows, err := repo.DB.Query(`
SELECT DISTINCT file_id
FROM collection_files
WHERE file_id = ANY ($1)
AND is_deleted = false
AND collection_id = ANY ($2)`,
pq.Array(fileIDs), pq.Array(cIDs))
if err != nil {
return stacktrace.Propagate(err, "")
}
defer rows.Close()
// Create a map of input fileIDs for easy lookup
fileIDMap := make(map[int64]bool)
for _, id := range fileIDs {
fileIDMap[id] = false // false means not found yet
}
// Mark files that were found
for rows.Next() {
var fileID int64
if err := rows.Scan(&fileID); err != nil {
return stacktrace.Propagate(err, "")
}
fileIDMap[fileID] = true // mark as found
}
if err = rows.Err(); err != nil {
return stacktrace.Propagate(err, "")
}
// Collect missing files
var missingFiles []int64
for id, found := range fileIDMap {
if !found {
missingFiles = append(missingFiles, id)
}
}
if len(missingFiles) > 0 {
logrus.WithField("missingFiles", missingFiles).Info("missing files")
return stacktrace.Propagate(fmt.Errorf("missing files %v", missingFiles), "")
}
return nil
}
// VerifyAllFileIDsExistsInCollection returns error if the fileIDs don't exist in the collection
func (repo *CollectionRepository) VerifyAllFileIDsExistsInCollection(ctx context.Context, cID int64, fileIDs []int64) error {
fileIdMap := make(map[int64]bool)