From f2dc157e8a0fadd1fa2b4e7a047bcc8077a102f5 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Fri, 8 Aug 2025 13:10:57 +0530 Subject: [PATCH] Support for customDomain flag --- server/cmd/museum/main.go | 4 +- server/ente/remotestore.go | 45 ++++++++++++++----- server/pkg/api/remotestore.go | 14 ++++++ .../pkg/controller/remotestore/controller.go | 34 +++++++++++--- server/pkg/repo/remotestore/repository.go | 9 ++++ 5 files changed, 90 insertions(+), 16 deletions(-) diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 87d720530a..e3301fe6db 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -213,7 +213,6 @@ func main() { commonBillController := commonbilling.NewController(emailNotificationCtrl, storagBonusRepo, userRepo, usageRepo, billingRepo) appStoreController := controller.NewAppStoreController(defaultPlan, billingRepo, fileRepo, userRepo, commonBillController) - remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository} playStoreController := controller.NewPlayStoreController(defaultPlan, billingRepo, fileRepo, userRepo, storagBonusRepo, commonBillController) stripeController := controller.NewStripeController(plans, stripeClients, @@ -222,6 +221,8 @@ func main() { appStoreController, playStoreController, stripeController, discordController, emailNotificationCtrl, billingRepo, userRepo, usageRepo, storagBonusRepo, commonBillController) + remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository, BillingCtrl: billingController} + pushController := controller.NewPushController(pushRepo, taskLockingRepo, hostName) mailingListsController := controller.NewMailingListsController() @@ -788,6 +789,7 @@ func main() { remoteStoreHandler := &api.RemoteStoreHandler{Controller: remoteStoreController} privateAPI.POST("/remote-store/update", remoteStoreHandler.InsertOrUpdate) + privateAPI.DELETE("/remote-store/:key", remoteStoreHandler.RemoveKey) privateAPI.GET("/remote-store", remoteStoreHandler.GetKey) privateAPI.GET("/remote-store/feature-flags", remoteStoreHandler.GetFeatureFlags) diff --git a/server/ente/remotestore.go b/server/ente/remotestore.go index 8546fd3cf0..69f29cf543 100644 --- a/server/ente/remotestore.go +++ b/server/ente/remotestore.go @@ -23,15 +23,16 @@ type AdminUpdateKeyValueRequest struct { type FeatureFlagResponse struct { EnableStripe bool `json:"enableStripe"` // If true, the mobile client will stop using CF worker to download files - DisableCFWorker bool `json:"disableCFWorker"` - MapEnabled bool `json:"mapEnabled"` - FaceSearchEnabled bool `json:"faceSearchEnabled"` - PassKeyEnabled bool `json:"passKeyEnabled"` - RecoveryKeyVerified bool `json:"recoveryKeyVerified"` - InternalUser bool `json:"internalUser"` - BetaUser bool `json:"betaUser"` - EnableMobMultiPart bool `json:"enableMobMultiPart"` - CastUrl string `json:"castUrl"` + DisableCFWorker bool `json:"disableCFWorker"` + MapEnabled bool `json:"mapEnabled"` + FaceSearchEnabled bool `json:"faceSearchEnabled"` + PassKeyEnabled bool `json:"passKeyEnabled"` + RecoveryKeyVerified bool `json:"recoveryKeyVerified"` + InternalUser bool `json:"internalUser"` + BetaUser bool `json:"betaUser"` + EnableMobMultiPart bool `json:"enableMobMultiPart"` + CastUrl string `json:"castUrl"` + CustomDomain *string `json:"customDomain,omitempty"` } type FlagKey string @@ -43,8 +44,24 @@ const ( PassKeyEnabled FlagKey = "passKeyEnabled" IsInternalUser FlagKey = "internalUser" IsBetaUser FlagKey = "betaUser" + CustomDomain FlagKey = "customDomain" ) +var validFlagKeys = map[FlagKey]struct{}{ + RecoveryKeyVerified: {}, + MapEnabled: {}, + FaceSearchEnabled: {}, + PassKeyEnabled: {}, + IsInternalUser: {}, + IsBetaUser: {}, + CustomDomain: {}, +} + +func IsValidFlagKey(key string) bool { + _, exists := validFlagKeys[FlagKey(key)] + return exists +} + func (k FlagKey) String() string { return string(k) } @@ -52,13 +69,21 @@ func (k FlagKey) String() string { // UserEditable returns true if the key is user editable func (k FlagKey) UserEditable() bool { switch k { - case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled: + case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, CustomDomain: return true default: return false } } +func (k FlagKey) NeedSubscription() bool { + return k == CustomDomain +} + +func (k FlagKey) CanRemove() bool { + return k == CustomDomain +} + func (k FlagKey) IsAdminEditable() bool { switch k { case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled: diff --git a/server/pkg/api/remotestore.go b/server/pkg/api/remotestore.go index 9f03554de8..7752c990f3 100644 --- a/server/pkg/api/remotestore.go +++ b/server/pkg/api/remotestore.go @@ -33,6 +33,20 @@ func (h *RemoteStoreHandler) InsertOrUpdate(c *gin.Context) { c.Status(http.StatusOK) } +func (h *RemoteStoreHandler) RemoveKey(c *gin.Context) { + key := c.Param("key") + if key == "" { + handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("key is missing"), "")) + return + } + err := h.Controller.RemoveKey(c, key) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to update key's value")) + return + } + c.Status(http.StatusOK) +} + // GetKey handler for fetching a value for particular key func (h *RemoteStoreHandler) GetKey(c *gin.Context) { var request ente.GetValueRequest diff --git a/server/pkg/controller/remotestore/controller.go b/server/pkg/controller/remotestore/controller.go index f031f65033..77d00947d4 100644 --- a/server/pkg/controller/remotestore/controller.go +++ b/server/pkg/controller/remotestore/controller.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/ente-io/museum/pkg/controller" "github.com/spf13/viper" "github.com/ente-io/museum/ente" @@ -15,20 +16,33 @@ import ( // Controller is interface for exposing business logic related to for remote store type Controller struct { - Repo *remotestore.Repository + Repo *remotestore.Repository + BillingCtrl *controller.BillingController } // InsertOrUpdate the key's value func (c *Controller) InsertOrUpdate(ctx *gin.Context, request ente.UpdateKeyValueRequest) error { - if err := _validateRequest(request.Key, request.Value, false); err != nil { + userID := auth.GetUserID(ctx.Request.Header) + if err := c._validateRequest(userID, request.Key, request.Value, false); err != nil { return err } - userID := auth.GetUserID(ctx.Request.Header) return c.Repo.InsertOrUpdate(ctx, userID, request.Key, request.Value) } +// RemoveKey removes the key from remote store +func (c *Controller) RemoveKey(ctx *gin.Context, key string) error { + userID := auth.GetUserID(ctx.Request.Header) + if valid := ente.IsValidFlagKey(key); !valid { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key") + } + if !ente.FlagKey(key).CanRemove() { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not removable", key)), "key not removable") + } + return c.Repo.RemoveKey(ctx, userID, key) +} + func (c *Controller) AdminInsertOrUpdate(ctx *gin.Context, request ente.AdminUpdateKeyValueRequest) error { - if err := _validateRequest(request.Key, request.Value, true); err != nil { + if err := c._validateRequest(request.UserID, request.Key, request.Value, true); err != nil { return err } return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, request.Value) @@ -80,12 +94,19 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons response.InternalUser = value == "true" case ente.IsBetaUser: response.BetaUser = value == "true" + case ente.CustomDomain: + if value != "" { + response.CustomDomain = &value + } } } return response, nil } -func _validateRequest(key, value string, byAdmin bool) error { +func (c *Controller) _validateRequest(userID int64, key, value string, byAdmin bool) error { + if ente.IsValidFlagKey(key) { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key") + } flag := ente.FlagKey(key) if !flag.UserEditable() && !byAdmin { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not user editable", key)), "key not user editable") @@ -96,5 +117,8 @@ func _validateRequest(key, value string, byAdmin bool) error { if flag.IsBoolType() && value != "true" && value != "false" { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed") } + if flag.NeedSubscription() { + return c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true) + } return nil } diff --git a/server/pkg/repo/remotestore/repository.go b/server/pkg/repo/remotestore/repository.go index 2548f49018..d9dd50ac4d 100644 --- a/server/pkg/repo/remotestore/repository.go +++ b/server/pkg/repo/remotestore/repository.go @@ -24,6 +24,15 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, userID int64, key strin return stacktrace.Propagate(err, "failed to insert/update") } +func (r *Repository) RemoveKey(ctx context.Context, userID int64, key string) error { + _, err := r.DB.ExecContext(ctx, `DELETE FROM remote_store + WHERE user_id = $1 AND key_name = $2`, + userID, // $1 + key, // $2 + ) + return stacktrace.Propagate(err, "failed to remove key") +} + // GetValue fetches and return the value for given user_id and key func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (string, error) { rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store