Endpoint to check domain claim

This commit is contained in:
Neeraj Gupta
2025-08-08 15:40:05 +05:30
parent 23103c3bcc
commit 1c37332f37
5 changed files with 40 additions and 1 deletions

View File

@@ -793,6 +793,7 @@ func main() {
privateAPI.DELETE("/remote-store/:key", remoteStoreHandler.RemoveKey)
privateAPI.GET("/remote-store", remoteStoreHandler.GetKey)
privateAPI.GET("/remote-store/feature-flags", remoteStoreHandler.GetFeatureFlags)
privateAPI.GET("/custom-domain", remoteStoreHandler.CheckDomain)
pushHandler := &api.PushHandler{PushController: pushController}
privateAPI.POST("/push/token", pushHandler.AddToken)

View File

@@ -73,3 +73,18 @@ func (h *RemoteStoreHandler) GetFeatureFlags(c *gin.Context) {
}
c.JSON(http.StatusOK, resp)
}
// CheckDomain returns 200 ok if the custom domain is claimed by any ente user
func (h *RemoteStoreHandler) CheckDomain(c *gin.Context) {
domain := c.Query("domain")
if domain == "" {
handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("domain is missing"), ""))
return
}
_, err := h.Controller.DomainOwner(c, domain)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, "failed to get feature flags"))
return
}
c.JSON(http.StatusOK, gin.H{})
}

View File

@@ -104,6 +104,10 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons
return response, nil
}
func (c *Controller) DomainOwner(ctx *gin.Context, domain string) (*int64, error) {
return c.Repo.DomainOwner(ctx, domain)
}
func (c *Controller) _validateRequest(userID int64, key, value string, byAdmin bool) error {
if ente.IsValidFlagKey(key) {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key")

View File

@@ -133,7 +133,8 @@ func (r *RateLimitMiddleware) APIRateLimitForUserMiddleware(urlSanitizer func(_
// getLimiter, based on reqPath & reqMethod, return instance of limiter.Limiter which needs to
// be applied for a request. It returns nil if the request is not rate limited
func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limiter.Limiter {
if reqPath == "/users/public-key" {
if reqPath == "/users/public-key" ||
reqPath == "/custom-domain" {
return r.limit200ReqPerMin
}
if reqPath == "/users/ott" ||

View File

@@ -47,6 +47,24 @@ func (r *Repository) RemoveKey(ctx context.Context, userID int64, key string) er
return stacktrace.Propagate(err, "failed to remove key")
}
func (r *Repository) DomainOwner(ctx context.Context, domain string) (*int64, error) {
// Check if the domain is already taken by another user
rows := r.DB.QueryRowContext(ctx, `SELECT user_id FROM remote_store
WHERE key_name = $1 AND key_value = $2`,
ente.CustomDomain, // $1
domain, // $2
)
var userID int64
err := rows.Scan(&userID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "")
}
return nil, stacktrace.Propagate(err, "failed to fetch domain owner")
}
return &userID, nil
}
// GetValue fetches and return the value for given user_id and key
func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (string, error) {
rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store