From 40a4f783f763c487af4fcebf6b2568ad66fe0ff5 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Mon, 22 Jul 2024 15:27:39 +0530 Subject: [PATCH] [server] Add API to return indexed files for given model --- server/cmd/museum/main.go | 1 + server/ente/embedding.go | 12 ++++++++ server/pkg/api/embedding.go | 18 +++++++++++ server/pkg/controller/embedding/controller.go | 13 ++++++++ server/pkg/repo/embedding/repository.go | 30 +++++++++++++++++++ 5 files changed, 74 insertions(+) diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 5f8534b9d3..9258fa9b77 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -684,6 +684,7 @@ func main() { privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate) privateAPI.GET("/embeddings/diff", embeddingHandler.GetDiff) + privateAPI.GET("/embeddings/indexed-files", embeddingHandler.GetIndexedFiles) privateAPI.POST("/embeddings/files", embeddingHandler.GetFilesEmbedding) privateAPI.DELETE("/embeddings", embeddingHandler.DeleteAll) diff --git a/server/ente/embedding.go b/server/ente/embedding.go index fabde44a58..4388219b80 100644 --- a/server/ente/embedding.go +++ b/server/ente/embedding.go @@ -10,6 +10,12 @@ type Embedding struct { Size *int64 } +// IndexedFile ... +type IndexedFile struct { + FileID int64 `json:"fileID"` + UpdatedAt int64 `json:"updatedAt"` +} + type InsertOrUpdateEmbeddingRequest struct { FileID int64 `json:"fileID" binding:"required"` Model string `json:"model" binding:"required"` @@ -25,6 +31,12 @@ type GetEmbeddingDiffRequest struct { Limit int16 `form:"limit" binding:"required"` } +type GetIndexedFiles struct { + Model Model `form:"model"` + SinceTime *int64 `form:"sinceTime" binding:"required"` + Limit *int64 `form:"limit"` +} + type GetFilesEmbeddingRequest struct { Model Model `form:"model" binding:"required"` FileIDs []int64 `form:"fileIDs" binding:"required"` diff --git a/server/pkg/api/embedding.go b/server/pkg/api/embedding.go index 5b06503318..d95158dd8d 100644 --- a/server/pkg/api/embedding.go +++ b/server/pkg/api/embedding.go @@ -50,6 +50,24 @@ func (h *EmbeddingHandler) GetDiff(c *gin.Context) { }) } +// GetIndexedFiles returns the fileIDs that has been indexed or updated for given user +func (h *EmbeddingHandler) GetIndexedFiles(c *gin.Context) { + var request ente.GetIndexedFiles + if err := c.ShouldBindQuery(&request); err != nil { + handler.Error(c, + stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) + return + } + embeddings, err := h.Controller.GetIndexedFiles(c, request) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{ + "diff": embeddings, + }) +} + // GetFilesEmbedding returns the embeddings for the files func (h *EmbeddingHandler) GetFilesEmbedding(c *gin.Context) { var request ente.GetFilesEmbeddingRequest diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index 6f3de3ca78..aadb938038 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -125,6 +125,19 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb return &embedding, nil } +func (c *Controller) GetIndexedFiles(ctx *gin.Context, req ente.GetIndexedFiles) ([]ente.IndexedFile, error) { + userID := auth.GetUserID(ctx.Request.Header) + updateSince := int64(0) + if req.SinceTime != nil { + updateSince = *req.SinceTime + } + indexedFiles, err := c.Repo.GetIndexedFiles(ctx, userID, req.Model, updateSince, req.Limit) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return indexedFiles, nil +} + func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) { userID := auth.GetUserID(ctx.Request.Header) diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index 5cfbd35c57..234683a5a8 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/ente-io/museum/ente" "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" "github.com/lib/pq" "github.com/sirupsen/logrus" ) @@ -176,6 +177,35 @@ func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Mode return nil } +func (r *Repository) GetIndexedFiles(ctx *gin.Context, id int64, model ente.Model, since int64, limit *int64) ([]ente.IndexedFile, error) { + var rows *sql.Rows + var err error + if limit == nil { + rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3`, id, model, since) + } else { + rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3 LIMIT $4`, id, model, since, *limit) + } + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + defer func() { + if err := rows.Close(); err != nil { + logrus.Error(err) + } + }() + result := make([]ente.IndexedFile, 0) + for rows.Next() { + var meta ente.IndexedFile + err := rows.Scan(&meta.FileID, &meta.UpdatedAt) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + result = append(result, meta) + } + return result, nil + +} + func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) { defer func() { if err := rows.Close(); err != nil {