Add validation

This commit is contained in:
Neeraj Gupta
2025-08-08 15:31:02 +05:30
parent 3167d85f06
commit 920702c5dd
2 changed files with 44 additions and 7 deletions

View File

@@ -1,5 +1,12 @@
package ente
import (
"fmt"
"github.com/ente-io/stacktrace"
"net/url"
"strings"
)
type GetValueRequest struct {
Key string `form:"key" binding:"required"`
DefaultValue *string `form:"defaultValue"`
@@ -100,7 +107,36 @@ func (k FlagKey) IsBoolType() bool {
switch k {
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, IsInternalUser, IsBetaUser:
return true
default:
case CustomDomain:
return false
default:
return false // Explicitly handle unexpected cases
}
}
func (k FlagKey) IsValidValue(value string) error {
if k.IsBoolType() && value != "true" && value != "false" {
return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
}
if k == CustomDomain && value != "" {
if !isValidCustomDomainURL(value) {
return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("unexcpeted %s", value)), "url with https://. Also, tt should not end with trailing dash.")
}
// ensure that it's valid domain that starts with https and does not end with trailing dash.
return stacktrace.Propagate(NewBadRequestWithMessage("custom domain cannot be empty"), "custom domain cannot be empty")
}
return nil
}
func isValidCustomDomainURL(input string) bool {
if !strings.HasPrefix(input, "https://") || strings.HasSuffix(input, "/") {
return false
}
u, err := url.Parse(input)
if err != nil || u.Scheme != "https" || u.Host == "" {
return false
}
return true
}

View File

@@ -26,6 +26,9 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, request ente.UpdateKeyValu
if err := c._validateRequest(userID, request.Key, request.Value, false); err != nil {
return err
}
if request.Value == "" && ente.FlagKey(request.Key).CanRemove() {
return c.Repo.RemoveKey(ctx, userID, request.Key)
}
return c.Repo.InsertOrUpdate(ctx, userID, request.Key, request.Value)
}
@@ -79,9 +82,6 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons
}
for key, value := range values {
flag := ente.FlagKey(key)
if !flag.IsBoolType() {
continue
}
switch flag {
case ente.RecoveryKeyVerified:
response.RecoveryKeyVerified = value == "true"
@@ -109,15 +109,16 @@ func (c *Controller) _validateRequest(userID int64, key, value string, byAdmin b
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key")
}
flag := ente.FlagKey(key)
if err := flag.IsValidValue(value); err != nil {
return stacktrace.Propagate(err, "")
}
if !flag.UserEditable() && !byAdmin {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not user editable", key)), "key not user editable")
}
if byAdmin && !flag.IsAdminEditable() {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not admin editable", key)), "key not admin editable")
}
if flag.IsBoolType() && value != "true" && value != "false" {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
}
if flag.NeedSubscription() {
return c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
}