From f0236acf8f4678345c01a6d4e6a46d60b859f4d3 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 19 Mar 2024 06:15:55 +0530 Subject: [PATCH] [server] Minor bug fixes in embedding/files API --- server/pkg/api/embedding.go | 2 +- server/pkg/repo/embedding/repository.go | 3 ++- server/pkg/utils/array/array.go | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/server/pkg/api/embedding.go b/server/pkg/api/embedding.go index 4b072b0b70..5b06503318 100644 --- a/server/pkg/api/embedding.go +++ b/server/pkg/api/embedding.go @@ -53,7 +53,7 @@ func (h *EmbeddingHandler) GetDiff(c *gin.Context) { // GetFilesEmbedding returns the embeddings for the files func (h *EmbeddingHandler) GetFilesEmbedding(c *gin.Context) { var request ente.GetFilesEmbeddingRequest - if err := c.ShouldBindQuery(&request); err != nil { + if err := c.ShouldBindJSON(&request); err != nil { handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) return diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index 90e8a82642..c90dd8743e 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/lib/pq" "github.com/ente-io/museum/ente" "github.com/ente-io/stacktrace" @@ -58,7 +59,7 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) { rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at FROM embeddings - WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, fileIDs) + WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs)) if err != nil { return nil, stacktrace.Propagate(err, "") } diff --git a/server/pkg/utils/array/array.go b/server/pkg/utils/array/array.go index 46d1e31132..0c2d25d90f 100644 --- a/server/pkg/utils/array/array.go +++ b/server/pkg/utils/array/array.go @@ -71,7 +71,7 @@ func FindMissingElementsInSecondList(sourceList []int64, targetList []int64) []i targetSet[item] = struct{}{} } - var missingElements []int64 + var missingElements = make([]int64, 0) for _, item := range sourceList { if _, found := targetSet[item]; !found { missingElements = append(missingElements, item)