diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index cc10a67e8b..86f5537c7b 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -793,6 +793,7 @@ func main() { privateAPI.DELETE("/remote-store/:key", remoteStoreHandler.RemoveKey) privateAPI.GET("/remote-store", remoteStoreHandler.GetKey) privateAPI.GET("/remote-store/feature-flags", remoteStoreHandler.GetFeatureFlags) + privateAPI.GET("/custom-domain", remoteStoreHandler.CheckDomain) pushHandler := &api.PushHandler{PushController: pushController} privateAPI.POST("/push/token", pushHandler.AddToken) diff --git a/server/pkg/api/remotestore.go b/server/pkg/api/remotestore.go index 7752c990f3..e83883c2dc 100644 --- a/server/pkg/api/remotestore.go +++ b/server/pkg/api/remotestore.go @@ -73,3 +73,18 @@ func (h *RemoteStoreHandler) GetFeatureFlags(c *gin.Context) { } c.JSON(http.StatusOK, resp) } + +// CheckDomain returns 200 ok if the custom domain is claimed by any ente user +func (h *RemoteStoreHandler) CheckDomain(c *gin.Context) { + domain := c.Query("domain") + if domain == "" { + handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("domain is missing"), "")) + return + } + _, err := h.Controller.DomainOwner(c, domain) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to get feature flags")) + return + } + c.JSON(http.StatusOK, gin.H{}) +} diff --git a/server/pkg/controller/remotestore/controller.go b/server/pkg/controller/remotestore/controller.go index e6769ce443..d6932a5aaf 100644 --- a/server/pkg/controller/remotestore/controller.go +++ b/server/pkg/controller/remotestore/controller.go @@ -104,6 +104,10 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons return response, nil } +func (c *Controller) DomainOwner(ctx *gin.Context, domain string) (*int64, error) { + return c.Repo.DomainOwner(ctx, domain) +} + func (c *Controller) _validateRequest(userID int64, key, value string, byAdmin bool) error { if ente.IsValidFlagKey(key) { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key") diff --git a/server/pkg/middleware/rate_limit.go b/server/pkg/middleware/rate_limit.go index bf4c403cfe..26044ba53c 100644 --- a/server/pkg/middleware/rate_limit.go +++ b/server/pkg/middleware/rate_limit.go @@ -133,7 +133,8 @@ func (r *RateLimitMiddleware) APIRateLimitForUserMiddleware(urlSanitizer func(_ // getLimiter, based on reqPath & reqMethod, return instance of limiter.Limiter which needs to // be applied for a request. It returns nil if the request is not rate limited func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limiter.Limiter { - if reqPath == "/users/public-key" { + if reqPath == "/users/public-key" || + reqPath == "/custom-domain" { return r.limit200ReqPerMin } if reqPath == "/users/ott" || diff --git a/server/pkg/repo/remotestore/repository.go b/server/pkg/repo/remotestore/repository.go index 681942365c..f2bd8d7789 100644 --- a/server/pkg/repo/remotestore/repository.go +++ b/server/pkg/repo/remotestore/repository.go @@ -47,6 +47,24 @@ func (r *Repository) RemoveKey(ctx context.Context, userID int64, key string) er return stacktrace.Propagate(err, "failed to remove key") } +func (r *Repository) DomainOwner(ctx context.Context, domain string) (*int64, error) { + // Check if the domain is already taken by another user + rows := r.DB.QueryRowContext(ctx, `SELECT user_id FROM remote_store + WHERE key_name = $1 AND key_value = $2`, + ente.CustomDomain, // $1 + domain, // $2 + ) + var userID int64 + err := rows.Scan(&userID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "") + } + return nil, stacktrace.Propagate(err, "failed to fetch domain owner") + } + return &userID, nil +} + // GetValue fetches and return the value for given user_id and key func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (string, error) { rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store