Add policy for deferring messages when auth has failed
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import {yieldThread} from "./Tools.js"
|
||||
import {remove, yieldThread} from "./Tools.js"
|
||||
|
||||
export type TaskQueueOptions<Item> = {
|
||||
batchSize: number
|
||||
@@ -17,13 +17,19 @@ export class TaskQueue<Item> {
|
||||
this.process()
|
||||
}
|
||||
|
||||
remove(item: Item) {
|
||||
this.items = remove(item, this.items)
|
||||
}
|
||||
|
||||
async process() {
|
||||
if (this.isProcessing || this.isPaused) {
|
||||
if (this.isProcessing || this.isPaused || this.items.length == 0) {
|
||||
return
|
||||
}
|
||||
|
||||
this.isProcessing = true
|
||||
|
||||
await yieldThread()
|
||||
|
||||
for (const item of this.items.splice(0, this.options.batchSize)) {
|
||||
try {
|
||||
await this.options.processItem(item)
|
||||
@@ -35,8 +41,6 @@ export class TaskQueue<Item> {
|
||||
this.isProcessing = false
|
||||
|
||||
if (this.items.length > 0) {
|
||||
await yieldThread()
|
||||
|
||||
this.process()
|
||||
}
|
||||
}
|
||||
|
||||
+67
-31
@@ -1,8 +1,10 @@
|
||||
import {on, sleep} from "@welshman/lib"
|
||||
import EventEmitter from "events"
|
||||
import {on, call, sleep} from "@welshman/lib"
|
||||
import type {SignedEvent, StampedEvent} from "@welshman/util"
|
||||
import {makeEvent, CLIENT_AUTH} from "@welshman/util"
|
||||
import {isRelayAuth, isRelayOk, RelayMessage} from "./message.js"
|
||||
import {isRelayAuth, isClientAuth, isRelayOk, RelayMessage} from "./message.js"
|
||||
import {Socket, SocketStatus, SocketEventType, SocketUnsubscriber} from "./socket.js"
|
||||
import {TypedEmitter} from "./util.js"
|
||||
|
||||
export const makeAuthEvent = (url: string, challenge: string) =>
|
||||
makeEvent(CLIENT_AUTH, {
|
||||
@@ -27,22 +29,24 @@ export type AuthResult = {
|
||||
reason?: string
|
||||
}
|
||||
|
||||
export type AuthManagerOptions = {
|
||||
sign: (event: StampedEvent) => Promise<SignedEvent>
|
||||
eager?: boolean
|
||||
export enum AuthStateEventType {
|
||||
Status = "auth:state:event:status",
|
||||
}
|
||||
|
||||
export class AuthManager {
|
||||
export type AuthStateEvents = {
|
||||
[AuthStateEventType.Status]: (status: AuthStatus) => void
|
||||
}
|
||||
|
||||
export class AuthState extends (EventEmitter as new () => TypedEmitter<AuthStateEvents>) {
|
||||
challenge: string | undefined
|
||||
request: string | undefined
|
||||
details: string | undefined
|
||||
status = AuthStatus.None
|
||||
_unsubscribers: SocketUnsubscriber[] = []
|
||||
|
||||
constructor(
|
||||
readonly socket: Socket,
|
||||
readonly options: AuthManagerOptions,
|
||||
) {
|
||||
constructor(readonly socket: Socket) {
|
||||
super()
|
||||
|
||||
this._unsubscribers.push(
|
||||
on(socket, SocketEventType.Receive, (message: RelayMessage) => {
|
||||
if (isRelayOk(message)) {
|
||||
@@ -52,9 +56,9 @@ export class AuthManager {
|
||||
this.details = details
|
||||
|
||||
if (ok) {
|
||||
this.status = AuthStatus.Ok
|
||||
this.setStatus(AuthStatus.Ok)
|
||||
} else {
|
||||
this.status = AuthStatus.Forbidden
|
||||
this.setStatus(AuthStatus.Forbidden)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -65,11 +69,15 @@ export class AuthManager {
|
||||
this.challenge = challenge
|
||||
this.request = undefined
|
||||
this.details = undefined
|
||||
this.status = AuthStatus.Requested
|
||||
this.setStatus(AuthStatus.Requested)
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
if (this.options.eager) {
|
||||
this.respond()
|
||||
}
|
||||
this._unsubscribers.push(
|
||||
on(socket, SocketEventType.Enqueue, (message: RelayMessage) => {
|
||||
if (isClientAuth(message)) {
|
||||
this.setStatus(AuthStatus.PendingResponse)
|
||||
}
|
||||
}),
|
||||
)
|
||||
@@ -80,12 +88,43 @@ export class AuthManager {
|
||||
this.challenge = undefined
|
||||
this.request = undefined
|
||||
this.details = undefined
|
||||
this.status = AuthStatus.None
|
||||
this.setStatus(AuthStatus.None)
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
setStatus(status: AuthStatus) {
|
||||
this.status = status
|
||||
this.emit(AuthStateEventType.Status, status)
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
this.removeAllListeners()
|
||||
this._unsubscribers.forEach(call)
|
||||
}
|
||||
}
|
||||
|
||||
export type AuthManagerOptions = {
|
||||
sign: (event: StampedEvent) => Promise<SignedEvent>
|
||||
eager?: boolean
|
||||
}
|
||||
|
||||
export class AuthManager {
|
||||
state: AuthState
|
||||
|
||||
constructor(
|
||||
readonly socket: Socket,
|
||||
readonly options: AuthManagerOptions,
|
||||
) {
|
||||
this.state = new AuthState(socket)
|
||||
this.state.on(AuthStateEventType.Status, (status: string) => {
|
||||
if (status === AuthStatus.Requested && options.eager) {
|
||||
this.respond()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async waitFor(condition: () => boolean, timeout = 300) {
|
||||
const start = Date.now()
|
||||
|
||||
@@ -99,14 +138,14 @@ export class AuthManager {
|
||||
}
|
||||
|
||||
async waitForChallenge(timeout = 300) {
|
||||
await this.waitFor(() => Boolean(this.challenge), timeout)
|
||||
await this.waitFor(() => Boolean(this.state.challenge), timeout)
|
||||
}
|
||||
|
||||
async waitForResolution(timeout = 300) {
|
||||
await this.waitFor(
|
||||
() =>
|
||||
[AuthStatus.None, AuthStatus.DeniedSignature, AuthStatus.Forbidden, AuthStatus.Ok].includes(
|
||||
this.status,
|
||||
this.state.status,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
@@ -116,7 +155,7 @@ export class AuthManager {
|
||||
await this.socket.attemptToOpen()
|
||||
await this.waitForChallenge(Math.ceil(timeout / 2))
|
||||
|
||||
if (this.status === AuthStatus.Requested) {
|
||||
if (this.state.status === AuthStatus.Requested) {
|
||||
await this.respond()
|
||||
}
|
||||
|
||||
@@ -124,31 +163,28 @@ export class AuthManager {
|
||||
}
|
||||
|
||||
async respond() {
|
||||
if (!this.challenge) {
|
||||
if (!this.state.challenge) {
|
||||
throw new Error("Attempted to authenticate with no challenge")
|
||||
}
|
||||
|
||||
if (this.status !== AuthStatus.Requested) {
|
||||
throw new Error(`Attempted to authenticate when auth is already ${this.status}`)
|
||||
if (this.state.status !== AuthStatus.Requested) {
|
||||
throw new Error(`Attempted to authenticate when auth is already ${this.state.status}`)
|
||||
}
|
||||
|
||||
this.status = AuthStatus.PendingSignature
|
||||
this.state.setStatus(AuthStatus.PendingSignature)
|
||||
|
||||
const template = makeAuthEvent(this.socket.url, this.challenge)
|
||||
const template = makeAuthEvent(this.socket.url, this.state.challenge)
|
||||
const event = await this.options.sign(template)
|
||||
|
||||
if (event) {
|
||||
this.request = event.id
|
||||
this.state.request = event.id
|
||||
this.socket.send(["AUTH", event])
|
||||
this.status = AuthStatus.PendingResponse
|
||||
} else {
|
||||
this.status = AuthStatus.DeniedSignature
|
||||
this.state.setStatus(AuthStatus.DeniedSignature)
|
||||
}
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
for (const cb of this._unsubscribers) {
|
||||
cb()
|
||||
}
|
||||
this.state.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type {SignedEvent} from "@welshman/util"
|
||||
import type {SignedEvent, Filter} from "@welshman/util"
|
||||
|
||||
// relay -> client
|
||||
|
||||
@@ -53,35 +53,51 @@ export const isRelayOk = (m: RelayMessage): m is RelayOk => m[0] === RelayMessag
|
||||
|
||||
export enum ClientMessageType {
|
||||
Auth = "AUTH",
|
||||
Close = "CLOSE",
|
||||
Event = "EVENT",
|
||||
NegClose = "NEG-CLOSE",
|
||||
NegOpen = "NEG-OPEN",
|
||||
Req = "REQ",
|
||||
}
|
||||
|
||||
export type ClientMessage = any[]
|
||||
|
||||
export type ClientAuthPayload = []
|
||||
export type ClientAuthPayload = [string]
|
||||
|
||||
export type ClientEventPayload = []
|
||||
export type ClientClosePayload = [string]
|
||||
|
||||
export type ClientNegClosePayload = []
|
||||
export type ClientEventPayload = [SignedEvent]
|
||||
|
||||
export type ClientReqPayload = []
|
||||
export type ClientNegClosePayload = [string]
|
||||
|
||||
export type ClientAuth = [ClientMessageType.Req, ...ClientAuthPayload]
|
||||
export type ClientNegOpenPayload = [string, Filter, string]
|
||||
|
||||
export type ClientEvent = [ClientMessageType.Req, ...ClientEventPayload]
|
||||
export type ClientReqPayload = [string, Filter]
|
||||
|
||||
export type ClientNegClose = [ClientMessageType.Req, ...ClientNegClosePayload]
|
||||
export type ClientAuth = [ClientMessageType.Auth, ...ClientAuthPayload]
|
||||
|
||||
export type ClientClose = [ClientMessageType.Close, ...ClientClosePayload]
|
||||
|
||||
export type ClientEvent = [ClientMessageType.Event, ...ClientEventPayload]
|
||||
|
||||
export type ClientNegClose = [ClientMessageType.NegClose, ...ClientNegClosePayload]
|
||||
|
||||
export type ClientNegOpen = [ClientMessageType.NegOpen, ...ClientNegOpenPayload]
|
||||
|
||||
export type ClientReq = [ClientMessageType.Req, ...ClientReqPayload]
|
||||
|
||||
export const isClientAuth = (m: ClientMessage): m is ClientAuth => m[0] === ClientMessageType.Auth
|
||||
|
||||
export const isClientClose = (m: ClientMessage): m is ClientClose =>
|
||||
m[0] === ClientMessageType.Close
|
||||
|
||||
export const isClientEvent = (m: ClientMessage): m is ClientEvent =>
|
||||
m[0] === ClientMessageType.Event
|
||||
|
||||
export const isClientNegClose = (m: ClientMessage): m is ClientNegClose =>
|
||||
m[0] === ClientMessageType.NegClose
|
||||
|
||||
export const isClientNegOpen = (m: ClientMessage): m is ClientNegOpen =>
|
||||
m[0] === ClientMessageType.NegOpen
|
||||
|
||||
export const isClientReq = (m: ClientMessage): m is ClientReq => m[0] === ClientMessageType.Req
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
import {on, spec, ago, now} from "@welshman/lib"
|
||||
import {AUTH_JOIN} from "@welshman/util"
|
||||
import {
|
||||
ClientMessage,
|
||||
isClientAuth,
|
||||
isClientClose,
|
||||
isClientEvent,
|
||||
ClientMessageType,
|
||||
} from "./message.js"
|
||||
import {Socket, SocketStatus, SocketEventType} from "./socket.js"
|
||||
import {AuthState, AuthStatus, AuthStateEventType} from "./auth.js"
|
||||
|
||||
// Pause sending messages when the socket isn't open
|
||||
export const socketPolicySendWhenOpen = (socket: Socket) => {
|
||||
const unsubscribe = on(socket, SocketEventType.Status, (newStatus: SocketStatus) => {
|
||||
if (newStatus === SocketStatus.Open) {
|
||||
socket._sendQueue.start()
|
||||
} else {
|
||||
socket._sendQueue.stop()
|
||||
}
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}
|
||||
|
||||
export const socketPolicyDeferOnAuth = (socket: Socket) => {
|
||||
const buffer: ClientMessage[] = []
|
||||
const authState = new AuthState(socket)
|
||||
const okStatuses = [AuthStatus.None, AuthStatus.Ok]
|
||||
|
||||
// Pause sending certain messages when we're not authenticated
|
||||
const unsubscribeEnqueue = on(socket, SocketEventType.Enqueue, (message: ClientMessage) => {
|
||||
// If we're closing a request, but it never got sent, remove both from the queue
|
||||
// Otherwise, always send CLOSE
|
||||
if (isClientClose(message)) {
|
||||
const req = buffer.find(spec([ClientMessageType.Req, message[1]]))
|
||||
|
||||
if (req) {
|
||||
socket._sendQueue.remove(req)
|
||||
socket._sendQueue.remove(message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Always allow sending auth
|
||||
if (isClientAuth(message)) return
|
||||
|
||||
// Always allow sending join requests
|
||||
if (isClientEvent(message) && message[1].kind === AUTH_JOIN) return
|
||||
|
||||
// If we're not ok, remove the message and save it for later
|
||||
if (!okStatuses.includes(authState.status)) {
|
||||
buffer.push(message)
|
||||
socket._sendQueue.remove(message)
|
||||
}
|
||||
})
|
||||
|
||||
// Send buffered messages when we get successful auth
|
||||
const unsubscribeAuthStatus = on(authState, AuthStateEventType.Status, (status: AuthStatus) => {
|
||||
if (okStatuses.includes(status) && buffer.length > 0) {
|
||||
for (const message of buffer.splice(0)) {
|
||||
socket.send(message)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
unsubscribeAuthStatus()
|
||||
unsubscribeEnqueue()
|
||||
authState.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
export const socketPolicyConnectOnSend = (socket: Socket) => {
|
||||
let lastError = 0
|
||||
let currentStatus = SocketStatus.Closed
|
||||
|
||||
const unsubscribeOnStatus = on(socket, SocketEventType.Status, (newStatus: SocketStatus) => {
|
||||
// Keep track of the most recent error
|
||||
if (newStatus === SocketStatus.Error) {
|
||||
lastError = now()
|
||||
}
|
||||
|
||||
// Keep track of the current status
|
||||
currentStatus = newStatus
|
||||
})
|
||||
|
||||
const unsubscribeOnSend = on(socket, SocketEventType.Send, (message: ClientMessage) => {
|
||||
// When a new message is sent, make sure the socket is open (unless there was a recent error)
|
||||
if (currentStatus === SocketStatus.Closed && now() - lastError < ago(30)) {
|
||||
socket.open()
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
unsubscribeOnStatus()
|
||||
unsubscribeOnSend()
|
||||
}
|
||||
}
|
||||
|
||||
export const defaultSocketPolicies = [
|
||||
socketPolicySendWhenOpen,
|
||||
socketPolicyDeferOnAuth,
|
||||
socketPolicyConnectOnSend,
|
||||
]
|
||||
@@ -1,6 +1,17 @@
|
||||
import {remove} from "@welshman/lib"
|
||||
import {normalizeRelayUrl} from "@welshman/util"
|
||||
import {Socket, makeSocket} from "./socket.js"
|
||||
import {Socket} from "./socket.js"
|
||||
import {defaultSocketPolicies} from "./policy.js"
|
||||
|
||||
export const makeSocket = (url: string, policies = defaultSocketPolicies) => {
|
||||
const socket = new Socket(url)
|
||||
|
||||
for (const applyPolicy of defaultSocketPolicies) {
|
||||
applyPolicy(socket)
|
||||
}
|
||||
|
||||
return socket
|
||||
}
|
||||
|
||||
export type PoolSubscription = (socket: Socket) => void
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import WebSocket from "isomorphic-ws"
|
||||
import EventEmitter from "events"
|
||||
import {on, now, ago, TaskQueue} from "@welshman/lib"
|
||||
import {TaskQueue} from "@welshman/lib"
|
||||
import {RelayMessage, ClientMessage} from "./message.js"
|
||||
import {TypedEmitter} from "./util.js"
|
||||
|
||||
@@ -17,6 +17,7 @@ export enum SocketEventType {
|
||||
Error = "socket:event:error",
|
||||
Status = "socket:event:status",
|
||||
Send = "socket:event:send",
|
||||
Enqueue = "socket:event:enqueue",
|
||||
Receive = "socket:event:receive",
|
||||
}
|
||||
|
||||
@@ -24,6 +25,7 @@ export type SocketEvents = {
|
||||
[SocketEventType.Error]: (error: string, url: string) => void
|
||||
[SocketEventType.Status]: (status: SocketStatus, url: string) => void
|
||||
[SocketEventType.Send]: (message: ClientMessage, url: string) => void
|
||||
[SocketEventType.Enqueue]: (message: ClientMessage, url: string) => void
|
||||
[SocketEventType.Receive]: (message: RelayMessage, url: string) => void
|
||||
}
|
||||
|
||||
@@ -113,57 +115,6 @@ export class Socket extends (EventEmitter as new () => TypedEmitter<SocketEvents
|
||||
|
||||
send = (message: ClientMessage) => {
|
||||
this._sendQueue.push(message)
|
||||
this.emit(SocketEventType.Enqueue, message, this.url)
|
||||
}
|
||||
}
|
||||
|
||||
export const socketPolicySendWhenOpen = (socket: Socket) => {
|
||||
// Pause sending messages when the socket isn't open
|
||||
const unsubscribe = on(socket, SocketEventType.Status, newStatus => {
|
||||
if (newStatus === SocketStatus.Open) {
|
||||
socket._sendQueue.start()
|
||||
} else {
|
||||
socket._sendQueue.stop()
|
||||
}
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}
|
||||
|
||||
export const socketPolicyConnectOnSend = (socket: Socket) => {
|
||||
let lastError = 0
|
||||
let currentStatus = SocketStatus.Closed
|
||||
|
||||
const unsubscribeOnStatus = on(socket, SocketEventType.Status, (newStatus: SocketStatus) => {
|
||||
// Keep track of the most recent error
|
||||
if (newStatus === SocketStatus.Error) {
|
||||
lastError = now()
|
||||
}
|
||||
|
||||
// Keep track of the current status
|
||||
currentStatus = newStatus
|
||||
})
|
||||
|
||||
const unsubscribeOnSend = on(socket, SocketEventType.Send, (message: ClientMessage) => {
|
||||
// When a new message is sent, make sure the socket is open (unless there was a recent error)
|
||||
if (currentStatus === SocketStatus.Closed && now() - lastError < ago(30)) {
|
||||
socket.open()
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
unsubscribeOnStatus()
|
||||
unsubscribeOnSend()
|
||||
}
|
||||
}
|
||||
|
||||
export const defaultSocketPolicies = [socketPolicySendWhenOpen, socketPolicyConnectOnSend]
|
||||
|
||||
export const makeSocket = (url: string, policies = defaultSocketPolicies) => {
|
||||
const socket = new Socket(url)
|
||||
|
||||
for (const applyPolicy of defaultSocketPolicies) {
|
||||
applyPolicy(socket)
|
||||
}
|
||||
|
||||
return socket
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user