[server] Support for configuring custom-domains (#6827)

## Description

## Tests
- Verified server db migrations is on 103
- Verified that duplicate custom domain results in error
- basic sanity testing for custom domain validation.
This commit is contained in:
Neeraj
2025-08-12 17:25:47 +05:30
committed by GitHub
13 changed files with 379 additions and 38 deletions

View File

@@ -98,6 +98,7 @@ func main() {
}
viper.SetDefault("apps.public-albums", "https://albums.ente.io")
viper.SetDefault("apps.custom-domain.cname", "my.ente.io")
viper.SetDefault("apps.public-locker", "https://locker.ente.io")
viper.SetDefault("apps.accounts", "https://accounts.ente.io")
viper.SetDefault("apps.cast", "https://cast.ente.io")
@@ -213,7 +214,6 @@ func main() {
commonBillController := commonbilling.NewController(emailNotificationCtrl, storagBonusRepo, userRepo, usageRepo, billingRepo)
appStoreController := controller.NewAppStoreController(defaultPlan,
billingRepo, fileRepo, userRepo, commonBillController)
remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository}
playStoreController := controller.NewPlayStoreController(defaultPlan,
billingRepo, fileRepo, userRepo, storagBonusRepo, commonBillController)
stripeController := controller.NewStripeController(plans, stripeClients,
@@ -222,6 +222,8 @@ func main() {
appStoreController, playStoreController, stripeController,
discordController, emailNotificationCtrl,
billingRepo, userRepo, usageRepo, storagBonusRepo, commonBillController)
remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository, BillingCtrl: billingController}
pushController := controller.NewPushController(pushRepo, taskLockingRepo, hostName)
mailingListsController := controller.NewMailingListsController()
@@ -376,6 +378,7 @@ func main() {
Cache: accessTokenCache,
BillingCtrl: billingController,
DiscordController: discordController,
RemoteStoreRepo: remoteStoreRepository,
}
fileLinkMiddleware := &middleware.FileLinkMiddleware{
FileLinkRepo: fileLinkRepo,
@@ -788,8 +791,10 @@ func main() {
remoteStoreHandler := &api.RemoteStoreHandler{Controller: remoteStoreController}
privateAPI.POST("/remote-store/update", remoteStoreHandler.InsertOrUpdate)
privateAPI.DELETE("/remote-store/:key", remoteStoreHandler.RemoveKey)
privateAPI.GET("/remote-store", remoteStoreHandler.GetKey)
privateAPI.GET("/remote-store/feature-flags", remoteStoreHandler.GetFeatureFlags)
publicAPI.GET("/custom-domain", remoteStoreHandler.CheckDomain)
pushHandler := &api.PushHandler{PushController: pushController}
privateAPI.POST("/push/token", pushHandler.AddToken)

View File

@@ -95,6 +95,9 @@ apps:
accounts:
# Default is https://family.ente.io
family:
custom-domain:
# Default is my.ente.io
cname:
# Database connection parameters
db:

View File

@@ -1,5 +1,12 @@
package ente
import (
"fmt"
"github.com/ente-io/stacktrace"
"regexp"
"strings"
)
type GetValueRequest struct {
Key string `form:"key" binding:"required"`
DefaultValue *string `form:"defaultValue"`
@@ -10,28 +17,30 @@ type GetValueResponse struct {
}
type UpdateKeyValueRequest struct {
Key string `json:"key" binding:"required"`
Value string `json:"value" binding:"required"`
Key string `json:"key" binding:"required"`
Value *string `json:"value" binding:"required"`
}
type AdminUpdateKeyValueRequest struct {
UserID int64 `json:"userID" binding:"required"`
Key string `json:"key" binding:"required"`
Value string `json:"value" binding:"required"`
UserID int64 `json:"userID" binding:"required"`
Key string `json:"key" binding:"required"`
Value *string `json:"value" binding:"required"`
}
type FeatureFlagResponse struct {
EnableStripe bool `json:"enableStripe"`
// If true, the mobile client will stop using CF worker to download files
DisableCFWorker bool `json:"disableCFWorker"`
MapEnabled bool `json:"mapEnabled"`
FaceSearchEnabled bool `json:"faceSearchEnabled"`
PassKeyEnabled bool `json:"passKeyEnabled"`
RecoveryKeyVerified bool `json:"recoveryKeyVerified"`
InternalUser bool `json:"internalUser"`
BetaUser bool `json:"betaUser"`
EnableMobMultiPart bool `json:"enableMobMultiPart"`
CastUrl string `json:"castUrl"`
DisableCFWorker bool `json:"disableCFWorker"`
MapEnabled bool `json:"mapEnabled"`
FaceSearchEnabled bool `json:"faceSearchEnabled"`
PassKeyEnabled bool `json:"passKeyEnabled"`
RecoveryKeyVerified bool `json:"recoveryKeyVerified"`
InternalUser bool `json:"internalUser"`
BetaUser bool `json:"betaUser"`
EnableMobMultiPart bool `json:"enableMobMultiPart"`
CastUrl string `json:"castUrl"`
CustomDomain *string `json:"customDomain,omitempty"`
CustomDomainCNAME string `json:"customDomainCNAME,omitempty"`
}
type FlagKey string
@@ -43,8 +52,24 @@ const (
PassKeyEnabled FlagKey = "passKeyEnabled"
IsInternalUser FlagKey = "internalUser"
IsBetaUser FlagKey = "betaUser"
CustomDomain FlagKey = "customDomain"
)
var validFlagKeys = map[FlagKey]struct{}{
RecoveryKeyVerified: {},
MapEnabled: {},
FaceSearchEnabled: {},
PassKeyEnabled: {},
IsInternalUser: {},
IsBetaUser: {},
CustomDomain: {},
}
func IsValidFlagKey(key string) bool {
_, exists := validFlagKeys[FlagKey(key)]
return exists
}
func (k FlagKey) String() string {
return string(k)
}
@@ -52,13 +77,21 @@ func (k FlagKey) String() string {
// UserEditable returns true if the key is user editable
func (k FlagKey) UserEditable() bool {
switch k {
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled:
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, CustomDomain:
return true
default:
return false
}
}
func (k FlagKey) NeedSubscription() bool {
return k == CustomDomain
}
func (k FlagKey) CanRemove() bool {
return k == CustomDomain
}
func (k FlagKey) IsAdminEditable() bool {
switch k {
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled:
@@ -74,7 +107,40 @@ func (k FlagKey) IsBoolType() bool {
switch k {
case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, IsInternalUser, IsBetaUser:
return true
default:
case CustomDomain:
return false
default:
return false // Explicitly handle unexpected cases
}
}
func (k FlagKey) IsValidValue(value string) error {
if k.IsBoolType() && value != "true" && value != "false" {
return stacktrace.Propagate(NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
}
if k == CustomDomain && value != "" {
if err := isValidDomainWithoutScheme(value); err != nil {
return stacktrace.Propagate(err, "invalid custom domain")
}
}
return nil
}
var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`)
func isValidDomainWithoutScheme(input string) error {
trimmed := strings.TrimSpace(input)
if trimmed != input {
return NewBadRequestWithMessage("domain contains leading or trailing spaces")
}
if trimmed == "" {
return NewBadRequestWithMessage("domain is empty")
}
if strings.Contains(trimmed, "://") {
return NewBadRequestWithMessage("domain should not contain scheme (e.g., http:// or https://)")
}
if !domainRegex.MatchString(trimmed) {
return NewBadRequestWithMessage(fmt.Sprintf("invalid domain format: %s", trimmed))
}
return nil
}

View File

@@ -0,0 +1,92 @@
package ente
import "testing"
func TestIsValidDomainWithoutScheme(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
// ✅ Valid cases
{"simple domain", "google.com", false},
{"multi-level domain", "sub.example.co.in", false},
{"numeric in label", "a1b2c3.com", false},
{"long but valid label", "my-very-long-subdomain-name.example.com", false},
// ❌ Leading/trailing spaces
{"leading space", " google.com", true},
{"trailing space", "google.com ", true},
{"both spaces", " google.com ", true},
// ❌ Empty or whitespace
{"empty string", "", true},
{"only spaces", " ", true},
// ❌ Scheme included
{"http scheme", "http://google.com", true},
{"https scheme", "https://example.com", true},
{"ftp scheme", "ftp://example.com", true},
// ❌ Invalid characters
{"underscore in label", "my_domain.com", true},
{"invalid symbol", "exa$mple.com", true},
{"space inside", "exa mple.com", true},
// ❌ Wrong format
{"missing dot", "localhost", true},
{"single label TLD", "com", true},
{"ends with dot", "example.com.", true},
{"ends with dash", "example-.com", true},
{"starts with dash", "-example.com", true},
// ❌ Consecutive dots
{"double dots", "example..com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := isValidDomainWithoutScheme(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("isValidDomainWithoutScheme(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
})
}
}
func TestFlagKey_IsValidValue(t *testing.T) {
tests := []struct {
name string
key FlagKey
value string
wantErr bool
}{
// ✅ Valid boolean flag values
{"valid true for bool key", MapEnabled, "true", false},
{"valid false for bool key", FaceSearchEnabled, "false", false},
// ❌ Invalid boolean flag values
{"invalid value for bool key", PassKeyEnabled, "yes", true},
{"empty value for bool key", IsInternalUser, "", true},
// ✅ Valid custom domain values
{"valid custom domain", CustomDomain, "example.com", false},
{"valid subdomain", CustomDomain, "sub.example.com", false},
// ❌ Invalid custom domain values
{"empty custom domain", CustomDomain, "", false}, // Allowed as empty
{"custom domain with scheme", CustomDomain, "http://example.com", true},
{"custom domain with invalid format", CustomDomain, "exa$mple.com", true},
{"custom domain with leading space", CustomDomain, " example.com", true},
{"custom domain with trailing space", CustomDomain, "example.com ", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.key.IsValidValue(tt.value)
if (err != nil) != tt.wantErr {
t.Errorf("FlagKey(%q).IsValidValue(%q) error = %v, wantErr %v", tt.key, tt.value, err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1 @@
DROP INDEX IF EXISTS remote_store_custom_domain_unique_idx;

View File

@@ -0,0 +1,3 @@
CREATE UNIQUE INDEX IF NOT EXISTS remote_store_custom_domain_unique_idx
ON remote_store (key_value)
WHERE key_name = 'customDomain';

View File

@@ -3,13 +3,14 @@ package api
import (
"errors"
"fmt"
"github.com/ente-io/museum/pkg/controller/emergency"
"github.com/ente-io/museum/pkg/controller/remotestore"
"github.com/ente-io/museum/pkg/repo/authenticator"
"net/http"
"strconv"
"strings"
"github.com/ente-io/museum/pkg/controller/emergency"
"github.com/ente-io/museum/pkg/controller/remotestore"
"github.com/ente-io/museum/pkg/repo/authenticator"
"github.com/ente-io/museum/pkg/controller/family"
bonusEntity "github.com/ente-io/museum/ente/storagebonus"
@@ -377,7 +378,7 @@ func (h *AdminHandler) UpdateFeatureFlag(c *gin.Context) {
return
}
go h.DiscordController.NotifyAdminAction(
fmt.Sprintf("Admin (%d) updating flag:%s to val:%s for %d", auth.GetUserID(c.Request.Header), request.Key, request.Value, request.UserID))
fmt.Sprintf("Admin (%d) updating flag:%s to val:%v for %d", auth.GetUserID(c.Request.Header), request.Key, request.Value, request.UserID))
logger := logrus.WithFields(logrus.Fields{
"user_id": request.UserID,

View File

@@ -33,6 +33,20 @@ func (h *RemoteStoreHandler) InsertOrUpdate(c *gin.Context) {
c.Status(http.StatusOK)
}
func (h *RemoteStoreHandler) RemoveKey(c *gin.Context) {
key := c.Param("key")
if key == "" {
handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("key is missing"), ""))
return
}
err := h.Controller.RemoveKey(c, key)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, "failed to update key's value"))
return
}
c.Status(http.StatusOK)
}
// GetKey handler for fetching a value for particular key
func (h *RemoteStoreHandler) GetKey(c *gin.Context) {
var request ente.GetValueRequest
@@ -59,3 +73,18 @@ func (h *RemoteStoreHandler) GetFeatureFlags(c *gin.Context) {
}
c.JSON(http.StatusOK, resp)
}
// CheckDomain returns 200 ok if the custom domain is claimed by any ente user
func (h *RemoteStoreHandler) CheckDomain(c *gin.Context) {
domain := c.Query("domain")
if domain == "" {
handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("domain is missing"), ""))
return
}
_, err := h.Controller.DomainOwner(c, domain)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, "failed to get feature flags"))
return
}
c.JSON(http.StatusOK, gin.H{})
}

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/ente-io/museum/pkg/controller"
"github.com/spf13/viper"
"github.com/ente-io/museum/ente"
@@ -15,23 +16,39 @@ import (
// Controller is interface for exposing business logic related to for remote store
type Controller struct {
Repo *remotestore.Repository
Repo *remotestore.Repository
BillingCtrl *controller.BillingController
}
// InsertOrUpdate the key's value
func (c *Controller) InsertOrUpdate(ctx *gin.Context, request ente.UpdateKeyValueRequest) error {
if err := _validateRequest(request.Key, request.Value, false); err != nil {
userID := auth.GetUserID(ctx.Request.Header)
if err := c._validateRequest(userID, request.Key, request.Value, false); err != nil {
return err
}
if *request.Value == "" && ente.FlagKey(request.Key).CanRemove() {
return c.Repo.RemoveKey(ctx, userID, request.Key)
}
return c.Repo.InsertOrUpdate(ctx, userID, request.Key, *request.Value)
}
// RemoveKey removes the key from remote store
func (c *Controller) RemoveKey(ctx *gin.Context, key string) error {
userID := auth.GetUserID(ctx.Request.Header)
return c.Repo.InsertOrUpdate(ctx, userID, request.Key, request.Value)
if valid := ente.IsValidFlagKey(key); !valid {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key")
}
if !ente.FlagKey(key).CanRemove() {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not removable", key)), "key not removable")
}
return c.Repo.RemoveKey(ctx, userID, key)
}
func (c *Controller) AdminInsertOrUpdate(ctx *gin.Context, request ente.AdminUpdateKeyValueRequest) error {
if err := _validateRequest(request.Key, request.Value, true); err != nil {
if err := c._validateRequest(request.UserID, request.Key, request.Value, true); err != nil {
return err
}
return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, request.Value)
return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, *request.Value)
}
func (c *Controller) Get(ctx *gin.Context, req ente.GetValueRequest) (*ente.GetValueResponse, error) {
@@ -61,12 +78,10 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons
// except internal user.rt
EnableMobMultiPart: true,
CastUrl: viper.GetString("apps.cast"),
CustomDomainCNAME: viper.GetString("apps.custom-domain.cname"),
}
for key, value := range values {
flag := ente.FlagKey(key)
if !flag.IsBoolType() {
continue
}
switch flag {
case ente.RecoveryKeyVerified:
response.RecoveryKeyVerified = value == "true"
@@ -80,21 +95,40 @@ func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagRespons
response.InternalUser = value == "true"
case ente.IsBetaUser:
response.BetaUser = value == "true"
case ente.CustomDomain:
if value != "" {
response.CustomDomain = &value
}
}
}
return response, nil
}
func _validateRequest(key, value string, byAdmin bool) error {
func (c *Controller) DomainOwner(ctx *gin.Context, domain string) (*int64, error) {
return c.Repo.DomainOwner(ctx, domain)
}
func (c *Controller) _validateRequest(userID int64, key string, valuePtr *string, byAdmin bool) error {
if !ente.IsValidFlagKey(key) {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", key)), "invalid flag key")
}
if valuePtr == nil {
return stacktrace.Propagate(ente.NewBadRequestWithMessage("value is missing"), "value is nil")
}
value := *valuePtr
flag := ente.FlagKey(key)
if err := flag.IsValidValue(value); err != nil {
return stacktrace.Propagate(err, "")
}
if !flag.UserEditable() && !byAdmin {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not user editable", key)), "key not user editable")
}
if byAdmin && !flag.IsAdminEditable() {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not admin editable", key)), "key not admin editable")
}
if flag.IsBoolType() && value != "true" && value != "false" {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed")
if flag.NeedSubscription() {
return c.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, true)
}
return nil
}

View File

@@ -6,6 +6,12 @@ import (
"crypto/sha256"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/ente-io/museum/pkg/repo/remotestore"
"github.com/gin-contrib/requestid"
"github.com/spf13/viper"
public2 "github.com/ente-io/museum/pkg/controller/public"
"github.com/ente-io/museum/pkg/repo/public"
@@ -35,6 +41,7 @@ type CollectionLinkMiddleware struct {
Cache *cache.Cache
BillingCtrl *controller.BillingController
DiscordController *discord.DiscordController
RemoteStoreRepo *remotestore.Repository
}
// Authenticate returns a middle ware that extracts the `X-Auth-Access-Token`
@@ -65,7 +72,7 @@ func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context
return
}
// validate if user still has active paid subscription
if err = m.validateOwnersSubscription(publicCollectionSummary.CollectionID); err != nil {
if err = m.validateOwnersSubscription(c, publicCollectionSummary.CollectionID); err != nil {
logrus.WithError(err).Warn("failed to verify active paid subscription")
c.AbortWithStatusJSON(http.StatusGone, gin.H{"error": "no active subscription"})
return
@@ -115,12 +122,17 @@ func (m *CollectionLinkMiddleware) Authenticate(urlSanitizer func(_ *gin.Context
c.Next()
}
}
func (m *CollectionLinkMiddleware) validateOwnersSubscription(cID int64) error {
func (m *CollectionLinkMiddleware) validateOwnersSubscription(c *gin.Context, cID int64) error {
userID, err := m.CollectionRepo.GetOwnerID(cID)
if err != nil {
return stacktrace.Propagate(err, "")
}
return m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
err = m.BillingCtrl.HasActiveSelfOrFamilySubscription(userID, false)
if err != nil {
return stacktrace.Propagate(err, "failed to validate owners subscription")
}
m.validateOrigin(c, userID)
return nil
}
func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context,
@@ -177,6 +189,42 @@ func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath stri
return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *collectionSummary.PassHash)
}
func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) {
origin := c.Request.Header.Get("Origin")
if origin == "" || origin == viper.GetString("apps.public-albums") {
return
}
reqId := requestid.Get(c)
logger := logrus.WithFields(logrus.Fields{
"ownerID": ownerID,
"req_id": reqId,
"origin": origin,
})
alertMessage := fmt.Sprintf("custom domain check failed %s", reqId)
domain, err := m.RemoteStoreRepo.GetDomain(c, ownerID)
if err != nil {
logger.WithError(err).Error("failed to fetch custom domain for owner")
m.DiscordController.NotifyPotentialAbuse(alertMessage)
return
}
if domain == nil || *domain == "" {
logger.Warn("custom domain is nil or empty")
m.DiscordController.NotifyPotentialAbuse(alertMessage)
return
}
parse, err := url.Parse(origin)
if err != nil {
logger.WithError(err).Error("failed to parse origin URL")
m.DiscordController.NotifyPotentialAbuse(alertMessage + " - failed to parse origin URL")
return
}
if !strings.Contains(strings.ToLower(parse.Host), strings.ToLower(*domain)) {
logger.Warnf("custom domain check failed for owner %d, origin %s, domain %s", ownerID, origin, *domain)
m.DiscordController.NotifyPotentialAbuse(alertMessage)
}
}
func computeHashKeyForList(list []string, delim string) string {
var buffer bytes.Buffer
for i := range list {

View File

@@ -133,7 +133,8 @@ func (r *RateLimitMiddleware) APIRateLimitForUserMiddleware(urlSanitizer func(_
// getLimiter, based on reqPath & reqMethod, return instance of limiter.Limiter which needs to
// be applied for a request. It returns nil if the request is not rate limited
func (r *RateLimitMiddleware) getLimiter(reqPath string, reqMethod string) *limiter.Limiter {
if reqPath == "/users/public-key" {
if reqPath == "/users/public-key" ||
reqPath == "/custom-domain" {
return r.limit200ReqPerMin
}
if reqPath == "/users/ott" ||

View File

@@ -11,8 +11,6 @@ import (
"github.com/lib/pq"
)
const BaseShareURL = "https://albums.ente.io/?t=%s"
// CollectionLinkRepo defines the methods for inserting, updating and
// retrieving entities related to public collections
type CollectionLinkRepo struct {

View File

@@ -3,6 +3,9 @@ package remotestore
import (
"context"
"database/sql"
"errors"
"github.com/ente-io/museum/ente"
"github.com/lib/pq"
"github.com/ente-io/stacktrace"
)
@@ -21,9 +24,66 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, userID int64, key strin
key, // $2 key_name
value, // $3 key_value
)
if err != nil {
// Check for unique violation (PostgreSQL error code 23505)
var pgErr *pq.Error
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
if pgErr.Constraint == "remote_store_custom_domain_unique_idx" {
return ente.NewConflictError("custom domain already exists for another user")
}
}
return stacktrace.Propagate(err, "failed to insert/update")
}
return stacktrace.Propagate(err, "failed to insert/update")
}
func (r *Repository) RemoveKey(ctx context.Context, userID int64, key string) error {
_, err := r.DB.ExecContext(ctx, `DELETE FROM remote_store
WHERE user_id = $1 AND key_name = $2`,
userID, // $1
key, // $2
)
return stacktrace.Propagate(err, "failed to remove key")
}
func (r *Repository) DomainOwner(ctx context.Context, domain string) (*int64, error) {
// Check if the domain is already taken by another user
rows := r.DB.QueryRowContext(ctx, `SELECT user_id FROM remote_store
WHERE key_name = $1 AND key_value = $2`,
ente.CustomDomain, // $1
domain, // $2
)
var userID int64
err := rows.Scan(&userID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "")
}
return nil, stacktrace.Propagate(err, "failed to fetch domain owner")
}
return &userID, nil
}
func (r *Repository) GetDomain(ctx context.Context, userID int64) (*string, error) {
// Fetch the custom domain for the user
rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store
WHERE user_id = $1 AND key_name = $2`,
userID, // $1
ente.CustomDomain, // $2
)
var domain string
err := rows.Scan(&domain)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, stacktrace.Propagate(err, "failed to fetch custom domain")
}
return &domain, nil
}
// GetValue fetches and return the value for given user_id and key
func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (string, error) {
rows := r.DB.QueryRowContext(ctx, `SELECT key_value FROM remote_store