Add support for approving recovery

This commit is contained in:
Neeraj Gupta
2024-12-10 13:33:18 +05:30
parent c5c77ab706
commit 1222a063e8
5 changed files with 49 additions and 7 deletions

View File

@@ -621,6 +621,7 @@ func main() {
privateAPI.POST("/emergency-contacts/start-recovery", emergencyHandler.StartRecovery)
privateAPI.POST("/emergency-contacts/stop-recovery", emergencyHandler.StopRecovery)
privateAPI.POST("/emergency-contacts/reject-recovery", emergencyHandler.RejectRecovery)
privateAPI.POST("/emergency-contacts/approve-recovery", emergencyHandler.RejectRecovery)
privateAPI.GET("/emergency-contacts/recovery-info/:id", emergencyHandler.GetRecoveryInfo)
privateAPI.POST("/emergency-contacts/init-change-password", emergencyHandler.InitChangePassword)
privateAPI.POST("/emergency-contacts/change-password", emergencyHandler.ChangePassword)

View File

@@ -96,6 +96,20 @@ func (h *EmergencyHandler) RejectRecovery(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
}
func (h *EmergencyHandler) ApproveRecovery(c *gin.Context) {
var request ente.RecoveryIdentifier
if err := c.ShouldBindJSON(&request); err != nil {
handler.Error(c, stacktrace.Propagate(ente.NewBadRequestWithMessage("failed to validate req param"), err.Error()))
return
}
err := h.Controller.ApproveRecovery(c, auth.GetUserID(c.Request.Header), request)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, gin.H{})
}
func (h *EmergencyHandler) GetRecoveryInfo(c *gin.Context) {
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {

View File

@@ -13,7 +13,7 @@ func (c *Controller) GetRecoveryInfo(ctx *gin.Context,
userID int64,
sessionID uuid.UUID,
) (*string, *ente.KeyAttributes, error) {
contact, err := c.validateSessionAndGetContact(ctx, userID, sessionID)
contact, err := c.checkRecoveryAndGetContact(ctx, userID, sessionID)
if err != nil {
return nil, nil, err
}
@@ -30,7 +30,7 @@ func (c *Controller) GetRecoveryInfo(ctx *gin.Context,
func (c *Controller) InitChangePassword(ctx *gin.Context, userID int64, request ente.RecoverySrpSetupRequest) (*ente.SetupSRPResponse, error) {
sessionID := request.RecoveryID
contact, err := c.validateSessionAndGetContact(ctx, userID, sessionID)
contact, err := c.checkRecoveryAndGetContact(ctx, userID, sessionID)
if err != nil {
return nil, err
}
@@ -43,7 +43,7 @@ func (c *Controller) InitChangePassword(ctx *gin.Context, userID int64, request
func (c *Controller) ChangePassword(ctx *gin.Context, userID int64, request ente.RecoveryUpdateSRPAndKeysRequest) (*ente.UpdateSRPSetupResponse, error) {
sessionID := request.RecoveryID
contact, err := c.validateSessionAndGetContact(ctx, userID, sessionID)
contact, err := c.checkRecoveryAndGetContact(ctx, userID, sessionID)
if err != nil {
return nil, err
}
@@ -64,7 +64,7 @@ func (c *Controller) ChangePassword(ctx *gin.Context, userID int64, request ente
return resp, nil
}
func (c *Controller) validateSessionAndGetContact(ctx *gin.Context,
func (c *Controller) checkRecoveryAndGetContact(ctx *gin.Context,
userID int64,
sessionID uuid.UUID) (*emergency.ContactRow, error) {
recoverRow, err := c.Repo.GetRecoverRowByID(ctx, sessionID)

View File

@@ -53,6 +53,26 @@ func (c *Controller) RejectRecovery(ctx *gin.Context,
return nil
}
func (c *Controller) ApproveRecovery(ctx *gin.Context,
userID int64,
req ente.RecoveryIdentifier) error {
if req.EmergencyContactID == req.UserID {
return stacktrace.Propagate(ente.NewBadRequestWithMessage("contact and user can not be same"), "")
}
if req.UserID != userID {
return stacktrace.Propagate(ente.ErrPermissionDenied, "only account owner can reject recovery")
}
hasUpdate, err := c.Repo.UpdateRecoveryStatusForID(ctx, req.ID, ente.RecoveryStatusReady)
if !hasUpdate {
log.WithField("userID", userID).WithField("req", req).
Warn("no row updated while rejecting recovery")
}
if err != nil {
return stacktrace.Propagate(err, "")
}
return nil
}
func (c *Controller) StopRecovery(ctx *gin.Context,
userID int64,
req ente.RecoveryIdentifier) error {

View File

@@ -2,6 +2,7 @@ package emergency
import (
"context"
"database/sql"
"fmt"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/utils/time"
@@ -62,9 +63,15 @@ FROM emergency_recovery WHERE (user_id=$1 OR emergency_contact_id=$1) AND statu
func (repo *Repository) UpdateRecoveryStatusForID(ctx context.Context, sessionID uuid.UUID, status ente.RecoveryStatus) (bool, error) {
validPrevStatus := validPreviousStatus(status)
result, err := repo.DB.ExecContext(ctx, `UPDATE emergency_recovery SET status=$1 WHERE id=$2 and status = ANY($3)`, status, sessionID, pq.Array(validPrevStatus))
if err != nil {
return false, stacktrace.Propagate(err, "")
var result sql.Result
var err error
if status == ente.RecoveryStatusReady {
result, err = repo.DB.ExecContext(ctx, `UPDATE emergency_recovery SET status=$1, wait_till=$2 WHERE id=$3 and status = ANY($4)`, status, time.Microseconds(), sessionID, pq.Array(validPrevStatus))
} else {
result, err = repo.DB.ExecContext(ctx, `UPDATE emergency_recovery SET status=$1 WHERE id=$2 and status = ANY($3)`, status, sessionID, pq.Array(validPrevStatus))
if err != nil {
return false, stacktrace.Propagate(err, "")
}
}
rows, _ := result.RowsAffected()
return rows > 0, nil