[server] Add API to return indexed files for given model

This commit is contained in:
Neeraj Gupta
2024-07-22 15:27:39 +05:30
parent 1972239bb0
commit 40a4f783f7
5 changed files with 74 additions and 0 deletions

View File

@@ -684,6 +684,7 @@ func main() {
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)
privateAPI.GET("/embeddings/diff", embeddingHandler.GetDiff)
privateAPI.GET("/embeddings/indexed-files", embeddingHandler.GetIndexedFiles)
privateAPI.POST("/embeddings/files", embeddingHandler.GetFilesEmbedding)
privateAPI.DELETE("/embeddings", embeddingHandler.DeleteAll)

View File

@@ -10,6 +10,12 @@ type Embedding struct {
Size *int64
}
// IndexedFile ...
type IndexedFile struct {
FileID int64 `json:"fileID"`
UpdatedAt int64 `json:"updatedAt"`
}
type InsertOrUpdateEmbeddingRequest struct {
FileID int64 `json:"fileID" binding:"required"`
Model string `json:"model" binding:"required"`
@@ -25,6 +31,12 @@ type GetEmbeddingDiffRequest struct {
Limit int16 `form:"limit" binding:"required"`
}
type GetIndexedFiles struct {
Model Model `form:"model"`
SinceTime *int64 `form:"sinceTime" binding:"required"`
Limit *int64 `form:"limit"`
}
type GetFilesEmbeddingRequest struct {
Model Model `form:"model" binding:"required"`
FileIDs []int64 `form:"fileIDs" binding:"required"`

View File

@@ -50,6 +50,24 @@ func (h *EmbeddingHandler) GetDiff(c *gin.Context) {
})
}
// GetIndexedFiles returns the fileIDs that has been indexed or updated for given user
func (h *EmbeddingHandler) GetIndexedFiles(c *gin.Context) {
var request ente.GetIndexedFiles
if err := c.ShouldBindQuery(&request); err != nil {
handler.Error(c,
stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err)))
return
}
embeddings, err := h.Controller.GetIndexedFiles(c, request)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, gin.H{
"diff": embeddings,
})
}
// GetFilesEmbedding returns the embeddings for the files
func (h *EmbeddingHandler) GetFilesEmbedding(c *gin.Context) {
var request ente.GetFilesEmbeddingRequest

View File

@@ -125,6 +125,19 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb
return &embedding, nil
}
func (c *Controller) GetIndexedFiles(ctx *gin.Context, req ente.GetIndexedFiles) ([]ente.IndexedFile, error) {
userID := auth.GetUserID(ctx.Request.Header)
updateSince := int64(0)
if req.SinceTime != nil {
updateSince = *req.SinceTime
}
indexedFiles, err := c.Repo.GetIndexedFiles(ctx, userID, req.Model, updateSince, req.Limit)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return indexedFiles, nil
}
func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) ([]ente.Embedding, error) {
userID := auth.GetUserID(ctx.Request.Header)

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"github.com/ente-io/museum/ente"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
"github.com/sirupsen/logrus"
)
@@ -176,6 +177,35 @@ func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Mode
return nil
}
func (r *Repository) GetIndexedFiles(ctx *gin.Context, id int64, model ente.Model, since int64, limit *int64) ([]ente.IndexedFile, error) {
var rows *sql.Rows
var err error
if limit == nil {
rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3`, id, model, since)
} else {
rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3 LIMIT $4`, id, model, since, *limit)
}
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
defer func() {
if err := rows.Close(); err != nil {
logrus.Error(err)
}
}()
result := make([]ente.IndexedFile, 0)
for rows.Next() {
var meta ente.IndexedFile
err := rows.Scan(&meta.FileID, &meta.UpdatedAt)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
result = append(result, meta)
}
return result, nil
}
func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
defer func() {
if err := rows.Close(); err != nil {