[server] Reject/Stop active recovery when contact is removed
This commit is contained in:
@@ -33,7 +33,6 @@ const (
|
||||
UserInvitedContact ContactState = "INVITED"
|
||||
UserRevokedContact ContactState = "REVOKED"
|
||||
ContactAccepted ContactState = "ACCEPTED"
|
||||
ContactDeleted ContactState = "DELETED"
|
||||
ContactLeft ContactState = "CONTACT_LEFT"
|
||||
ContactDenied ContactState = "CONTACT_DENIED"
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ package emergency
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/museum/pkg/controller/user"
|
||||
"github.com/ente-io/museum/pkg/repo"
|
||||
@@ -23,6 +24,33 @@ func (c *Controller) UpdateContact(ctx *gin.Context,
|
||||
if err := validateUpdateReq(userID, req); err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
if req.State == ente.ContactDenied || req.State == ente.ContactLeft || req.State == ente.UserRevokedContact {
|
||||
activeSessions, sessionErr := c.Repo.GetActiveSessions(ctx, req.UserID, req.EmergencyContactID)
|
||||
if sessionErr != nil {
|
||||
return stacktrace.Propagate(sessionErr, "")
|
||||
}
|
||||
for _, session := range activeSessions {
|
||||
if req.State == ente.UserRevokedContact {
|
||||
rejErr := c.RejectRecovery(ctx, userID, ente.RecoveryIdentifier{
|
||||
ID: session.ID,
|
||||
UserID: session.UserID,
|
||||
EmergencyContactID: session.EmergencyContactID,
|
||||
})
|
||||
if rejErr != nil {
|
||||
return stacktrace.Propagate(rejErr, "failed to reject recovery")
|
||||
}
|
||||
} else {
|
||||
stopErr := c.StopRecovery(ctx, userID, ente.RecoveryIdentifier{
|
||||
ID: session.ID,
|
||||
UserID: session.UserID,
|
||||
EmergencyContactID: session.EmergencyContactID,
|
||||
})
|
||||
if stopErr != nil {
|
||||
return stacktrace.Propagate(stopErr, "failed to stop recovery")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
hasUpdate, err := c.Repo.UpdateState(ctx, req.UserID, req.EmergencyContactID, req.State)
|
||||
if !hasUpdate {
|
||||
log.WithField("userID", userID).WithField("req", req).
|
||||
@@ -30,12 +58,6 @@ func (c *Controller) UpdateContact(ctx *gin.Context,
|
||||
} else {
|
||||
go c.sendContactNotification(ctx, req.UserID, req.EmergencyContactID, req.State)
|
||||
}
|
||||
recoverStatus := getNextRecoveryStatusFromContactState(req.State)
|
||||
if recoverStatus != nil {
|
||||
if err := c.Repo.UpdateRecoveryStatus(ctx, req.UserID, req.EmergencyContactID, *recoverStatus); err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
@@ -66,20 +88,3 @@ func validateUpdateReq(userID int64, req ente.UpdateContact) error {
|
||||
return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("Can not update state to %s", req.State)), "")
|
||||
}
|
||||
}
|
||||
|
||||
// When a user contact state is update, we need to update the recovery status for any ongoing recovery
|
||||
func getNextRecoveryStatusFromContactState(state ente.ContactState) *ente.RecoveryStatus {
|
||||
switch state {
|
||||
case ente.ContactAccepted:
|
||||
return nil
|
||||
case ente.UserInvitedContact:
|
||||
return nil
|
||||
case ente.ContactLeft:
|
||||
return ente.RecoveryStatusStopped.Ptr()
|
||||
case ente.ContactDenied:
|
||||
return ente.RecoveryStatusStopped.Ptr()
|
||||
case ente.UserRevokedContact:
|
||||
return ente.RecoveryStatusRejected.Ptr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
@@ -65,6 +66,24 @@ FROM emergency_recovery WHERE (user_id=$1 OR emergency_contact_id=$1) AND statu
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (repo *Repository) GetActiveSessions(ctx *gin.Context, userID int64, emergencyContactID int64) ([]*RecoverRow, error) {
|
||||
rows, err := repo.DB.QueryContext(ctx, `SELECT id, user_id, emergency_contact_id, status, wait_till, next_reminder_at, created_at
|
||||
FROM emergency_recovery WHERE user_id=$1 and emergency_contact_id=$2 AND status= ANY($3)`, userID, emergencyContactID, pq.Array([]ente.RecoveryStatus{ente.RecoveryStatusWaiting, ente.RecoveryStatusReady}))
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
defer rows.Close()
|
||||
var sessions []*RecoverRow
|
||||
for rows.Next() {
|
||||
var row RecoverRow
|
||||
if err := rows.Scan(&row.ID, &row.UserID, &row.EmergencyContactID, &row.Status, &row.WaitTill, &row.NextReminderAt, &row.CreatedAt); err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
sessions = append(sessions, &row)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (repo *Repository) UpdateRecoveryStatusForID(ctx context.Context, sessionID uuid.UUID, status ente.RecoveryStatus) (bool, error) {
|
||||
validPrevStatus := validPreviousStatus(status)
|
||||
var result sql.Result
|
||||
|
||||
@@ -118,8 +118,7 @@ func getValidPreviousState(cs ente.ContactState) []ente.ContactState {
|
||||
return []ente.ContactState{ente.UserInvitedContact}
|
||||
case ente.UserRevokedContact:
|
||||
return []ente.ContactState{ente.UserInvitedContact, ente.ContactAccepted}
|
||||
case ente.ContactDeleted:
|
||||
return []ente.ContactState{ente.UserInvitedContact, ente.ContactAccepted}
|
||||
|
||||
}
|
||||
panic("invalid state")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user