From 2fe703df925d82e49632feaa60959ed52ebf823f Mon Sep 17 00:00:00 2001 From: Neeraj Gupta Date: Wed, 3 Apr 2024 12:38:34 +0530 Subject: [PATCH] [server] Increase embedding fetch limit (#1300) ## Description Also use different semaphore than existing diff API ## Tests --- server/pkg/controller/embedding/controller.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index 7f2f5dd805..d6e78209fa 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -275,7 +275,9 @@ func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, er return len(embeddingObj), nil } -var globalFetchSemaphore = make(chan struct{}, 300) +var globalDiffFetchSemaphore = make(chan struct{}, 300) + +var globalFileFetchSemaphore = make(chan struct{}, 400) func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) { var wg sync.WaitGroup @@ -285,10 +287,10 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em for i, objectKey := range objectKeys { wg.Add(1) - globalFetchSemaphore <- struct{}{} // Acquire from global semaphore + globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore go func(i int, objectKey string) { defer wg.Done() - defer func() { <-globalFetchSemaphore }() // Release back to global semaphore + defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore obj, err := c.getEmbeddingObject(objectKey, downloader) if err != nil { @@ -322,10 +324,10 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows for i, dbEmbeddingRow := range dbEmbeddingRows { wg.Add(1) - globalFetchSemaphore <- struct{}{} // Acquire from global semaphore + globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore go func(i int, dbEmbeddingRow ente.Embedding) { defer wg.Done() - defer func() { <-globalFetchSemaphore }() // Release back to global semaphore + defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) obj, err := c.getEmbeddingObject(objectKey, downloader) if err != nil { @@ -373,8 +375,8 @@ func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID if len(req.FileIDs) == 0 { return ente.NewBadRequestWithMessage("fileIDs are required") } - if len(req.FileIDs) > 100 { - return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 100") + if len(req.FileIDs) > 200 { + return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200") } if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{ ActorUserId: userID,