[server] Support for configuring custom-domains (#6827)
## Description ## Tests - Verified server db migrations is on 103 - Verified that duplicate custom domain results in error - basic sanity testing for custom domain validation.
This commit is contained in:
@@ -98,6 +98,7 @@ func main() {
|
||||
}
|
||||
|
||||
viper.SetDefault("apps.public-albums", "https://albums.ente.io")
|
||||
viper.SetDefault("apps.custom-domain.cname", "my.ente.io")
|
||||
viper.SetDefault("apps.public-locker", "https://locker.ente.io")
|
||||
viper.SetDefault("apps.accounts", "https://accounts.ente.io")
|
||||
viper.SetDefault("apps.cast", "https://cast.ente.io")
|
||||
@@ -213,7 +214,6 @@ func main() {
|
||||
commonBillController := commonbilling.NewController(emailNotificationCtrl, storagBonusRepo, userRepo, usageRepo, billingRepo)
|
||||
appStoreController := controller.NewAppStoreController(defaultPlan,
|
||||
billingRepo, fileRepo, userRepo, commonBillController)
|
||||
remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository}
|
||||
playStoreController := controller.NewPlayStoreController(defaultPlan,
|
||||
billingRepo, fileRepo, userRepo, storagBonusRepo, commonBillController)
|
||||
stripeController := controller.NewStripeController(plans, stripeClients,
|
||||
@@ -222,6 +222,8 @@ func main() {
|
||||
appStoreController, playStoreController, stripeController,
|
||||
discordController, emailNotificationCtrl,
|
||||
billingRepo, userRepo, usageRepo, storagBonusRepo, commonBillController)
|
||||
remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository, BillingCtrl: billingController}
|
||||
|
||||
pushController := controller.NewPushController(pushRepo, taskLockingRepo, hostName)
|
||||
mailingListsController := controller.NewMailingListsController()
|
||||
|
||||
@@ -376,6 +378,7 @@ func main() {
|
||||
Cache: accessTokenCache,
|
||||
BillingCtrl: billingController,
|
||||
DiscordController: discordController,
|
||||
RemoteStoreRepo: remoteStoreRepository,
|
||||
}
|
||||
fileLinkMiddleware := &middleware.FileLinkMiddleware{
|
||||
FileLinkRepo: fileLinkRepo,
|
||||
@@ -788,8 +791,10 @@ func main() {
|
||||
remoteStoreHandler := &api.RemoteStoreHandler{Controller: remoteStoreController}
|
||||
|
||||
privateAPI.POST("/remote-store/update", remoteStoreHandler.InsertOrUpdate)
|
||||
privateAPI.DELETE("/remote-store/:key", remoteStoreHandler.RemoveKey)
|
||||
privateAPI.GET("/remote-store", remoteStoreHandler.GetKey)
|
||||
privateAPI.GET("/remote-store/feature-flags", remoteStoreHandler.GetFeatureFlags)
|
||||
publicAPI.GET("/custom-domain", remoteStoreHandler.CheckDomain)
|
||||
|
||||
pushHandler := &api.PushHandler{PushController: pushController}
|
||||
privateAPI.POST("/push/token", pushHandler.AddToken)
|
||||
|
||||
@@ -95,6 +95,9 @@ apps:
|
||||
accounts:
|
||||
# Default is https://family.ente.io
|
||||
family:
|
||||
custom-domain:
|
||||
# Default is my.ente.io
|
||||
cname:
|
||||
|
||||
# Database connection parameters
|
||||
db:
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package ente
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GetValueRequest struct {
|
||||
Key string `form:"key" binding:"required"`
|
||||
DefaultValue *string `form:"defaultValue"`
|
||||
@@ -10,28 +17,30 @@ type GetValueResponse struct {
|
||||
}
|
||||
|
||||
type UpdateKeyValueRequest struct {
|
||||
Key string `json:"key" binding:"required"`
|
||||
Value string `json:"value" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
Value *string `json:"value" binding:"required"`
|
||||
}
|
||||
|
||||
type AdminUpdateKeyValueRequest struct {
|
||||
UserID int64 `json:"userID" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
Value string `json:"value" binding:"required"`
|
||||
UserID int64 `json:"userID" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
Value *string `json:"value" binding:"required"`
|
||||
}
|
||||
|
||||
type FeatureFlagResponse struct {
|
||||
EnableStripe bool `json:"enableStripe"`
|
||||
// If true, the mobile client will stop using CF worker to download files
|
||||
DisableCFWorker bool `json:"disableCFWorker"`
|
||||
MapEnabled bool `json:"mapEnabled"`
|
||||
FaceSearchEnabled bool `json:"faceSearchEnabled"`
|
||||
PassKeyEnabled bool `json:"passKeyEnabled"`
|
||||
RecoveryKeyVerified bool `json:"recoveryKeyVerified"`
|
||||
InternalUser bool `json:"internalUser"`
|
||||
BetaUser bool `json:"betaUser"`
|
||||
EnableMobMultiPart bool `json:"enableMobMultiPart"`
|
||||
CastUrl string `json:"castUrl"`
|
||||
DisableCFWorker bool `json:"disableCFWorker"`
|
||||
MapEnabled bool `json:"mapEnabled"`
|
||||
FaceSearchEnabled bool `json:"faceSearchEnabled"`
|
||||
PassKeyEnabled bool `json:"passKeyEnabled"`
|
||||
RecoveryKeyVerified bool `json:"recoveryKeyVerified"`
|
||||
InternalUser bool `json:"internalUser"`
|
||||
BetaUser bool `json:"betaUser"`
|
||||
EnableMobMultiPart bool `json:"enableMobMultiPart"`
|
||||
CastUrl string `json:"castUrl"`
|
||||
CustomDomain *string `json:"customDomain,omitempty"`
|
||||
CustomDomainCNAME string `json:"customDomainCNAME,omitempty"`
|
||||
}
|
||||
|
||||
type FlagKey string
|
||||
@@ -43,8 +52,24 @@ const (
|
||||
PassKeyEnabled FlagKey = "passKeyEnabled"
|
||||
IsInternalUser FlagKey = "internalUser"
|
||||
IsBetaUser FlagKey = "betaUser"
|
||||
CustomDomain FlagKey = "customDomain"
|
||||
)
|
||||
|
||||
var validFlagKeys = map[FlagKey]struct{}{
|
||||
RecoveryKeyVerified: {},
|
||||
MapEnabled: {},
|
||||
FaceSearchEnabled: {},
|
||||
PassKeyEnabled: {},
|
||||
IsInternalUser: {},
|
||||
IsBetaUser: {},
|
||||
CustomDomain: {},
|
||||
}
|
||||
|
||||
func IsValidFlagKey(key string) bool {
|
||||
_, exists := validFlagKeys[FlagKey(key)]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (k FlagKey) String() string {
|
||||
return string(k)
|
||||
}
|
||||
@@ -52,13 +77,21 @@ func (k FlagKey) String() string {
|
||||
// UserEditable returns true if the key is user editable
|
||||
func (k FlagKey) UserEditable() bool {
|
||||
switch k {
|
||||
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled:
|
||||
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, CustomDomain:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (k FlagKey) NeedSubscription() bool {
|
||||
return k == CustomDomain
|
||||
}
|
||||
|
||||
func (k FlagKey) CanRemove() bool {
|
||||
return k == CustomDomain
|
||||
}
|
||||
|
||||
func (k FlagKey) IsAdminEditable() bool {
|
||||
switch k {
|
||||
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled:
|
||||
@@ -74,7 +107,40 @@ func (k FlagKey) IsBoolType() bool {
|
||||
switch k {
|
||||
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, IsInternalUser, IsBetaUser:
|
||||
return true
|
||||
default:
|
||||
case CustomDomain:
|
||||
return false
|
||||
default:
|
||||
return false // Explicitly handle unexpected cases
|
||||
}
|
||||
}
|
||||
|
||||
func (k FlagKey) IsValidValue(value string) error {
|
||||
if k.IsBoolType() && value != "true" && value != "false" {
|
||||
return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
|
||||
}
|
||||
if k == CustomDomain && value != "" {
|
||||
if err := isValidDomainWithoutScheme(value); err != nil {
|
||||
return stacktrace.Propagate(err, "invalid custom domain")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`)
|
||||
|
||||
func isValidDomainWithoutScheme(input string) error {
|
||||
trimmed := strings.TrimSpace(input)
|
||||
if trimmed != input {
|
||||
return NewBadRequestWithMessage("domain contains leading or trailing spaces")
|
||||
}
|
||||
if trimmed == "" {
|
||||
return NewBadRequestWithMessage("domain is empty")
|
||||
}
|
||||
if strings.Contains(trimmed, "://") {
|
||||
return NewBadRequestWithMessage("domain should not contain scheme (e.g., http:// or https://)")
|
||||
}
|
||||
if !domainRegex.MatchString(trimmed) {
|
||||
return NewBadRequestWithMessage(fmt.Sprintf("invalid domain format: %s", trimmed))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
92
server/ente/remotestore_test.go
Normal file
92
server/ente/remotestore_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package ente
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsValidDomainWithoutScheme(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
// ✅ Valid cases
|
||||
{"simple domain", "google.com", false},
|
||||
{"multi-level domain", "sub.example.co.in", false},
|
||||
{"numeric in label", "a1b2c3.com", false},
|
||||
{"long but valid label", "my-very-long-subdomain-name.example.com", false},
|
||||
|
||||
// ❌ Leading/trailing spaces
|
||||
{"leading space", " google.com", true},
|
||||
{"trailing space", "google.com ", true},
|
||||
{"both spaces", " google.com ", true},
|
||||
|
||||
// ❌ Empty or whitespace
|
||||
{"empty string", "", true},
|
||||
{"only spaces", " ", true},
|
||||
|
||||
// ❌ Scheme included
|
||||
{"http scheme", "http://google.com", true},
|
||||
{"https scheme", "https://example.com", true},
|
||||
{"ftp scheme", "ftp://example.com", true},
|
||||
|
||||
// ❌ Invalid characters
|
||||
{"underscore in label", "my_domain.com", true},
|
||||
{"invalid symbol", "exa$mple.com", true},
|
||||
{"space inside", "exa mple.com", true},
|
||||
|
||||
// ❌ Wrong format
|
||||
{"missing dot", "localhost", true},
|
||||
{"single label TLD", "com", true},
|
||||
{"ends with dot", "example.com.", true},
|
||||
{"ends with dash", "example-.com", true},
|
||||
{"starts with dash", "-example.com", true},
|
||||
|
||||
// ❌ Consecutive dots
|
||||
{"double dots", "example..com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := isValidDomainWithoutScheme(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("isValidDomainWithoutScheme(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlagKey_IsValidValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key FlagKey
|
||||
value string
|
||||
wantErr bool
|
||||
}{
|
||||
// ✅ Valid boolean flag values
|
||||
{"valid true for bool key", MapEnabled, "true", false},
|
||||
{"valid false for bool key", FaceSearchEnabled, "false", false},
|
||||
|
||||
// ❌ Invalid boolean flag values
|
||||
{"invalid value for bool key", PassKeyEnabled, "yes", true},
|
||||
{"empty value for bool key", IsInternalUser, "", true},
|
||||
|
||||
// ✅ Valid custom domain values
|
||||
{"valid custom domain", CustomDomain, "example.com", false},
|
||||
{"valid subdomain", CustomDomain, "sub.example.com", false},
|
||||
|
||||
// ❌ Invalid custom domain values
|
||||
{"empty custom domain", CustomDomain, "", false}, // Allowed as empty
|
||||
{"custom domain with scheme", CustomDomain, "http://example.com", true},
|
||||
{"custom domain with invalid format", CustomDomain, "exa$mple.com", true},
|
||||
{"custom domain with leading space", CustomDomain, " example.com", true},
|
||||
{"custom domain with trailing space", CustomDomain, "example.com ", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.key.IsValidValue(tt.value)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FlagKey(%q).IsValidValue(%q) error = %v, wantErr %v", tt.key, tt.value, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1
server/migrations/104_rs_custom_domain.down.sql
Normal file
1
server/migrations/104_rs_custom_domain.down.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS remote_store_custom_domain_unique_idx;
|
||||
3
server/migrations/104_rs_custom_domain.up.sql
Normal file
3
server/migrations/104_rs_custom_domain.up.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS remote_store_custom_domain_unique_idx
|
||||
ON remote_store (key_value)
|
||||
WHERE key_name = 'customDomain';
|
||||
@@ -3,13 +3,14 @@ package api
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/controller/emergency"
|
||||
"github.com/ente-io/museum/pkg/controller/remotestore"
|
||||
"github.com/ente-io/museum/pkg/repo/authenticator"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ente-io/museum/pkg/controller/emergency"
|
||||
"github.com/ente-io/museum/pkg/controller/remotestore"
|
||||
"github.com/ente-io/museum/pkg/repo/authenticator"
|
||||
|
||||
"github.com/ente-io/museum/pkg/controller/family"
|
||||
|
||||
bonusEntity "github.com/ente-io/museum/ente/storagebonus"
|
||||
@@ -377,7 +378,7 @@ func (h *AdminHandler) UpdateFeatureFlag(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
go h.DiscordController.NotifyAdminAction(
|
||||
fmt.Sprintf("Admin (%d) updating flag:%s to val:%s for %d", auth.GetUserID(c.Request.Header), request.Key, request.Value, request.UserID))
|
||||
fmt.Sprintf("Admin (%d) updating flag:%s to val:%v for %d", auth.GetUserID(c.Request.Header), request.Key, request.Value, request.UserID))
|
||||
|
||||
logger := logrus.WithFields(logrus.Fields{
|
||||
"user_id": request.UserID,
|
||||
|
||||
@@ -33,6 +33,20 @@ func (h *RemoteStoreHandler) InsertOrUpdate(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *RemoteStoreHandler) RemoveKey(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
if key == "" {
|
||||
handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("key is missing"), ""))
|
||||
return
|
||||
}
|
||||
err := h.Controller.RemoveKey(c, key)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, "failed to update key's value"))
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// GetKey handler for fetching a value for particular key
|
||||
func (h *RemoteStoreHandler) GetKey(c *gin.Context) {
|
||||
var request ente.GetValueRequest
|
||||
@@ -59,3 +73,18 @@ func (h *RemoteStoreHandler) GetFeatureFlags(c *gin.Context) {
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// CheckDomain returns 200 ok if the custom domain is claimed by any ente user
|
||||
func (h *RemoteStoreHandler) CheckDomain(c *gin.Context) {
|
||||
domain := c.Query("domain")
|
||||
if domain == "" {
|
||||
handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("domain is missing"), ""))
|
||||
return
|
||||
}
|
||||
_, err := h.Controller.DomainOwner(c, domain)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, "failed to get feature flags"))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
@@ -15,23 +16,39 @@ import (
|
||||
|
||||
// Controller is interface for exposing business logic related to for remote store
|
||||
type Controller struct {
|
||||
Repo *remotestore.Repository
|
||||
Repo *remotestore.Repository
|
||||
BillingCtrl *controller.BillingController
|
||||
}
|
||||
|
||||
// InsertOrUpdate the key's value
|
||||
func (c *Controller) InsertOrUpdate(ctx *gin.Context, request ente.UpdateKeyValueRequest) error {
|
||||
if err := _validateRequest(request.Key, request.Value, false); err != nil {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
if err := c._validateRequest(userID, request.Key, request.Value, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if *request.Value == "" && ente.FlagKey(request.Key).CanRemove() {
|
||||
return c.Repo.RemoveKey(ctx, userID, request.Key)
|
||||
}
|
||||
return c.Repo.InsertOrUpdate(ctx, userID, request.Key, *request.Value)
|
||||
}
|
||||
|
||||
// RemoveKey removes the key from remote store
|
||||
func (c *Controller) RemoveKey(ctx *gin.Context, key string) error {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
return c.Repo.InsertOrUpdate(ctx, userID, request.Key, request.Value)
|
||||
if valid := ente.IsValidFlagKey(key); !valid {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key")
|
||||
}
|
||||
if !ente.FlagKey(key).CanRemove() {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not removable", key)), "key not removable")
|
||||
}
|
||||
return c.Repo.RemoveKey(ctx, userID, key)
|
||||
}
|
||||
|
||||
func (c *Controller) AdminInsertOrUpdate(ctx *gin.Context, request ente.AdminUpdateKeyValueRequest) error {
|
||||
if err := _validateRequest(request.Key, request.Value, true); err != nil {
|
||||
if err := c._validateRequest(request.UserID, request.Key, request.Value, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, request.Value)
|
||||
return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, *request.Value)
|
||||
}
|
||||
|
||||
func (c *Controller) Get(ctx *gin.Context, req ente.GetValueRequest) (*ente.GetValueResponse, error) {
|
||||
@@ -61,12 +78,10 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons
|
||||
// except internal user.rt
|
||||
EnableMobMultiPart: true,
|
||||
CastUrl: viper.GetString("apps.cast"),
|
||||
CustomDomainCNAME: viper.GetString("apps.custom-domain.cname"),
|
||||
}
|
||||
for key, value := range values {
|
||||
flag := ente.FlagKey(key)
|
||||
if !flag.IsBoolType() {
|
||||
continue
|
||||
}
|
||||
switch flag {
|
||||
case ente.RecoveryKeyVerified:
|
||||
response.RecoveryKeyVerified = value == "true"
|
||||
@@ -80,21 +95,40 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons
|
||||
response.InternalUser = value == "true"
|
||||
case ente.IsBetaUser:
|
||||
response.BetaUser = value == "true"
|
||||
case ente.CustomDomain:
|
||||
if value != "" {
|
||||
response.CustomDomain = &value
|
||||
}
|
||||
}
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func _validateRequest(key, value string, byAdmin bool) error {
|
||||
func (c *Controller) DomainOwner(ctx *gin.Context, domain string) (*int64, error) {
|
||||
return c.Repo.DomainOwner(ctx, domain)
|
||||
}
|
||||
|
||||
func (c *Controller) _validateRequest(userID int64, key string, valuePtr *string, byAdmin bool) error {
|
||||
if !ente.IsValidFlagKey(key) {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key")
|
||||
}
|
||||
if valuePtr == nil {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage("value is missing"), "value is nil")
|
||||
}
|
||||
value := *valuePtr
|
||||
flag := ente.FlagKey(key)
|
||||
if err := flag.IsValidValue(value); err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
if !flag.UserEditable() && !byAdmin {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not user editable", key)), "key not user editable")
|
||||
}
|
||||
if byAdmin && !flag.IsAdminEditable() {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not admin editable", key)), "key not admin editable")
|
||||
}
|
||||
if flag.IsBoolType() && value != "true" && value != "false" {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
|
||||
|
||||
if flag.NeedSubscription() {
|
||||
return c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,12 @@ import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/ente-io/museum/pkg/repo/remotestore"
|
||||
"github.com/gin-contrib/requestid"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
public2 "github.com/ente-io/museum/pkg/controller/public"
|
||||
"github.com/ente-io/museum/pkg/repo/public"
|
||||
@@ -35,6 +41,7 @@ type CollectionLinkMiddleware struct {
|
||||
Cache *cache.Cache
|
||||
BillingCtrl *controller.BillingController
|
||||
DiscordController *discord.DiscordController
|
||||
RemoteStoreRepo *remotestore.Repository
|
||||
}
|
||||
|
||||
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
|
||||
@@ -65,7 +72,7 @@ func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context
|
||||
return
|
||||
}
|
||||
// validate if user still has active paid subscription
|
||||
if err = m.validateOwnersSubscription(publicCollectionSummary.CollectionID); err != nil {
|
||||
if err = m.validateOwnersSubscription(c, publicCollectionSummary.CollectionID); err != nil {
|
||||
logrus.WithError(err).Warn("failed to verify active paid subscription")
|
||||
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
|
||||
return
|
||||
@@ -115,12 +122,17 @@ func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
func (m *CollectionLinkMiddleware) validateOwnersSubscription(cID int64) error {
|
||||
func (m *CollectionLinkMiddleware) validateOwnersSubscription(c *gin.Context, cID int64) error {
|
||||
userID, err := m.CollectionRepo.GetOwnerID(cID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
|
||||
err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "failed to validate owners subscription")
|
||||
}
|
||||
m.validateOrigin(c, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context,
|
||||
@@ -177,6 +189,42 @@ func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath stri
|
||||
return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *collectionSummary.PassHash)
|
||||
}
|
||||
|
||||
func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
if origin == "" || origin == viper.GetString("apps.public-albums") {
|
||||
return
|
||||
}
|
||||
reqId := requestid.Get(c)
|
||||
logger := logrus.WithFields(logrus.Fields{
|
||||
"ownerID": ownerID,
|
||||
"req_id": reqId,
|
||||
"origin": origin,
|
||||
})
|
||||
alertMessage := fmt.Sprintf("custom domain check failed %s", reqId)
|
||||
domain, err := m.RemoteStoreRepo.GetDomain(c, ownerID)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to fetch custom domain for owner")
|
||||
m.DiscordController.NotifyPotentialAbuse(alertMessage)
|
||||
return
|
||||
}
|
||||
if domain == nil || *domain == "" {
|
||||
logger.Warn("custom domain is nil or empty")
|
||||
m.DiscordController.NotifyPotentialAbuse(alertMessage)
|
||||
return
|
||||
}
|
||||
parse, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to parse origin URL")
|
||||
m.DiscordController.NotifyPotentialAbuse(alertMessage + " - failed to parse origin URL")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(parse.Host), strings.ToLower(*domain)) {
|
||||
logger.Warnf("custom domain check failed for owner %d, origin %s, domain %s", ownerID, origin, *domain)
|
||||
m.DiscordController.NotifyPotentialAbuse(alertMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func computeHashKeyForList(list []string, delim string) string {
|
||||
var buffer bytes.Buffer
|
||||
for i := range list {
|
||||
|
||||
@@ -133,7 +133,8 @@ func (r *RateLimitMiddleware) APIRateLimitForUserMiddleware(urlSanitizer func(_
|
||||
// getLimiter, based on reqPath & reqMethod, return instance of limiter.Limiter which needs to
|
||||
// be applied for a request. It returns nil if the request is not rate limited
|
||||
func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limiter.Limiter {
|
||||
if reqPath == "/users/public-key" {
|
||||
if reqPath == "/users/public-key" ||
|
||||
reqPath == "/custom-domain" {
|
||||
return r.limit200ReqPerMin
|
||||
}
|
||||
if reqPath == "/users/ott" ||
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const BaseShareURL = "https://albums.ente.io/?t=%s"
|
||||
|
||||
// CollectionLinkRepo defines the methods for inserting, updating and
|
||||
// retrieving entities related to public collections
|
||||
type CollectionLinkRepo struct {
|
||||
|
||||
@@ -3,6 +3,9 @@ package remotestore
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/ente-io/stacktrace"
|
||||
)
|
||||
@@ -21,9 +24,66 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, userID int64, key strin
|
||||
key, // $2 key_name
|
||||
value, // $3 key_value
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
// Check for unique violation (PostgreSQL error code 23505)
|
||||
var pgErr *pq.Error
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
if pgErr.Constraint == "remote_store_custom_domain_unique_idx" {
|
||||
return ente.NewConflictError("custom domain already exists for another user")
|
||||
}
|
||||
}
|
||||
return stacktrace.Propagate(err, "failed to insert/update")
|
||||
}
|
||||
return stacktrace.Propagate(err, "failed to insert/update")
|
||||
}
|
||||
|
||||
func (r *Repository) RemoveKey(ctx context.Context, userID int64, key string) error {
|
||||
_, err := r.DB.ExecContext(ctx, `DELETE FROM remote_store
|
||||
WHERE user_id = $1 AND key_name = $2`,
|
||||
userID, // $1
|
||||
key, // $2
|
||||
)
|
||||
return stacktrace.Propagate(err, "failed to remove key")
|
||||
}
|
||||
|
||||
func (r *Repository) DomainOwner(ctx context.Context, domain string) (*int64, error) {
|
||||
// Check if the domain is already taken by another user
|
||||
rows := r.DB.QueryRowContext(ctx, `SELECT user_id FROM remote_store
|
||||
WHERE key_name = $1 AND key_value = $2`,
|
||||
ente.CustomDomain, // $1
|
||||
domain, // $2
|
||||
)
|
||||
var userID int64
|
||||
err := rows.Scan(&userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "")
|
||||
}
|
||||
return nil, stacktrace.Propagate(err, "failed to fetch domain owner")
|
||||
}
|
||||
return &userID, nil
|
||||
}
|
||||
|
||||
func (r *Repository) GetDomain(ctx context.Context, userID int64) (*string, error) {
|
||||
// Fetch the custom domain for the user
|
||||
rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store
|
||||
WHERE user_id = $1 AND key_name = $2`,
|
||||
userID, // $1
|
||||
ente.CustomDomain, // $2
|
||||
)
|
||||
var domain string
|
||||
err := rows.Scan(&domain)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, stacktrace.Propagate(err, "failed to fetch custom domain")
|
||||
}
|
||||
return &domain, nil
|
||||
|
||||
}
|
||||
|
||||
// GetValue fetches and return the value for given user_id and key
|
||||
func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (string, error) {
|
||||
rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store
|
||||
|
||||
Reference in New Issue
Block a user