Add validation
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
package ente
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GetValueRequest struct {
|
||||
Key string `form:"key" binding:"required"`
|
||||
DefaultValue *string `form:"defaultValue"`
|
||||
@@ -100,7 +107,36 @@ func (k FlagKey) IsBoolType() bool {
|
||||
switch k {
|
||||
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, IsInternalUser, IsBetaUser:
|
||||
return true
|
||||
default:
|
||||
case CustomDomain:
|
||||
return false
|
||||
default:
|
||||
return false // Explicitly handle unexpected cases
|
||||
}
|
||||
}
|
||||
|
||||
func (k FlagKey) IsValidValue(value string) error {
|
||||
if k.IsBoolType() && value != "true" && value != "false" {
|
||||
return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
|
||||
}
|
||||
if k == CustomDomain && value != "" {
|
||||
if !isValidCustomDomainURL(value) {
|
||||
return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("unexcpeted %s", value)), "url with https://. Also, tt should not end with trailing dash.")
|
||||
}
|
||||
// ensure that it's valid domain that starts with https and does not end with trailing dash.
|
||||
return stacktrace.Propagate(NewBadRequestWithMessage("custom domain cannot be empty"), "custom domain cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isValidCustomDomainURL(input string) bool {
|
||||
if !strings.HasPrefix(input, "https://") || strings.HasSuffix(input, "/") {
|
||||
return false
|
||||
}
|
||||
|
||||
u, err := url.Parse(input)
|
||||
if err != nil || u.Scheme != "https" || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -26,6 +26,9 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, request ente.UpdateKeyValu
|
||||
if err := c._validateRequest(userID, request.Key, request.Value, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if request.Value == "" && ente.FlagKey(request.Key).CanRemove() {
|
||||
return c.Repo.RemoveKey(ctx, userID, request.Key)
|
||||
}
|
||||
return c.Repo.InsertOrUpdate(ctx, userID, request.Key, request.Value)
|
||||
}
|
||||
|
||||
@@ -79,9 +82,6 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons
|
||||
}
|
||||
for key, value := range values {
|
||||
flag := ente.FlagKey(key)
|
||||
if !flag.IsBoolType() {
|
||||
continue
|
||||
}
|
||||
switch flag {
|
||||
case ente.RecoveryKeyVerified:
|
||||
response.RecoveryKeyVerified = value == "true"
|
||||
@@ -109,15 +109,16 @@ func (c *Controller) _validateRequest(userID int64, key, value string, byAdmin b
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key")
|
||||
}
|
||||
flag := ente.FlagKey(key)
|
||||
if err := flag.IsValidValue(value); err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
if !flag.UserEditable() && !byAdmin {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not user editable", key)), "key not user editable")
|
||||
}
|
||||
if byAdmin && !flag.IsAdminEditable() {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not admin editable", key)), "key not admin editable")
|
||||
}
|
||||
if flag.IsBoolType() && value != "true" && value != "false" {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
|
||||
}
|
||||
|
||||
if flag.NeedSubscription() {
|
||||
return c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user