diff --git a/server/ente/userentity/entity.go b/server/ente/userentity/entity.go index ede9b8a0dd..904b846d1a 100644 --- a/server/ente/userentity/entity.go +++ b/server/ente/userentity/entity.go @@ -69,6 +69,26 @@ type EntityDataRequest struct { Type EntityType `json:"type" binding:"required"` EncryptedData string `json:"encryptedData" binding:"required"` Header string `json:"header" binding:"required"` + ID *string `json:"id"` // Optional ID, if not provided a new ID will be generated +} + +func (edr *EntityDataRequest) IsValid(userID int64) error { + if err := edr.Type.IsValid(); err != nil { + return err + } + switch edr.Type { + case SmartAlbum: + if edr.ID == nil { + return ente.NewBadRequestWithMessage("ID is required for SmartAlbum entity type") + } + // check if ID starts with sa_userid_ or not + if !strings.HasPrefix(*edr.ID, fmt.Sprintf("sa_%d_", userID)) { + return ente.NewBadRequestWithMessage(fmt.Sprintf("ID %s is not valid for SmartAlbum entity type", *edr.ID)) + } + return nil + default: + return nil + } } // UpdateEntityDataRequest updates the current entity diff --git a/server/pkg/api/userentity.go b/server/pkg/api/userentity.go index acda89a09e..e8aa043b5b 100644 --- a/server/pkg/api/userentity.go +++ b/server/pkg/api/userentity.go @@ -61,10 +61,6 @@ func (h *UserEntityHandler) CreateEntity(c *gin.Context) { stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) return } - if err := request.Type.IsValid(); err != nil { - handler.Error(c, stacktrace.Propagate(err, "Invalid EntityType")) - return - } resp, err := h.Controller.CreateEntity(c, request) if err != nil { handler.Error(c, stacktrace.Propagate(err, "Failed to create CreateEntityKey")) diff --git a/server/pkg/controller/userentity/controller.go b/server/pkg/controller/userentity/controller.go index f4fb1c8b9b..38b99fba22 100644 --- a/server/pkg/controller/userentity/controller.go +++ b/server/pkg/controller/userentity/controller.go @@ -32,6 +32,9 @@ func (c *Controller) GetKey(ctx *gin.Context, req model.GetEntityKeyRequest) (*m // CreateEntity stores entity data for the given type func (c *Controller) CreateEntity(ctx *gin.Context, req model.EntityDataRequest) (*model.EntityData, error) { userID := auth.GetUserID(ctx.Request.Header) + if err := req.IsValid(userID); err != nil { + return nil, stacktrace.Propagate(err, "invalid EntityDataRequest") + } id, err := c.Repo.Create(ctx, userID, req) if err != nil { return nil, stacktrace.Propagate(err, "failed to createEntity") diff --git a/server/pkg/repo/userentity/data.go b/server/pkg/repo/userentity/data.go index fa6457ebfa..45f59d3051 100644 --- a/server/pkg/repo/userentity/data.go +++ b/server/pkg/repo/userentity/data.go @@ -14,12 +14,17 @@ import ( // Create inserts a new entry func (r *Repository) Create(ctx context.Context, userID int64, entry model.EntityDataRequest) (string, error) { - idPrt, err := entry.Type.GetNewID() - if err != nil { - return "", stacktrace.Propagate(err, "failed to generate new id") + var id string + if entry.ID != nil { + id = *entry.ID + } else { + idPrt, err := entry.Type.GetNewID() + if err != nil { + return "", stacktrace.Propagate(err, "failed to generate new id") + } + id = *idPrt } - id := *idPrt - err = r.DB.QueryRow(`INSERT into entity_data( + err := r.DB.QueryRow(`INSERT into entity_data( id, user_id, type,