From c648127ff876cae7a3750ec5cfadfa87bdcacddf Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Thu, 12 Dec 2024 14:53:08 +0530 Subject: [PATCH] [server] Reject/Stop active recovery when contact is removed --- server/ente/emergency.go | 1 - server/pkg/controller/emergency/controller.go | 51 ++++++++++--------- server/pkg/repo/emergency/recovery.go | 19 +++++++ server/pkg/repo/emergency/repository.go | 3 +- 4 files changed, 48 insertions(+), 26 deletions(-) diff --git a/server/ente/emergency.go b/server/ente/emergency.go index 57207e97dc..0dcab7ab17 100644 --- a/server/ente/emergency.go +++ b/server/ente/emergency.go @@ -33,7 +33,6 @@ const ( UserInvitedContact ContactState = "INVITED" UserRevokedContact ContactState = "REVOKED" ContactAccepted ContactState = "ACCEPTED" - ContactDeleted ContactState = "DELETED" ContactLeft ContactState = "CONTACT_LEFT" ContactDenied ContactState = "CONTACT_DENIED" ) diff --git a/server/pkg/controller/emergency/controller.go b/server/pkg/controller/emergency/controller.go index ed7c5cb396..1cf399f86e 100644 --- a/server/pkg/controller/emergency/controller.go +++ b/server/pkg/controller/emergency/controller.go @@ -2,6 +2,7 @@ package emergency import ( "fmt" + "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/controller/user" "github.com/ente-io/museum/pkg/repo" @@ -23,6 +24,33 @@ func (c *Controller) UpdateContact(ctx *gin.Context, if err := validateUpdateReq(userID, req); err != nil { return stacktrace.Propagate(err, "") } + if req.State == ente.ContactDenied || req.State == ente.ContactLeft || req.State == ente.UserRevokedContact { + activeSessions, sessionErr := c.Repo.GetActiveSessions(ctx, req.UserID, req.EmergencyContactID) + if sessionErr != nil { + return stacktrace.Propagate(sessionErr, "") + } + for _, session := range activeSessions { + if req.State == ente.UserRevokedContact { + rejErr := c.RejectRecovery(ctx, userID, ente.RecoveryIdentifier{ + ID: session.ID, + UserID: session.UserID, + EmergencyContactID: session.EmergencyContactID, + }) + if rejErr != nil { + return stacktrace.Propagate(rejErr, "failed to reject recovery") + } + } else { + stopErr := c.StopRecovery(ctx, userID, ente.RecoveryIdentifier{ + ID: session.ID, + UserID: session.UserID, + EmergencyContactID: session.EmergencyContactID, + }) + if stopErr != nil { + return stacktrace.Propagate(stopErr, "failed to stop recovery") + } + } + } + } hasUpdate, err := c.Repo.UpdateState(ctx, req.UserID, req.EmergencyContactID, req.State) if !hasUpdate { log.WithField("userID", userID).WithField("req", req). @@ -30,12 +58,6 @@ func (c *Controller) UpdateContact(ctx *gin.Context, } else { go c.sendContactNotification(ctx, req.UserID, req.EmergencyContactID, req.State) } - recoverStatus := getNextRecoveryStatusFromContactState(req.State) - if recoverStatus != nil { - if err := c.Repo.UpdateRecoveryStatus(ctx, req.UserID, req.EmergencyContactID, *recoverStatus); err != nil { - return stacktrace.Propagate(err, "") - } - } if err != nil { return stacktrace.Propagate(err, "") } @@ -66,20 +88,3 @@ func validateUpdateReq(userID int64, req ente.UpdateContact) error { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("Can not update state to %s", req.State)), "") } } - -// When a user contact state is update, we need to update the recovery status for any ongoing recovery -func getNextRecoveryStatusFromContactState(state ente.ContactState) *ente.RecoveryStatus { - switch state { - case ente.ContactAccepted: - return nil - case ente.UserInvitedContact: - return nil - case ente.ContactLeft: - return ente.RecoveryStatusStopped.Ptr() - case ente.ContactDenied: - return ente.RecoveryStatusStopped.Ptr() - case ente.UserRevokedContact: - return ente.RecoveryStatusRejected.Ptr() - } - return nil -} diff --git a/server/pkg/repo/emergency/recovery.go b/server/pkg/repo/emergency/recovery.go index 9d5a363b0f..dc41c3d5ff 100644 --- a/server/pkg/repo/emergency/recovery.go +++ b/server/pkg/repo/emergency/recovery.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" @@ -65,6 +66,24 @@ FROM emergency_recovery WHERE (user_id=$1 OR emergency_contact_id=$1) AND statu return sessions, nil } +func (repo *Repository) GetActiveSessions(ctx *gin.Context, userID int64, emergencyContactID int64) ([]*RecoverRow, error) { + rows, err := repo.DB.QueryContext(ctx, `SELECT id, user_id, emergency_contact_id, status, wait_till, next_reminder_at, created_at +FROM emergency_recovery WHERE user_id=$1 and emergency_contact_id=$2 AND status= ANY($3)`, userID, emergencyContactID, pq.Array([]ente.RecoveryStatus{ente.RecoveryStatusWaiting, ente.RecoveryStatusReady})) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + defer rows.Close() + var sessions []*RecoverRow + for rows.Next() { + var row RecoverRow + if err := rows.Scan(&row.ID, &row.UserID, &row.EmergencyContactID, &row.Status, &row.WaitTill, &row.NextReminderAt, &row.CreatedAt); err != nil { + return nil, stacktrace.Propagate(err, "") + } + sessions = append(sessions, &row) + } + return sessions, nil +} + func (repo *Repository) UpdateRecoveryStatusForID(ctx context.Context, sessionID uuid.UUID, status ente.RecoveryStatus) (bool, error) { validPrevStatus := validPreviousStatus(status) var result sql.Result diff --git a/server/pkg/repo/emergency/repository.go b/server/pkg/repo/emergency/repository.go index eb64a3d275..5bf4eba9c3 100644 --- a/server/pkg/repo/emergency/repository.go +++ b/server/pkg/repo/emergency/repository.go @@ -118,8 +118,7 @@ func getValidPreviousState(cs ente.ContactState) []ente.ContactState { return []ente.ContactState{ente.UserInvitedContact} case ente.UserRevokedContact: return []ente.ContactState{ente.UserInvitedContact, ente.ContactAccepted} - case ente.ContactDeleted: - return []ente.ContactState{ente.UserInvitedContact, ente.ContactAccepted} + } panic("invalid state") }