diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 6c61e50273..abdd9e58ce 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -98,6 +98,7 @@ func main() { } viper.SetDefault("apps.public-albums", "https://albums.ente.io") + viper.SetDefault("apps.custom-domain.cname", "my.ente.io") viper.SetDefault("apps.public-locker", "https://locker.ente.io") viper.SetDefault("apps.accounts", "https://accounts.ente.io") viper.SetDefault("apps.cast", "https://cast.ente.io") @@ -213,7 +214,6 @@ func main() { commonBillController := commonbilling.NewController(emailNotificationCtrl, storagBonusRepo, userRepo, usageRepo, billingRepo) appStoreController := controller.NewAppStoreController(defaultPlan, billingRepo, fileRepo, userRepo, commonBillController) - remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository} playStoreController := controller.NewPlayStoreController(defaultPlan, billingRepo, fileRepo, userRepo, storagBonusRepo, commonBillController) stripeController := controller.NewStripeController(plans, stripeClients, @@ -222,6 +222,8 @@ func main() { appStoreController, playStoreController, stripeController, discordController, emailNotificationCtrl, billingRepo, userRepo, usageRepo, storagBonusRepo, commonBillController) + remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository, BillingCtrl: billingController} + pushController := controller.NewPushController(pushRepo, taskLockingRepo, hostName) mailingListsController := controller.NewMailingListsController() @@ -376,6 +378,7 @@ func main() { Cache: accessTokenCache, BillingCtrl: billingController, DiscordController: discordController, + RemoteStoreRepo: remoteStoreRepository, } fileLinkMiddleware := &middleware.FileLinkMiddleware{ FileLinkRepo: fileLinkRepo, @@ -788,8 +791,10 @@ func main() { remoteStoreHandler := &api.RemoteStoreHandler{Controller: remoteStoreController} privateAPI.POST("/remote-store/update", remoteStoreHandler.InsertOrUpdate) + privateAPI.DELETE("/remote-store/:key", remoteStoreHandler.RemoveKey) privateAPI.GET("/remote-store", remoteStoreHandler.GetKey) privateAPI.GET("/remote-store/feature-flags", remoteStoreHandler.GetFeatureFlags) + publicAPI.GET("/custom-domain", remoteStoreHandler.CheckDomain) pushHandler := &api.PushHandler{PushController: pushController} privateAPI.POST("/push/token", pushHandler.AddToken) diff --git a/server/configurations/local.yaml b/server/configurations/local.yaml index a16f560e43..70d7589280 100644 --- a/server/configurations/local.yaml +++ b/server/configurations/local.yaml @@ -95,6 +95,9 @@ apps: accounts: # Default is https://family.ente.io family: + custom-domain: + # Default is my.ente.io + cname: # Database connection parameters db: diff --git a/server/ente/remotestore.go b/server/ente/remotestore.go index 8546fd3cf0..ea96d8de1c 100644 --- a/server/ente/remotestore.go +++ b/server/ente/remotestore.go @@ -1,5 +1,12 @@ package ente +import ( + "fmt" + "github.com/ente-io/stacktrace" + "regexp" + "strings" +) + type GetValueRequest struct { Key string `form:"key" binding:"required"` DefaultValue *string `form:"defaultValue"` @@ -10,28 +17,30 @@ type GetValueResponse struct { } type UpdateKeyValueRequest struct { - Key string `json:"key" binding:"required"` - Value string `json:"value" binding:"required"` + Key string `json:"key" binding:"required"` + Value *string `json:"value" binding:"required"` } type AdminUpdateKeyValueRequest struct { - UserID int64 `json:"userID" binding:"required"` - Key string `json:"key" binding:"required"` - Value string `json:"value" binding:"required"` + UserID int64 `json:"userID" binding:"required"` + Key string `json:"key" binding:"required"` + Value *string `json:"value" binding:"required"` } type FeatureFlagResponse struct { EnableStripe bool `json:"enableStripe"` // If true, the mobile client will stop using CF worker to download files - DisableCFWorker bool `json:"disableCFWorker"` - MapEnabled bool `json:"mapEnabled"` - FaceSearchEnabled bool `json:"faceSearchEnabled"` - PassKeyEnabled bool `json:"passKeyEnabled"` - RecoveryKeyVerified bool `json:"recoveryKeyVerified"` - InternalUser bool `json:"internalUser"` - BetaUser bool `json:"betaUser"` - EnableMobMultiPart bool `json:"enableMobMultiPart"` - CastUrl string `json:"castUrl"` + DisableCFWorker bool `json:"disableCFWorker"` + MapEnabled bool `json:"mapEnabled"` + FaceSearchEnabled bool `json:"faceSearchEnabled"` + PassKeyEnabled bool `json:"passKeyEnabled"` + RecoveryKeyVerified bool `json:"recoveryKeyVerified"` + InternalUser bool `json:"internalUser"` + BetaUser bool `json:"betaUser"` + EnableMobMultiPart bool `json:"enableMobMultiPart"` + CastUrl string `json:"castUrl"` + CustomDomain *string `json:"customDomain,omitempty"` + CustomDomainCNAME string `json:"customDomainCNAME,omitempty"` } type FlagKey string @@ -43,8 +52,24 @@ const ( PassKeyEnabled FlagKey = "passKeyEnabled" IsInternalUser FlagKey = "internalUser" IsBetaUser FlagKey = "betaUser" + CustomDomain FlagKey = "customDomain" ) +var validFlagKeys = map[FlagKey]struct{}{ + RecoveryKeyVerified: {}, + MapEnabled: {}, + FaceSearchEnabled: {}, + PassKeyEnabled: {}, + IsInternalUser: {}, + IsBetaUser: {}, + CustomDomain: {}, +} + +func IsValidFlagKey(key string) bool { + _, exists := validFlagKeys[FlagKey(key)] + return exists +} + func (k FlagKey) String() string { return string(k) } @@ -52,13 +77,21 @@ func (k FlagKey) String() string { // UserEditable returns true if the key is user editable func (k FlagKey) UserEditable() bool { switch k { - case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled: + case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, CustomDomain: return true default: return false } } +func (k FlagKey) NeedSubscription() bool { + return k == CustomDomain +} + +func (k FlagKey) CanRemove() bool { + return k == CustomDomain +} + func (k FlagKey) IsAdminEditable() bool { switch k { case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled: @@ -74,7 +107,40 @@ func (k FlagKey) IsBoolType() bool { switch k { case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, IsInternalUser, IsBetaUser: return true - default: + case CustomDomain: return false + default: + return false // Explicitly handle unexpected cases } } + +func (k FlagKey) IsValidValue(value string) error { + if k.IsBoolType() && value != "true" && value != "false" { + return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed") + } + if k == CustomDomain && value != "" { + if err := isValidDomainWithoutScheme(value); err != nil { + return stacktrace.Propagate(err, "invalid custom domain") + } + } + return nil +} + +var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`) + +func isValidDomainWithoutScheme(input string) error { + trimmed := strings.TrimSpace(input) + if trimmed != input { + return NewBadRequestWithMessage("domain contains leading or trailing spaces") + } + if trimmed == "" { + return NewBadRequestWithMessage("domain is empty") + } + if strings.Contains(trimmed, "://") { + return NewBadRequestWithMessage("domain should not contain scheme (e.g., http:// or https://)") + } + if !domainRegex.MatchString(trimmed) { + return NewBadRequestWithMessage(fmt.Sprintf("invalid domain format: %s", trimmed)) + } + return nil +} diff --git a/server/ente/remotestore_test.go b/server/ente/remotestore_test.go new file mode 100644 index 0000000000..1056e5ca6b --- /dev/null +++ b/server/ente/remotestore_test.go @@ -0,0 +1,92 @@ +package ente + +import "testing" + +func TestIsValidDomainWithoutScheme(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + // ✅ Valid cases + {"simple domain", "google.com", false}, + {"multi-level domain", "sub.example.co.in", false}, + {"numeric in label", "a1b2c3.com", false}, + {"long but valid label", "my-very-long-subdomain-name.example.com", false}, + + // ❌ Leading/trailing spaces + {"leading space", " google.com", true}, + {"trailing space", "google.com ", true}, + {"both spaces", " google.com ", true}, + + // ❌ Empty or whitespace + {"empty string", "", true}, + {"only spaces", " ", true}, + + // ❌ Scheme included + {"http scheme", "http://google.com", true}, + {"https scheme", "https://example.com", true}, + {"ftp scheme", "ftp://example.com", true}, + + // ❌ Invalid characters + {"underscore in label", "my_domain.com", true}, + {"invalid symbol", "exa$mple.com", true}, + {"space inside", "exa mple.com", true}, + + // ❌ Wrong format + {"missing dot", "localhost", true}, + {"single label TLD", "com", true}, + {"ends with dot", "example.com.", true}, + {"ends with dash", "example-.com", true}, + {"starts with dash", "-example.com", true}, + + // ❌ Consecutive dots + {"double dots", "example..com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isValidDomainWithoutScheme(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("isValidDomainWithoutScheme(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestFlagKey_IsValidValue(t *testing.T) { + tests := []struct { + name string + key FlagKey + value string + wantErr bool + }{ + // ✅ Valid boolean flag values + {"valid true for bool key", MapEnabled, "true", false}, + {"valid false for bool key", FaceSearchEnabled, "false", false}, + + // ❌ Invalid boolean flag values + {"invalid value for bool key", PassKeyEnabled, "yes", true}, + {"empty value for bool key", IsInternalUser, "", true}, + + // ✅ Valid custom domain values + {"valid custom domain", CustomDomain, "example.com", false}, + {"valid subdomain", CustomDomain, "sub.example.com", false}, + + // ❌ Invalid custom domain values + {"empty custom domain", CustomDomain, "", false}, // Allowed as empty + {"custom domain with scheme", CustomDomain, "http://example.com", true}, + {"custom domain with invalid format", CustomDomain, "exa$mple.com", true}, + {"custom domain with leading space", CustomDomain, " example.com", true}, + {"custom domain with trailing space", CustomDomain, "example.com ", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.key.IsValidValue(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("FlagKey(%q).IsValidValue(%q) error = %v, wantErr %v", tt.key, tt.value, err, tt.wantErr) + } + }) + } +} diff --git a/server/migrations/104_rs_custom_domain.down.sql b/server/migrations/104_rs_custom_domain.down.sql new file mode 100644 index 0000000000..de4e692485 --- /dev/null +++ b/server/migrations/104_rs_custom_domain.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS remote_store_custom_domain_unique_idx; diff --git a/server/migrations/104_rs_custom_domain.up.sql b/server/migrations/104_rs_custom_domain.up.sql new file mode 100644 index 0000000000..4ab6b896f2 --- /dev/null +++ b/server/migrations/104_rs_custom_domain.up.sql @@ -0,0 +1,3 @@ +CREATE UNIQUE INDEX IF NOT EXISTS remote_store_custom_domain_unique_idx + ON remote_store (key_value) + WHERE key_name = 'customDomain'; diff --git a/server/pkg/api/admin.go b/server/pkg/api/admin.go index 124d21078f..7ae784ab41 100644 --- a/server/pkg/api/admin.go +++ b/server/pkg/api/admin.go @@ -3,13 +3,14 @@ package api import ( "errors" "fmt" - "github.com/ente-io/museum/pkg/controller/emergency" - "github.com/ente-io/museum/pkg/controller/remotestore" - "github.com/ente-io/museum/pkg/repo/authenticator" "net/http" "strconv" "strings" + "github.com/ente-io/museum/pkg/controller/emergency" + "github.com/ente-io/museum/pkg/controller/remotestore" + "github.com/ente-io/museum/pkg/repo/authenticator" + "github.com/ente-io/museum/pkg/controller/family" bonusEntity "github.com/ente-io/museum/ente/storagebonus" @@ -377,7 +378,7 @@ func (h *AdminHandler) UpdateFeatureFlag(c *gin.Context) { return } go h.DiscordController.NotifyAdminAction( - fmt.Sprintf("Admin (%d) updating flag:%s to val:%s for %d", auth.GetUserID(c.Request.Header), request.Key, request.Value, request.UserID)) + fmt.Sprintf("Admin (%d) updating flag:%s to val:%v for %d", auth.GetUserID(c.Request.Header), request.Key, request.Value, request.UserID)) logger := logrus.WithFields(logrus.Fields{ "user_id": request.UserID, diff --git a/server/pkg/api/remotestore.go b/server/pkg/api/remotestore.go index 9f03554de8..e83883c2dc 100644 --- a/server/pkg/api/remotestore.go +++ b/server/pkg/api/remotestore.go @@ -33,6 +33,20 @@ func (h *RemoteStoreHandler) InsertOrUpdate(c *gin.Context) { c.Status(http.StatusOK) } +func (h *RemoteStoreHandler) RemoveKey(c *gin.Context) { + key := c.Param("key") + if key == "" { + handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("key is missing"), "")) + return + } + err := h.Controller.RemoveKey(c, key) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to update key's value")) + return + } + c.Status(http.StatusOK) +} + // GetKey handler for fetching a value for particular key func (h *RemoteStoreHandler) GetKey(c *gin.Context) { var request ente.GetValueRequest @@ -59,3 +73,18 @@ func (h *RemoteStoreHandler) GetFeatureFlags(c *gin.Context) { } c.JSON(http.StatusOK, resp) } + +// CheckDomain returns 200 ok if the custom domain is claimed by any ente user +func (h *RemoteStoreHandler) CheckDomain(c *gin.Context) { + domain := c.Query("domain") + if domain == "" { + handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("domain is missing"), "")) + return + } + _, err := h.Controller.DomainOwner(c, domain) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to get feature flags")) + return + } + c.JSON(http.StatusOK, gin.H{}) +} diff --git a/server/pkg/controller/remotestore/controller.go b/server/pkg/controller/remotestore/controller.go index f031f65033..eac04534ba 100644 --- a/server/pkg/controller/remotestore/controller.go +++ b/server/pkg/controller/remotestore/controller.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/ente-io/museum/pkg/controller" "github.com/spf13/viper" "github.com/ente-io/museum/ente" @@ -15,23 +16,39 @@ import ( // Controller is interface for exposing business logic related to for remote store type Controller struct { - Repo *remotestore.Repository + Repo *remotestore.Repository + BillingCtrl *controller.BillingController } // InsertOrUpdate the key's value func (c *Controller) InsertOrUpdate(ctx *gin.Context, request ente.UpdateKeyValueRequest) error { - if err := _validateRequest(request.Key, request.Value, false); err != nil { + userID := auth.GetUserID(ctx.Request.Header) + if err := c._validateRequest(userID, request.Key, request.Value, false); err != nil { return err } + if *request.Value == "" && ente.FlagKey(request.Key).CanRemove() { + return c.Repo.RemoveKey(ctx, userID, request.Key) + } + return c.Repo.InsertOrUpdate(ctx, userID, request.Key, *request.Value) +} + +// RemoveKey removes the key from remote store +func (c *Controller) RemoveKey(ctx *gin.Context, key string) error { userID := auth.GetUserID(ctx.Request.Header) - return c.Repo.InsertOrUpdate(ctx, userID, request.Key, request.Value) + if valid := ente.IsValidFlagKey(key); !valid { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key") + } + if !ente.FlagKey(key).CanRemove() { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not removable", key)), "key not removable") + } + return c.Repo.RemoveKey(ctx, userID, key) } func (c *Controller) AdminInsertOrUpdate(ctx *gin.Context, request ente.AdminUpdateKeyValueRequest) error { - if err := _validateRequest(request.Key, request.Value, true); err != nil { + if err := c._validateRequest(request.UserID, request.Key, request.Value, true); err != nil { return err } - return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, request.Value) + return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, *request.Value) } func (c *Controller) Get(ctx *gin.Context, req ente.GetValueRequest) (*ente.GetValueResponse, error) { @@ -61,12 +78,10 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons // except internal user.rt EnableMobMultiPart: true, CastUrl: viper.GetString("apps.cast"), + CustomDomainCNAME: viper.GetString("apps.custom-domain.cname"), } for key, value := range values { flag := ente.FlagKey(key) - if !flag.IsBoolType() { - continue - } switch flag { case ente.RecoveryKeyVerified: response.RecoveryKeyVerified = value == "true" @@ -80,21 +95,40 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons response.InternalUser = value == "true" case ente.IsBetaUser: response.BetaUser = value == "true" + case ente.CustomDomain: + if value != "" { + response.CustomDomain = &value + } } } return response, nil } -func _validateRequest(key, value string, byAdmin bool) error { +func (c *Controller) DomainOwner(ctx *gin.Context, domain string) (*int64, error) { + return c.Repo.DomainOwner(ctx, domain) +} + +func (c *Controller) _validateRequest(userID int64, key string, valuePtr *string, byAdmin bool) error { + if !ente.IsValidFlagKey(key) { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key") + } + if valuePtr == nil { + return stacktrace.Propagate(ente.NewBadRequestWithMessage("value is missing"), "value is nil") + } + value := *valuePtr flag := ente.FlagKey(key) + if err := flag.IsValidValue(value); err != nil { + return stacktrace.Propagate(err, "") + } if !flag.UserEditable() && !byAdmin { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not user editable", key)), "key not user editable") } if byAdmin && !flag.IsAdminEditable() { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not admin editable", key)), "key not admin editable") } - if flag.IsBoolType() && value != "true" && value != "false" { - return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed") + + if flag.NeedSubscription() { + return c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true) } return nil } diff --git a/server/pkg/middleware/collection_link.go b/server/pkg/middleware/collection_link.go index 9c20362559..a0e59df2ca 100644 --- a/server/pkg/middleware/collection_link.go +++ b/server/pkg/middleware/collection_link.go @@ -6,6 +6,12 @@ import ( "crypto/sha256" "fmt" "net/http" + "net/url" + "strings" + + "github.com/ente-io/museum/pkg/repo/remotestore" + "github.com/gin-contrib/requestid" + "github.com/spf13/viper" public2 "github.com/ente-io/museum/pkg/controller/public" "github.com/ente-io/museum/pkg/repo/public" @@ -35,6 +41,7 @@ type CollectionLinkMiddleware struct { Cache *cache.Cache BillingCtrl *controller.BillingController DiscordController *discord.DiscordController + RemoteStoreRepo *remotestore.Repository } // Authenticate returns a middle ware that extracts the `X-Auth-Access-Token` @@ -65,7 +72,7 @@ func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context return } // validate if user still has active paid subscription - if err = m.validateOwnersSubscription(publicCollectionSummary.CollectionID); err != nil { + if err = m.validateOwnersSubscription(c, publicCollectionSummary.CollectionID); err != nil { logrus.WithError(err).Warn("failed to verify active paid subscription") c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"}) return @@ -115,12 +122,17 @@ func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context c.Next() } } -func (m *CollectionLinkMiddleware) validateOwnersSubscription(cID int64) error { +func (m *CollectionLinkMiddleware) validateOwnersSubscription(c *gin.Context, cID int64) error { userID, err := m.CollectionRepo.GetOwnerID(cID) if err != nil { return stacktrace.Propagate(err, "") } - return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false) + err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false) + if err != nil { + return stacktrace.Propagate(err, "failed to validate owners subscription") + } + m.validateOrigin(c, userID) + return nil } func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context, @@ -177,6 +189,42 @@ func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath stri return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *collectionSummary.PassHash) } +func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) { + origin := c.Request.Header.Get("Origin") + + if origin == "" || origin == viper.GetString("apps.public-albums") { + return + } + reqId := requestid.Get(c) + logger := logrus.WithFields(logrus.Fields{ + "ownerID": ownerID, + "req_id": reqId, + "origin": origin, + }) + alertMessage := fmt.Sprintf("custom domain check failed %s", reqId) + domain, err := m.RemoteStoreRepo.GetDomain(c, ownerID) + if err != nil { + logger.WithError(err).Error("failed to fetch custom domain for owner") + m.DiscordController.NotifyPotentialAbuse(alertMessage) + return + } + if domain == nil || *domain == "" { + logger.Warn("custom domain is nil or empty") + m.DiscordController.NotifyPotentialAbuse(alertMessage) + return + } + parse, err := url.Parse(origin) + if err != nil { + logger.WithError(err).Error("failed to parse origin URL") + m.DiscordController.NotifyPotentialAbuse(alertMessage + " - failed to parse origin URL") + return + } + if !strings.Contains(strings.ToLower(parse.Host), strings.ToLower(*domain)) { + logger.Warnf("custom domain check failed for owner %d, origin %s, domain %s", ownerID, origin, *domain) + m.DiscordController.NotifyPotentialAbuse(alertMessage) + } +} + func computeHashKeyForList(list []string, delim string) string { var buffer bytes.Buffer for i := range list { diff --git a/server/pkg/middleware/rate_limit.go b/server/pkg/middleware/rate_limit.go index bf4c403cfe..26044ba53c 100644 --- a/server/pkg/middleware/rate_limit.go +++ b/server/pkg/middleware/rate_limit.go @@ -133,7 +133,8 @@ func (r *RateLimitMiddleware) APIRateLimitForUserMiddleware(urlSanitizer func(_ // getLimiter, based on reqPath & reqMethod, return instance of limiter.Limiter which needs to // be applied for a request. It returns nil if the request is not rate limited func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limiter.Limiter { - if reqPath == "/users/public-key" { + if reqPath == "/users/public-key" || + reqPath == "/custom-domain" { return r.limit200ReqPerMin } if reqPath == "/users/ott" || diff --git a/server/pkg/repo/public/collection_link.go b/server/pkg/repo/public/collection_link.go index fafcd4cb11..3271a850b7 100644 --- a/server/pkg/repo/public/collection_link.go +++ b/server/pkg/repo/public/collection_link.go @@ -11,8 +11,6 @@ import ( "github.com/lib/pq" ) -const BaseShareURL = "https://albums.ente.io/?t=%s" - // CollectionLinkRepo defines the methods for inserting, updating and // retrieving entities related to public collections type CollectionLinkRepo struct { diff --git a/server/pkg/repo/remotestore/repository.go b/server/pkg/repo/remotestore/repository.go index 2548f49018..245d05b4eb 100644 --- a/server/pkg/repo/remotestore/repository.go +++ b/server/pkg/repo/remotestore/repository.go @@ -3,6 +3,9 @@ package remotestore import ( "context" "database/sql" + "errors" + "github.com/ente-io/museum/ente" + "github.com/lib/pq" "github.com/ente-io/stacktrace" ) @@ -21,9 +24,66 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, userID int64, key strin key, // $2 key_name value, // $3 key_value ) + + if err != nil { + // Check for unique violation (PostgreSQL error code 23505) + var pgErr *pq.Error + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + if pgErr.Constraint == "remote_store_custom_domain_unique_idx" { + return ente.NewConflictError("custom domain already exists for another user") + } + } + return stacktrace.Propagate(err, "failed to insert/update") + } return stacktrace.Propagate(err, "failed to insert/update") } +func (r *Repository) RemoveKey(ctx context.Context, userID int64, key string) error { + _, err := r.DB.ExecContext(ctx, `DELETE FROM remote_store + WHERE user_id = $1 AND key_name = $2`, + userID, // $1 + key, // $2 + ) + return stacktrace.Propagate(err, "failed to remove key") +} + +func (r *Repository) DomainOwner(ctx context.Context, domain string) (*int64, error) { + // Check if the domain is already taken by another user + rows := r.DB.QueryRowContext(ctx, `SELECT user_id FROM remote_store + WHERE key_name = $1 AND key_value = $2`, + ente.CustomDomain, // $1 + domain, // $2 + ) + var userID int64 + err := rows.Scan(&userID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "") + } + return nil, stacktrace.Propagate(err, "failed to fetch domain owner") + } + return &userID, nil +} + +func (r *Repository) GetDomain(ctx context.Context, userID int64) (*string, error) { + // Fetch the custom domain for the user + rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store + WHERE user_id = $1 AND key_name = $2`, + userID, // $1 + ente.CustomDomain, // $2 + ) + var domain string + err := rows.Scan(&domain) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, stacktrace.Propagate(err, "failed to fetch custom domain") + } + return &domain, nil + +} + // GetValue fetches and return the value for given user_id and key func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (string, error) { rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store