[server] Fail request on customDomain mismatch (#6893)

## Description

## Tests
This commit is contained in:
Neeraj
2025-08-19 14:23:00 +05:30
committed by GitHub

View File

@@ -131,8 +131,7 @@ func (m *CollectionLinkMiddleware) validateOwnersSubscription(c *gin.Context, cI
if err != nil {
return stacktrace.Propagate(err, "failed to validate owners subscription")
}
m.validateOrigin(c, userID)
return nil
return m.validateOrigin(c, userID)
}
func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context,
@@ -189,11 +188,11 @@ func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath stri
return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *collectionSummary.PassHash)
}
func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) {
func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) error {
origin := c.Request.Header.Get("Origin")
if origin == "" || origin == viper.GetString("apps.public-albums") {
return
return nil
}
reqId := requestid.Get(c)
logger := logrus.WithFields(logrus.Fields{
@@ -201,28 +200,30 @@ func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64)
"req_id": reqId,
"origin": origin,
})
alertMessage := fmt.Sprintf("custom domain check failed %s", reqId)
alertMessage := fmt.Sprintf("custom domain %s", reqId)
domain, err := m.RemoteStoreRepo.GetDomain(c, ownerID)
if err != nil {
logger.WithError(err).Error("failed to fetch custom domain for owner")
m.DiscordController.NotifyPotentialAbuse(alertMessage)
return
logger.WithError(err).Error("domainFetchFailed")
m.DiscordController.NotifyPotentialAbuse(alertMessage + " - domainFetchFailed")
return nil
}
if domain == nil || *domain == "" {
logger.Warn("custom domain is nil or empty")
m.DiscordController.NotifyPotentialAbuse(alertMessage)
return
logger.Warn("domainNotConfigured")
m.DiscordController.NotifyPotentialAbuse(alertMessage + " - domainNotConfigured")
return ente.NewPermissionDeniedError("no custom domain configured")
}
parse, err := url.Parse(origin)
if err != nil {
logger.WithError(err).Error("failed to parse origin URL")
m.DiscordController.NotifyPotentialAbuse(alertMessage + " - failed to parse origin URL")
return
logger.WithError(err).Error("originParseFailedL")
m.DiscordController.NotifyPotentialAbuse(alertMessage + " - originParseFailed")
return nil
}
if !strings.Contains(strings.ToLower(parse.Host), strings.ToLower(*domain)) {
logger.Warnf("custom domain check failed for owner %d, origin %s, domain %s", ownerID, origin, *domain)
m.DiscordController.NotifyPotentialAbuse(alertMessage)
logger.Warnf("domainMismatch for owner %d, origin %s, domain %s host %s", ownerID, origin, *domain, parse.Host)
m.DiscordController.NotifyPotentialAbuse(alertMessage + " - domainMismatch")
return ente.NewPermissionDeniedError("unknown custom domain")
}
return nil
}
func computeHashKeyForList(list []string, delim string) string {