[server] Reject/Stop active recovery when contact is removed

This commit is contained in:
Neeraj Gupta
2024-12-12 14:53:08 +05:30
parent cbe105020b
commit c648127ff8
4 changed files with 48 additions and 26 deletions

View File

@@ -33,7 +33,6 @@ const (
UserInvitedContact ContactState = "INVITED"
UserRevokedContact ContactState = "REVOKED"
ContactAccepted ContactState = "ACCEPTED"
ContactDeleted ContactState = "DELETED"
ContactLeft ContactState = "CONTACT_LEFT"
ContactDenied ContactState = "CONTACT_DENIED"
)

View File

@@ -2,6 +2,7 @@ package emergency
import (
"fmt"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/controller/user"
"github.com/ente-io/museum/pkg/repo"
@@ -23,6 +24,33 @@ func (c *Controller) UpdateContact(ctx *gin.Context,
if err := validateUpdateReq(userID, req); err != nil {
return stacktrace.Propagate(err, "")
}
if req.State == ente.ContactDenied || req.State == ente.ContactLeft || req.State == ente.UserRevokedContact {
activeSessions, sessionErr := c.Repo.GetActiveSessions(ctx, req.UserID, req.EmergencyContactID)
if sessionErr != nil {
return stacktrace.Propagate(sessionErr, "")
}
for _, session := range activeSessions {
if req.State == ente.UserRevokedContact {
rejErr := c.RejectRecovery(ctx, userID, ente.RecoveryIdentifier{
ID: session.ID,
UserID: session.UserID,
EmergencyContactID: session.EmergencyContactID,
})
if rejErr != nil {
return stacktrace.Propagate(rejErr, "failed to reject recovery")
}
} else {
stopErr := c.StopRecovery(ctx, userID, ente.RecoveryIdentifier{
ID: session.ID,
UserID: session.UserID,
EmergencyContactID: session.EmergencyContactID,
})
if stopErr != nil {
return stacktrace.Propagate(stopErr, "failed to stop recovery")
}
}
}
}
hasUpdate, err := c.Repo.UpdateState(ctx, req.UserID, req.EmergencyContactID, req.State)
if !hasUpdate {
log.WithField("userID", userID).WithField("req", req).
@@ -30,12 +58,6 @@ func (c *Controller) UpdateContact(ctx *gin.Context,
} else {
go c.sendContactNotification(ctx, req.UserID, req.EmergencyContactID, req.State)
}
recoverStatus := getNextRecoveryStatusFromContactState(req.State)
if recoverStatus != nil {
if err := c.Repo.UpdateRecoveryStatus(ctx, req.UserID, req.EmergencyContactID, *recoverStatus); err != nil {
return stacktrace.Propagate(err, "")
}
}
if err != nil {
return stacktrace.Propagate(err, "")
}
@@ -66,20 +88,3 @@ func validateUpdateReq(userID int64, req ente.UpdateContact) error {
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("Can not update state to %s", req.State)), "")
}
}
// When a user contact state is update, we need to update the recovery status for any ongoing recovery
func getNextRecoveryStatusFromContactState(state ente.ContactState) *ente.RecoveryStatus {
switch state {
case ente.ContactAccepted:
return nil
case ente.UserInvitedContact:
return nil
case ente.ContactLeft:
return ente.RecoveryStatusStopped.Ptr()
case ente.ContactDenied:
return ente.RecoveryStatusStopped.Ptr()
case ente.UserRevokedContact:
return ente.RecoveryStatusRejected.Ptr()
}
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
@@ -65,6 +66,24 @@ FROM emergency_recovery WHERE (user_id=$1 OR emergency_contact_id=$1) AND statu
return sessions, nil
}
func (repo *Repository) GetActiveSessions(ctx *gin.Context, userID int64, emergencyContactID int64) ([]*RecoverRow, error) {
rows, err := repo.DB.QueryContext(ctx, `SELECT id, user_id, emergency_contact_id, status, wait_till, next_reminder_at, created_at
FROM emergency_recovery WHERE user_id=$1 and emergency_contact_id=$2 AND status= ANY($3)`, userID, emergencyContactID, pq.Array([]ente.RecoveryStatus{ente.RecoveryStatusWaiting, ente.RecoveryStatusReady}))
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
defer rows.Close()
var sessions []*RecoverRow
for rows.Next() {
var row RecoverRow
if err := rows.Scan(&row.ID, &row.UserID, &row.EmergencyContactID, &row.Status, &row.WaitTill, &row.NextReminderAt, &row.CreatedAt); err != nil {
return nil, stacktrace.Propagate(err, "")
}
sessions = append(sessions, &row)
}
return sessions, nil
}
func (repo *Repository) UpdateRecoveryStatusForID(ctx context.Context, sessionID uuid.UUID, status ente.RecoveryStatus) (bool, error) {
validPrevStatus := validPreviousStatus(status)
var result sql.Result

View File

@@ -118,8 +118,7 @@ func getValidPreviousState(cs ente.ContactState) []ente.ContactState {
return []ente.ContactState{ente.UserInvitedContact}
case ente.UserRevokedContact:
return []ente.ContactState{ente.UserInvitedContact, ente.ContactAccepted}
case ente.ContactDeleted:
return []ente.ContactState{ente.UserInvitedContact, ente.ContactAccepted}
}
panic("invalid state")
}