diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index 25edbdaa65..2df2d1b560 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -36,33 +36,35 @@ const ( ) type Controller struct { - Repo *embedding.Repository - AccessCtrl access.Controller - ObjectCleanupController *controller.ObjectCleanupController - S3Config *s3config.S3Config - QueueRepo *repo.QueueRepository - TaskLockingRepo *repo.TaskLockRepository - FileRepo *repo.FileRepository - CollectionRepo *repo.CollectionRepository - HostName string - cleanupCronRunning bool - embeddingS3Client *s3.S3 - embeddingBucket *string + Repo *embedding.Repository + AccessCtrl access.Controller + ObjectCleanupController *controller.ObjectCleanupController + S3Config *s3config.S3Config + QueueRepo *repo.QueueRepository + TaskLockingRepo *repo.TaskLockRepository + FileRepo *repo.FileRepository + CollectionRepo *repo.CollectionRepository + HostName string + cleanupCronRunning bool + embeddingS3Client *s3.S3 + embeddingBucket *string + areEmbeddingAndHotBucketSame bool } func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller { return &Controller{ - Repo: repo, - AccessCtrl: accessCtrl, - ObjectCleanupController: objectCleanupController, - S3Config: s3Config, - QueueRepo: queueRepo, - TaskLockingRepo: taskLockingRepo, - FileRepo: fileRepo, - CollectionRepo: collectionRepo, - HostName: hostName, - embeddingS3Client: s3Config.GetEmbeddingsS3Client(), - embeddingBucket: s3Config.GetEmbeddingsBucket(), + Repo: repo, + AccessCtrl: accessCtrl, + ObjectCleanupController: objectCleanupController, + S3Config: s3Config, + QueueRepo: queueRepo, + TaskLockingRepo: taskLockingRepo, + FileRepo: fileRepo, + CollectionRepo: collectionRepo, + HostName: hostName, + embeddingS3Client: s3Config.GetEmbeddingsS3Client(), + embeddingBucket: s3Config.GetEmbeddingsBucket(), + areEmbeddingAndHotBucketSame: s3Config.GetEmbeddingsBucket() == s3Config.GetHotBucket(), } } @@ -269,7 +271,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) { return } // if Embeddings DC is different from hot DC, delete from hot DC as well - if c.S3Config.GetEmbeddingsDataCenter() != c.S3Config.GetHotDataCenter() { + if !c.areEmbeddingAndHotBucketSame { err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter()) if err != nil { ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC") @@ -425,10 +427,21 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err()) } else { // check if the error is due to object not found - if s3Err, ok := errors.Unwrap(err).(awserr.Error); ok { + if s3Err, ok := err.(awserr.RequestFailure); ok { if s3Err.Code() == s3.ErrCodeNoSuchKey { - ctxLogger.Warn("Object not found: ", s3Err) - return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") + if c.areEmbeddingAndHotBucketSame { + ctxLogger.Error("Object not found: ", s3Err) + } else { + // If embedding and hot bucket are different, try to copy from hot bucket + copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey) + if err == nil { + ctxLogger.Info("Got the object from hot bucket object") + return *copyEmbeddingObject, nil + } else { + ctxLogger.WithError(err).Error("Failed to copy from hot bucket object") + } + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") + } } } ctxLogger.Error("Failed to fetch object: ", err) @@ -455,6 +468,26 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl return obj, nil } +// download the embedding object from hot bucket and upload to embeddings bucket +func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) { + if c.embeddingBucket == c.S3Config.GetHotBucket() { + return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "") + } + downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) + obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket()) + if err != nil { + return nil, stacktrace.Propagate(err, "failed to download from hot bucket") + } + go func() { + _, err = c.uploadObject(obj, objectKey) + if err != nil { + log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err) + } + }() + + return &obj, nil +} + func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error { if req.Model == "" { return ente.NewBadRequestWithMessage("model is required")