From 4bbe1ae0d237d4d405fb5606df252c3d64f5f2af Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Mon, 24 Feb 2025 14:49:21 +0530 Subject: [PATCH] [server] Remove embeddings handler --- server/cmd/museum/main.go | 7 - server/pkg/api/embedding.go | 95 ---- server/pkg/controller/embedding/controller.go | 425 ------------------ server/pkg/controller/embedding/delete.go | 13 - 4 files changed, 540 deletions(-) delete mode 100644 server/pkg/api/embedding.go diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index ded9c46475..9bc1beef63 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -755,13 +755,6 @@ func main() { privateAPI.POST("/push/token", pushHandler.AddToken) embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName) - embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController} - - privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate) - privateAPI.GET("/embeddings/diff", embeddingHandler.GetDiff) - privateAPI.GET("/embeddings/indexed-files", embeddingHandler.GetIndexedFiles) - privateAPI.POST("/embeddings/files", embeddingHandler.GetFilesEmbedding) - privateAPI.DELETE("/embeddings", embeddingHandler.DeleteAll) offerHandler := &api.OfferHandler{Controller: offerController} publicAPI.GET("/offers/black-friday", offerHandler.GetBlackFridayOffers) diff --git a/server/pkg/api/embedding.go b/server/pkg/api/embedding.go deleted file mode 100644 index d95158dd8d..0000000000 --- a/server/pkg/api/embedding.go +++ /dev/null @@ -1,95 +0,0 @@ -package api - -import ( - "fmt" - "net/http" - - "github.com/ente-io/museum/ente" - "github.com/ente-io/museum/pkg/controller/embedding" - "github.com/ente-io/museum/pkg/utils/handler" - - "github.com/ente-io/stacktrace" - "github.com/gin-gonic/gin" -) - -type EmbeddingHandler struct { - Controller *embedding.Controller -} - -// InsertOrUpdate handler for inserting or updating embedding -func (h *EmbeddingHandler) InsertOrUpdate(c *gin.Context) { - var request ente.InsertOrUpdateEmbeddingRequest - if err := c.ShouldBindJSON(&request); err != nil { - handler.Error(c, - stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) - return - } - embedding, err := h.Controller.InsertOrUpdate(c, request) - if err != nil { - handler.Error(c, stacktrace.Propagate(err, "")) - return - } - c.JSON(http.StatusOK, embedding) -} - -// GetDiff handler for getting diff of embedding -func (h *EmbeddingHandler) GetDiff(c *gin.Context) { - var request ente.GetEmbeddingDiffRequest - if err := c.ShouldBindQuery(&request); err != nil { - handler.Error(c, - stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) - return - } - embeddings, err := h.Controller.GetDiff(c, request) - if err != nil { - handler.Error(c, stacktrace.Propagate(err, "")) - return - } - c.JSON(http.StatusOK, gin.H{ - "diff": embeddings, - }) -} - -// GetIndexedFiles returns the fileIDs that has been indexed or updated for given user -func (h *EmbeddingHandler) GetIndexedFiles(c *gin.Context) { - var request ente.GetIndexedFiles - if err := c.ShouldBindQuery(&request); err != nil { - handler.Error(c, - stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) - return - } - embeddings, err := h.Controller.GetIndexedFiles(c, request) - if err != nil { - handler.Error(c, stacktrace.Propagate(err, "")) - return - } - c.JSON(http.StatusOK, gin.H{ - "diff": embeddings, - }) -} - -// GetFilesEmbedding returns the embeddings for the files -func (h *EmbeddingHandler) GetFilesEmbedding(c *gin.Context) { - var request ente.GetFilesEmbeddingRequest - if err := c.ShouldBindJSON(&request); err != nil { - handler.Error(c, - stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) - return - } - resp, err := h.Controller.GetFilesEmbedding(c, request) - if err != nil { - handler.Error(c, stacktrace.Propagate(err, "")) - return - } - c.JSON(http.StatusOK, resp) -} - -// DeleteAll handler for deleting all embeddings for the user -func (h *EmbeddingHandler) DeleteAll(c *gin.Context) { - err := h.Controller.DeleteAll(c) - if err != nil { - handler.Error(c, stacktrace.Propagate(err, "")) - return - } - c.Status(http.StatusOK) -} diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index 190f21d40a..d8cb4e62c0 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -1,32 +1,15 @@ package embedding import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/ente-io/museum/pkg/utils/array" "strconv" - "strings" - "sync" gTime "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" - "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/controller" "github.com/ente-io/museum/pkg/controller/access" "github.com/ente-io/museum/pkg/repo" "github.com/ente-io/museum/pkg/repo/embedding" - "github.com/ente-io/museum/pkg/utils/auth" - "github.com/ente-io/museum/pkg/utils/network" "github.com/ente-io/museum/pkg/utils/s3config" - "github.com/ente-io/stacktrace" - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" ) const ( @@ -35,16 +18,6 @@ const ( embeddingFetchTimeout = 10 * gTime.Second ) -// _fetchConfig is the configuration for the fetching objects from S3 -type _fetchConfig struct { - RetryCount int - InitialTimeout gTime.Duration - MaxTimeout gTime.Duration -} - -var _defaultFetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 10 * gTime.Second, MaxTimeout: 30 * gTime.Second} -var _b2FetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 15 * gTime.Second, MaxTimeout: 30 * gTime.Second} - type Controller struct { Repo *embedding.Repository AccessCtrl access.Controller @@ -82,404 +55,6 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup } } -func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) { - userID := auth.GetUserID(ctx.Request.Header) - - err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{ - ActorUserId: userID, - FileIDs: []int64{req.FileID}, - }) - - if err != nil { - return nil, stacktrace.Propagate(err, "User does not own file") - } - - count, err := c.CollectionRepo.GetCollectionCount(req.FileID) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - if count < 1 { - return nil, stacktrace.Propagate(ente.ErrNotFound, "") - } - version := 1 - if req.Version != nil { - version = *req.Version - } - - obj := ente.EmbeddingObject{ - Version: version, - EncryptedEmbedding: req.EncryptedEmbedding, - DecryptionHeader: req.DecryptionHeader, - Client: network.GetClientInfo(ctx), - } - size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter) - if uploadErr != nil { - log.Error(uploadErr) - return nil, stacktrace.Propagate(uploadErr, "") - } - embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter) - embedding.Version = &version - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - return &embedding, nil -} - -func (c *Controller) GetIndexedFiles(ctx *gin.Context, req ente.GetIndexedFiles) ([]ente.IndexedFile, error) { - userID := auth.GetUserID(ctx.Request.Header) - updateSince := int64(0) - if req.SinceTime != nil { - updateSince = *req.SinceTime - } - indexedFiles, err := c.Repo.GetIndexedFiles(ctx, userID, req.Model, updateSince, req.Limit) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - return indexedFiles, nil -} - -func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) { - userID := auth.GetUserID(ctx.Request.Header) - - if req.Model == "" { - req.Model = ente.GgmlClip - } - - embeddings, err := c.Repo.GetDiff(ctx, userID, req.Model, *req.SinceTime, req.Limit) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - - // Collect object keys for embeddings with missing data - var objectKeys []string - for i := range embeddings { - if embeddings[i].EncryptedEmbedding == "" { - objectKey := c.getObjectKey(userID, embeddings[i].FileID, embeddings[i].Model) - objectKeys = append(objectKeys, objectKey) - } - } - - // Fetch missing embeddings in parallel - if len(objectKeys) > 0 { - embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - - // Populate missing data in embeddings from fetched objects - for i, obj := range embeddingObjects { - for j := range embeddings { - if embeddings[j].EncryptedEmbedding == "" && c.getObjectKey(userID, embeddings[j].FileID, embeddings[j].Model) == objectKeys[i] { - embeddings[j].EncryptedEmbedding = obj.EncryptedEmbedding - embeddings[j].DecryptionHeader = obj.DecryptionHeader - } - } - } - } - - return embeddings, nil -} - -func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) { - userID := auth.GetUserID(ctx.Request.Header) - if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil { - return nil, stacktrace.Propagate(err, "") - } - - userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - - embeddingsWithData := make([]ente.Embedding, 0) - noEmbeddingFileIds := make([]int64, 0) - dbFileIds := make([]int64, 0) - // fileIDs that were indexed, but they don't contain any embedding information - for i := range userFileEmbeddings { - dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID) - if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize { - noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID) - } else { - embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i]) - } - } - pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds) - errFileIds := make([]int64, 0) - - // Fetch missing userFileEmbeddings in parallel - embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - fetchedEmbeddings := make([]ente.Embedding, 0) - - // Populate missing data in userFileEmbeddings from fetched objects - for _, obj := range embeddingObjects { - if obj.err != nil { - errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID) - } else { - fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{ - FileID: obj.dbEmbeddingRow.FileID, - Model: obj.dbEmbeddingRow.Model, - EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding, - DecryptionHeader: obj.embeddingObject.DecryptionHeader, - UpdatedAt: obj.dbEmbeddingRow.UpdatedAt, - Version: obj.dbEmbeddingRow.Version, - }) - } - } - - return &ente.GetFilesEmbeddingResponse{ - Embeddings: fetchedEmbeddings, - PendingIndexFileIDs: pendingIndexFileIds, - ErrFileIDs: errFileIds, - NoEmbeddingFileIDs: noEmbeddingFileIds, - }, nil -} - -func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string { - return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json" -} - func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string { return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/" } - -// Get userId, model and fileID from the object key -func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) { - split := strings.Split(objectKey, "/") - userID, _ = strconv.ParseInt(split[0], 10, 64) - fileID, _ = strconv.ParseInt(split[2], 10, 64) - model = strings.Split(split[3], ".")[0] - return userID, model, fileID -} - -// uploadObject uploads the embedding object to the object store and returns the object size -func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) { - embeddingObj, _ := json.Marshal(obj) - s3Client := c.S3Config.GetS3Client(dc) - s3Bucket := c.S3Config.GetBucket(dc) - uploader := s3manager.NewUploaderWithClient(&s3Client) - up := s3manager.UploadInput{ - Bucket: s3Bucket, - Key: &key, - Body: bytes.NewReader(embeddingObj), - } - result, err := uploader.Upload(&up) - if err != nil { - log.Error(err) - return -1, stacktrace.Propagate(err, "") - } - - log.Infof("Uploaded to bucket %s", result.Location) - return len(embeddingObj), nil -} - -var globalDiffFetchSemaphore = make(chan struct{}, 300) - -var globalFileFetchSemaphore = make(chan struct{}, 400) - -func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) { - var wg sync.WaitGroup - var errs []error - embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys)) - for i, objectKey := range objectKeys { - wg.Add(1) - globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore - go func(i int, objectKey string) { - defer wg.Done() - defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore - - obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) - if err != nil { - errs = append(errs, err) - log.Error("error fetching embedding object: "+objectKey, err) - } else { - embeddingObjects[i] = obj - } - }(i, objectKey) - } - - wg.Wait() - - if len(errs) > 0 { - return nil, stacktrace.Propagate(errors.New("failed to fetch some objects"), "") - } - - return embeddingObjects, nil -} - -type embeddingObjectResult struct { - embeddingObject ente.EmbeddingObject - dbEmbeddingRow ente.Embedding - err error -} - -func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) { - var wg sync.WaitGroup - embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows)) - - for i, dbEmbeddingRow := range dbEmbeddingRows { - wg.Add(1) - globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore - go func(i int, dbEmbeddingRow ente.Embedding) { - defer wg.Done() - defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore - objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) - obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) - if err != nil { - log.Error("error fetching embedding object: "+objectKey, err) - embeddingObjects[i] = embeddingObjectResult{ - err: err, - dbEmbeddingRow: dbEmbeddingRow, - } - - } else { - embeddingObjects[i] = embeddingObjectResult{ - embeddingObject: obj, - dbEmbeddingRow: dbEmbeddingRow, - } - } - }(i, dbEmbeddingRow) - } - wg.Wait() - return embeddingObjects, nil -} - -func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { - opt := _defaultFetchConfig - if dc == c.S3Config.GetHotBackblazeDC() { - opt = _b2FetchConfig - } - ctxLogger := log.WithField("objectKey", objectKey).WithField("dc", dc) - totalAttempts := opt.RetryCount + 1 - timeout := opt.InitialTimeout - for i := 0; i < totalAttempts; i++ { - if i > 0 { - timeout = timeout * 2 - if timeout > opt.MaxTimeout { - timeout = opt.MaxTimeout - } - } - fetchCtx, cancel := context.WithTimeout(ctx, timeout) - select { - case <-ctx.Done(): - cancel() - return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "") - default: - obj, err := c.downloadObject(fetchCtx, objectKey, dc) - cancel() // Ensure cancel is called to release resources - if err == nil { - if i > 0 { - ctxLogger.Infof("Fetched object after %d attempts", i) - } - return obj, nil - } - // Check if the error is due to context timeout or cancellation - if err == nil && fetchCtx.Err() != nil { - ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err()) - } else { - // check if the error is due to object not found - if s3Err, ok := err.(awserr.RequestFailure); ok { - if s3Err.Code() == s3.ErrCodeNoSuchKey { - var srcDc, destDc string - destDc = c.S3Config.GetDerivedStorageDataCenter() - // todo:(neeraj) Refactor this later to get available the DC from the DB instead of - // querying the DB. This will help in case of multiple DCs and avoid querying the DB - // for each object. - // For initial migration, as we know that original DC was b2, and if the embedding is not found - // in the new derived DC, we can try to fetch it from the B2 DC. - if c.derivedStorageDataCenter != c.S3Config.GetHotBackblazeDC() { - // embeddings ideally should ideally be in the default hot bucket b2 - srcDc = c.S3Config.GetHotBackblazeDC() - } else { - _, modelName, fileID := c.getEmbeddingObjectDetails(objectKey) - activeDcs, err := c.Repo.GetOtherDCsForFileAndModel(context.Background(), fileID, modelName, c.derivedStorageDataCenter) - if err != nil { - return ente.EmbeddingObject{}, stacktrace.Propagate(err, "failed to get other dc") - } - if len(activeDcs) > 0 { - srcDc = activeDcs[0] - } else { - ctxLogger.Error("Object not found in any dc ", s3Err) - return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") - } - } - copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, srcDc, destDc) - if err == nil { - ctxLogger.Infof("Got object from dc %s", srcDc) - return *copyEmbeddingObject, nil - } else { - ctxLogger.WithError(err).Errorf("Failed to get object from fallback dc %s", srcDc) - } - return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") - } - } - ctxLogger.Error("Failed to fetch object: ", err) - } - } - } - return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "") -} - -func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { - var obj ente.EmbeddingObject - buff := &aws.WriteAtBuffer{} - bucket := c.S3Config.GetBucket(dc) - downloader := c.downloadManagerCache[dc] - _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ - Bucket: bucket, - Key: &objectKey, - }) - if err != nil { - return obj, err - } - err = json.Unmarshal(buff.Bytes(), &obj) - if err != nil { - return obj, stacktrace.Propagate(err, "unmarshal failed") - } - return obj, nil -} - -// download the embedding object from hot bucket and upload to embeddings bucket -func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) { - if srcDC == destDC { - return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "") - } - obj, err := c.downloadObject(ctx, objectKey, srcDC) - if err != nil { - return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC)) - } - go func() { - userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey) - size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter) - if uploadErr != nil { - log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", uploadErr) - } - updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC) - if updateDcErr != nil { - log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr) - return - } - }() - return &obj, nil -} - -func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error { - if req.Model == "" { - return ente.NewBadRequestWithMessage("model is required") - } - if len(req.FileIDs) == 0 { - return ente.NewBadRequestWithMessage("fileIDs are required") - } - if len(req.FileIDs) > 200 { - return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200") - } - if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{ - ActorUserId: userID, - FileIDs: req.FileIDs, - }); err != nil { - return stacktrace.Propagate(err, "User does not own some file(s)") - } - return nil -} diff --git a/server/pkg/controller/embedding/delete.go b/server/pkg/controller/embedding/delete.go index 91a70963fe..134a5192dc 100644 --- a/server/pkg/controller/embedding/delete.go +++ b/server/pkg/controller/embedding/delete.go @@ -4,24 +4,11 @@ import ( "context" "fmt" "github.com/ente-io/museum/pkg/repo" - "github.com/ente-io/museum/pkg/utils/auth" "github.com/ente-io/museum/pkg/utils/time" - "github.com/ente-io/stacktrace" - "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" "strconv" ) -func (c *Controller) DeleteAll(ctx *gin.Context) error { - userID := auth.GetUserID(ctx.Request.Header) - - err := c.Repo.DeleteAll(ctx, userID) - if err != nil { - return stacktrace.Propagate(err, "") - } - return nil -} - // CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store func (c *Controller) CleanupDeletedEmbeddings() { log.Info("Cleaning up deleted embeddings")