diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 597826ca92..375bef2798 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -98,7 +98,7 @@ func main() { } viper.SetDefault("apps.public-albums", "https://albums.ente.io") - viper.SetDefault("apps.custom-domain.cname", "https://my.ente.io") + viper.SetDefault("apps.custom-domain.cname", "my.ente.io") viper.SetDefault("apps.public-locker", "https://locker.ente.io") viper.SetDefault("apps.accounts", "https://accounts.ente.io") viper.SetDefault("apps.cast", "https://cast.ente.io") diff --git a/server/ente/remotestore.go b/server/ente/remotestore.go index 02bcc52c1d..3a9a965bf7 100644 --- a/server/ente/remotestore.go +++ b/server/ente/remotestore.go @@ -3,7 +3,7 @@ package ente import ( "fmt" "github.com/ente-io/stacktrace" - "net/url" + "regexp" "strings" ) @@ -119,22 +119,28 @@ func (k FlagKey) IsValidValue(value string) error { return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed") } if k == CustomDomain && value != "" { - if !isValidCustomDomainURL(value) { - return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("invalid domain fmt: %s", value)), "url with https://. Also, tt should not end with trailing dash.") + if err := isValidDomainWithoutScheme(value); err != nil { + return stacktrace.Propagate(err, "invalid custom domain") } } return nil } -func isValidCustomDomainURL(input string) bool { - if !strings.HasPrefix(input, "https://") || strings.HasSuffix(input, "/") { - return false - } +var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`) - u, err := url.Parse(input) - if err != nil || u.Scheme != "https" || u.Host == "" { - return false +func isValidDomainWithoutScheme(input string) error { + trimmed := strings.TrimSpace(input) + if trimmed != input { + return NewBadRequestWithMessage("domain contains leading or trailing spaces") } - - return true + if trimmed == "" { + return NewBadRequestWithMessage("domain is empty") + } + if strings.Contains(trimmed, "://") { + return NewBadRequestWithMessage("domain should not contain scheme (e.g., http:// or https://)") + } + if !domainRegex.MatchString(trimmed) { + return NewBadRequestWithMessage(fmt.Sprintf("invalid domain format: %s", trimmed)) + } + return nil } diff --git a/server/ente/remotestore_test.go b/server/ente/remotestore_test.go new file mode 100644 index 0000000000..1056e5ca6b --- /dev/null +++ b/server/ente/remotestore_test.go @@ -0,0 +1,92 @@ +package ente + +import "testing" + +func TestIsValidDomainWithoutScheme(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + // ✅ Valid cases + {"simple domain", "google.com", false}, + {"multi-level domain", "sub.example.co.in", false}, + {"numeric in label", "a1b2c3.com", false}, + {"long but valid label", "my-very-long-subdomain-name.example.com", false}, + + // ❌ Leading/trailing spaces + {"leading space", " google.com", true}, + {"trailing space", "google.com ", true}, + {"both spaces", " google.com ", true}, + + // ❌ Empty or whitespace + {"empty string", "", true}, + {"only spaces", " ", true}, + + // ❌ Scheme included + {"http scheme", "http://google.com", true}, + {"https scheme", "https://example.com", true}, + {"ftp scheme", "ftp://example.com", true}, + + // ❌ Invalid characters + {"underscore in label", "my_domain.com", true}, + {"invalid symbol", "exa$mple.com", true}, + {"space inside", "exa mple.com", true}, + + // ❌ Wrong format + {"missing dot", "localhost", true}, + {"single label TLD", "com", true}, + {"ends with dot", "example.com.", true}, + {"ends with dash", "example-.com", true}, + {"starts with dash", "-example.com", true}, + + // ❌ Consecutive dots + {"double dots", "example..com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isValidDomainWithoutScheme(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("isValidDomainWithoutScheme(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestFlagKey_IsValidValue(t *testing.T) { + tests := []struct { + name string + key FlagKey + value string + wantErr bool + }{ + // ✅ Valid boolean flag values + {"valid true for bool key", MapEnabled, "true", false}, + {"valid false for bool key", FaceSearchEnabled, "false", false}, + + // ❌ Invalid boolean flag values + {"invalid value for bool key", PassKeyEnabled, "yes", true}, + {"empty value for bool key", IsInternalUser, "", true}, + + // ✅ Valid custom domain values + {"valid custom domain", CustomDomain, "example.com", false}, + {"valid subdomain", CustomDomain, "sub.example.com", false}, + + // ❌ Invalid custom domain values + {"empty custom domain", CustomDomain, "", false}, // Allowed as empty + {"custom domain with scheme", CustomDomain, "http://example.com", true}, + {"custom domain with invalid format", CustomDomain, "exa$mple.com", true}, + {"custom domain with leading space", CustomDomain, " example.com", true}, + {"custom domain with trailing space", CustomDomain, "example.com ", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.key.IsValidValue(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("FlagKey(%q).IsValidValue(%q) error = %v, wantErr %v", tt.key, tt.value, err, tt.wantErr) + } + }) + } +}