From 298e3695c75e1df823967ff22f79f1140be216d1 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 19 Aug 2025 12:34:31 +0530 Subject: [PATCH] [server] Fail request on customDomain mismatch --- server/pkg/middleware/collection_link.go | 33 ++++++++++++------------ 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/server/pkg/middleware/collection_link.go b/server/pkg/middleware/collection_link.go index a0e59df2ca..6efba9fa44 100644 --- a/server/pkg/middleware/collection_link.go +++ b/server/pkg/middleware/collection_link.go @@ -131,8 +131,7 @@ func (m *CollectionLinkMiddleware) validateOwnersSubscription(c *gin.Context, cI if err != nil { return stacktrace.Propagate(err, "failed to validate owners subscription") } - m.validateOrigin(c, userID) - return nil + return m.validateOrigin(c, userID) } func (m *CollectionLinkMiddleware) isDeviceLimitReached(ctx context.Context, @@ -189,11 +188,11 @@ func (m *CollectionLinkMiddleware) validatePassword(c *gin.Context, reqPath stri return m.PublicCollectionCtrl.ValidateJWTToken(c, accessTokenJWT, *collectionSummary.PassHash) } -func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) { +func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) error { origin := c.Request.Header.Get("Origin") if origin == "" || origin == viper.GetString("apps.public-albums") { - return + return nil } reqId := requestid.Get(c) logger := logrus.WithFields(logrus.Fields{ @@ -201,28 +200,30 @@ func (m *CollectionLinkMiddleware) validateOrigin(c *gin.Context, ownerID int64) "req_id": reqId, "origin": origin, }) - alertMessage := fmt.Sprintf("custom domain check failed %s", reqId) + alertMessage := fmt.Sprintf("custom domain %s", reqId) domain, err := m.RemoteStoreRepo.GetDomain(c, ownerID) if err != nil { - logger.WithError(err).Error("failed to fetch custom domain for owner") - m.DiscordController.NotifyPotentialAbuse(alertMessage) - return + logger.WithError(err).Error("domainFetchFailed") + m.DiscordController.NotifyPotentialAbuse(alertMessage + " - domainFetchFailed") + return nil } if domain == nil || *domain == "" { - logger.Warn("custom domain is nil or empty") - m.DiscordController.NotifyPotentialAbuse(alertMessage) - return + logger.Warn("domainNotConfigured") + m.DiscordController.NotifyPotentialAbuse(alertMessage + " - domainNotConfigured") + return ente.NewPermissionDeniedError("no custom domain configured") } parse, err := url.Parse(origin) if err != nil { - logger.WithError(err).Error("failed to parse origin URL") - m.DiscordController.NotifyPotentialAbuse(alertMessage + " - failed to parse origin URL") - return + logger.WithError(err).Error("originParseFailedL") + m.DiscordController.NotifyPotentialAbuse(alertMessage + " - originParseFailed") + return nil } if !strings.Contains(strings.ToLower(parse.Host), strings.ToLower(*domain)) { - logger.Warnf("custom domain check failed for owner %d, origin %s, domain %s", ownerID, origin, *domain) - m.DiscordController.NotifyPotentialAbuse(alertMessage) + logger.Warnf("domainMismatch for owner %d, origin %s, domain %s host %s", ownerID, origin, *domain, parse.Host) + m.DiscordController.NotifyPotentialAbuse(alertMessage + " - domainMismatch") + return ente.NewPermissionDeniedError("unknown custom domain") } + return nil } func computeHashKeyForList(list []string, delim string) string {