diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 4b38d60b3b..523b4fd769 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -621,6 +621,7 @@ func main() { privateAPI.POST("/emergency-contacts/start-recovery", emergencyHandler.StartRecovery) privateAPI.POST("/emergency-contacts/stop-recovery", emergencyHandler.StopRecovery) privateAPI.POST("/emergency-contacts/reject-recovery", emergencyHandler.RejectRecovery) + privateAPI.POST("/emergency-contacts/approve-recovery", emergencyHandler.RejectRecovery) privateAPI.GET("/emergency-contacts/recovery-info/:id", emergencyHandler.GetRecoveryInfo) privateAPI.POST("/emergency-contacts/init-change-password", emergencyHandler.InitChangePassword) privateAPI.POST("/emergency-contacts/change-password", emergencyHandler.ChangePassword) diff --git a/server/pkg/api/emergency.go b/server/pkg/api/emergency.go index 9571b2c1c6..7fef6a27ca 100644 --- a/server/pkg/api/emergency.go +++ b/server/pkg/api/emergency.go @@ -96,6 +96,20 @@ func (h *EmergencyHandler) RejectRecovery(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) } +func (h *EmergencyHandler) ApproveRecovery(c *gin.Context) { + var request ente.RecoveryIdentifier + if err := c.ShouldBindJSON(&request); err != nil { + handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("failed to validate req param"), err.Error())) + return + } + err := h.Controller.ApproveRecovery(c, auth.GetUserID(c.Request.Header), request) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{}) +} + func (h *EmergencyHandler) GetRecoveryInfo(c *gin.Context) { sessionID, err := uuid.Parse(c.Param("id")) if err != nil { diff --git a/server/pkg/controller/emergency/recovery.go b/server/pkg/controller/emergency/recovery.go index 6915b3eedd..4983693048 100644 --- a/server/pkg/controller/emergency/recovery.go +++ b/server/pkg/controller/emergency/recovery.go @@ -13,7 +13,7 @@ func (c *Controller) GetRecoveryInfo(ctx *gin.Context, userID int64, sessionID uuid.UUID, ) (*string, *ente.KeyAttributes, error) { - contact, err := c.validateSessionAndGetContact(ctx, userID, sessionID) + contact, err := c.checkRecoveryAndGetContact(ctx, userID, sessionID) if err != nil { return nil, nil, err } @@ -30,7 +30,7 @@ func (c *Controller) GetRecoveryInfo(ctx *gin.Context, func (c *Controller) InitChangePassword(ctx *gin.Context, userID int64, request ente.RecoverySrpSetupRequest) (*ente.SetupSRPResponse, error) { sessionID := request.RecoveryID - contact, err := c.validateSessionAndGetContact(ctx, userID, sessionID) + contact, err := c.checkRecoveryAndGetContact(ctx, userID, sessionID) if err != nil { return nil, err } @@ -43,7 +43,7 @@ func (c *Controller) InitChangePassword(ctx *gin.Context, userID int64, request func (c *Controller) ChangePassword(ctx *gin.Context, userID int64, request ente.RecoveryUpdateSRPAndKeysRequest) (*ente.UpdateSRPSetupResponse, error) { sessionID := request.RecoveryID - contact, err := c.validateSessionAndGetContact(ctx, userID, sessionID) + contact, err := c.checkRecoveryAndGetContact(ctx, userID, sessionID) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (c *Controller) ChangePassword(ctx *gin.Context, userID int64, request ente return resp, nil } -func (c *Controller) validateSessionAndGetContact(ctx *gin.Context, +func (c *Controller) checkRecoveryAndGetContact(ctx *gin.Context, userID int64, sessionID uuid.UUID) (*emergency.ContactRow, error) { recoverRow, err := c.Repo.GetRecoverRowByID(ctx, sessionID) diff --git a/server/pkg/controller/emergency/recovery_contact.go b/server/pkg/controller/emergency/recovery_contact.go index 769e19a179..039050f1be 100644 --- a/server/pkg/controller/emergency/recovery_contact.go +++ b/server/pkg/controller/emergency/recovery_contact.go @@ -53,6 +53,26 @@ func (c *Controller) RejectRecovery(ctx *gin.Context, return nil } +func (c *Controller) ApproveRecovery(ctx *gin.Context, + userID int64, + req ente.RecoveryIdentifier) error { + if req.EmergencyContactID == req.UserID { + return stacktrace.Propagate(ente.NewBadRequestWithMessage("contact and user can not be same"), "") + } + if req.UserID != userID { + return stacktrace.Propagate(ente.ErrPermissionDenied, "only account owner can reject recovery") + } + hasUpdate, err := c.Repo.UpdateRecoveryStatusForID(ctx, req.ID, ente.RecoveryStatusReady) + if !hasUpdate { + log.WithField("userID", userID).WithField("req", req). + Warn("no row updated while rejecting recovery") + } + if err != nil { + return stacktrace.Propagate(err, "") + } + return nil +} + func (c *Controller) StopRecovery(ctx *gin.Context, userID int64, req ente.RecoveryIdentifier) error { diff --git a/server/pkg/repo/emergency/recovery.go b/server/pkg/repo/emergency/recovery.go index 099d56df07..f84478dc60 100644 --- a/server/pkg/repo/emergency/recovery.go +++ b/server/pkg/repo/emergency/recovery.go @@ -2,6 +2,7 @@ package emergency import ( "context" + "database/sql" "fmt" "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/utils/time" @@ -62,9 +63,15 @@ FROM emergency_recovery WHERE (user_id=$1 OR emergency_contact_id=$1) AND statu func (repo *Repository) UpdateRecoveryStatusForID(ctx context.Context, sessionID uuid.UUID, status ente.RecoveryStatus) (bool, error) { validPrevStatus := validPreviousStatus(status) - result, err := repo.DB.ExecContext(ctx, `UPDATE emergency_recovery SET status=$1 WHERE id=$2 and status = ANY($3)`, status, sessionID, pq.Array(validPrevStatus)) - if err != nil { - return false, stacktrace.Propagate(err, "") + var result sql.Result + var err error + if status == ente.RecoveryStatusReady { + result, err = repo.DB.ExecContext(ctx, `UPDATE emergency_recovery SET status=$1, wait_till=$2 WHERE id=$3 and status = ANY($4)`, status, time.Microseconds(), sessionID, pq.Array(validPrevStatus)) + } else { + result, err = repo.DB.ExecContext(ctx, `UPDATE emergency_recovery SET status=$1 WHERE id=$2 and status = ANY($3)`, status, sessionID, pq.Array(validPrevStatus)) + if err != nil { + return false, stacktrace.Propagate(err, "") + } } rows, _ := result.RowsAffected() return rows > 0, nil