diff --git a/packages/lib/src/TaskQueue.ts b/packages/lib/src/TaskQueue.ts index 10d936d..ab708a8 100644 --- a/packages/lib/src/TaskQueue.ts +++ b/packages/lib/src/TaskQueue.ts @@ -7,19 +7,21 @@ export type TaskQueueOptions = { export class TaskQueue { items: Item[] = [] + isPaused = false isProcessing = false constructor(readonly options: TaskQueueOptions) {} push(item: Item) { this.items.push(item) - - if (!this.isProcessing) { - this.processBatch() - } + this.process() } - async processBatch() { + async process() { + if (this.isProcessing || this.isPaused) { + return + } + this.isProcessing = true for (const item of this.items.splice(0, this.options.batchSize)) { @@ -30,15 +32,24 @@ export class TaskQueue { } } + this.isProcessing = false + if (this.items.length > 0) { await yieldThread() - this.processBatch() - } else { - this.isProcessing = false + this.process() } } + stop() { + this.isPaused = true + } + + start() { + this.isPaused = false + this.process() + } + clear() { this.items = [] } diff --git a/packages/net2/src/auth.ts b/packages/net2/src/auth.ts index db98012..30c85c8 100644 --- a/packages/net2/src/auth.ts +++ b/packages/net2/src/auth.ts @@ -1,6 +1,7 @@ -import type {SignedEvent} from "@welshman/util" +import {sleep} from "@welshman/lib" +import type {SignedEvent, StampedEvent} from "@welshman/util" import {makeEvent, CLIENT_AUTH} from "@welshman/util" -import type {ISocket} from "./socket.js" +import {Socket, SocketStatus, SocketUnsubscriber} from "./socket.js" export const makeAuthEvent = (url: string, challenge: string) => makeEvent(CLIENT_AUTH, { @@ -10,16 +11,139 @@ export const makeAuthEvent = (url: string, challenge: string) => ], }) +export enum AuthStatus { + None = "auth:status:none", + Requested = "auth:status:requested", + PendingSignature = "auth:status:pending_signature", + DeniedSignature = "auth:status:denied_signature", + PendingResponse = "auth:status:pending_response", + Forbidden = "auth:status:forbidden", + Ok = "auth:status:ok", +} + export type AuthResult = { ok: boolean reason?: string } -export const authenticate = (socket: ISocket, event: SignedEvent) => - new Promise(resolve => { - socket.send(["AUTH", event]) +export type AuthManagerOptions = { + sign: (event: StampedEvent) => Promise + eager?: boolean +} - socket.onOk(([id, ok = false, reason = ""]) => { - if (id === event.id) resolve({ok, reason}) - }) - }) +export class AuthManager { + challenge: string | undefined + request: string | undefined + message: string | undefined + status = AuthStatus.None + _unsubscribers: SocketUnsubscriber[] = [] + + constructor( + readonly socket: Socket, + readonly options: AuthManagerOptions, + ) { + this._unsubscribers.push( + socket.onOk(([id, ok, message]) => { + if (id === this.request) { + this.message = message + + if (ok) { + this.status = AuthStatus.Ok + } else { + this.status = AuthStatus.Forbidden + } + } + }), + ) + + this._unsubscribers.push( + socket.onAuth(([challenge]) => { + this.challenge = challenge + this.request = undefined + this.message = undefined + this.status = AuthStatus.Requested + + if (this.options.eager) { + this.respond() + } + }), + ) + + this._unsubscribers.push( + socket.onStatus(status => { + if (status === SocketStatus.Closed) { + this.challenge = undefined + this.request = undefined + this.message = undefined + this.status = AuthStatus.None + } + }), + ) + } + + async waitFor(condition: () => boolean, timeout = 300) { + const start = Date.now() + + while (Date.now() - timeout <= start) { + if (condition()) { + break + } + + await sleep(Math.min(100, Math.ceil(timeout / 3))) + } + } + + async waitForChallenge(timeout = 300) { + await this.waitFor(() => Boolean(this.challenge), timeout) + } + + async waitForResolution(timeout = 300) { + await this.waitFor( + () => + [AuthStatus.None, AuthStatus.DeniedSignature, AuthStatus.Forbidden, AuthStatus.Ok].includes( + this.status, + ), + timeout, + ) + } + + async attempt(timeout = 300) { + await this.socket.attemptToOpen() + await this.waitForChallenge(Math.ceil(timeout / 2)) + + if (this.status === AuthStatus.Requested) { + await this.respond() + } + + await this.waitForResolution(Math.ceil(timeout / 2)) + } + + async respond() { + if (!this.challenge) { + throw new Error("Attempted to authenticate with no challenge") + } + + if (this.status !== AuthStatus.Requested) { + throw new Error(`Attempted to authenticate when auth is already ${this.status}`) + } + + this.status = AuthStatus.PendingSignature + + const template = makeAuthEvent(this.socket.url, this.challenge) + const event = await this.options.sign(template) + + if (event) { + this.request = event.id + this.socket.send(["AUTH", event]) + this.status = AuthStatus.PendingResponse + } else { + this.status = AuthStatus.DeniedSignature + } + } + + cleanup() { + for (const cb of this._unsubscribers) { + cb() + } + } +} diff --git a/packages/net2/src/message.ts b/packages/net2/src/message.ts index edeca36..4b1d7e5 100644 --- a/packages/net2/src/message.ts +++ b/packages/net2/src/message.ts @@ -1,5 +1,7 @@ import type {SignedEvent} from "@welshman/util" +// relay -> client + export enum RelayMessageType { Auth = "AUTH", Event = "EVENT", @@ -36,3 +38,7 @@ export const isRelayEoseMessage = (m: RelayMessage): m is RelayEoseMessage => export const isRelayOkMessage = (m: RelayMessage): m is RelayOkMessage => m[0] === RelayMessageType.Ok + +// client -> relay + +export type ClientMessage = any[] diff --git a/packages/net2/src/pool.ts b/packages/net2/src/pool.ts new file mode 100644 index 0000000..61dfbf7 --- /dev/null +++ b/packages/net2/src/pool.ts @@ -0,0 +1,71 @@ +import {remove} from "@welshman/lib" +import {normalizeRelayUrl} from "@welshman/util" +import {ISocket, makeSocket} from "./socket.js" + +export type PoolSubscription = (socket: ISocket) => void + +export type PoolOptions = { + makeSocket?: (url: string) => ISocket +} + +export class Pool { + _data = new Map() + _subs: PoolSubscription[] = [] + + constructor(readonly options: PoolOptions) {} + + has(url: string) { + return this._data.has(url) + } + + makeSocket(url: string) { + if (this.options.makeSocket) { + return this.options.makeSocket(url) + } + + return makeSocket(url) + } + + get(_url: string): ISocket { + const url = normalizeRelayUrl(_url) + const oldSocket = this._data.get(url) + + if (oldSocket) { + return oldSocket + } + + const socket = this.makeSocket(url) + + this._data.set(url, socket) + + for (const cb of this._subs) { + cb(socket) + } + + return socket + } + + subscribe(cb: PoolSubscription) { + this._subs.push(cb) + + return () => { + this._subs = remove(cb, this._subs) + } + } + + remove(url: string) { + const socket = this._data.get(url) + + if (socket) { + socket.cleanup() + + this._data.delete(url) + } + } + + clear() { + for (const url of this._data.keys()) { + this.remove(url) + } + } +} diff --git a/packages/net2/src/socket.ts b/packages/net2/src/socket.ts index 2c2520d..c69dc20 100644 --- a/packages/net2/src/socket.ts +++ b/packages/net2/src/socket.ts @@ -1,11 +1,12 @@ import WebSocket from "isomorphic-ws" -import {remove, TaskQueue} from "@welshman/lib" +import {remove, now, ago, TaskQueue} from "@welshman/lib" import type { RelayMessage, RelayAuthPayload, RelayEosePayload, RelayEventPayload, RelayOkPayload, + ClientMessage, } from "./message.js" import { isRelayAuthMessage, @@ -70,7 +71,9 @@ export const isSocketStatusEvent = (event: SocketEvent): event is SocketStatusEv export const isSocketMessageEvent = (event: SocketEvent): event is SocketMessageEvent => event.type === SocketEventType.Message -export type SocketSubscriber = (event: SocketEvent) => void +export type SocketSendSubscriber = (message: ClientMessage) => void + +export type SocketRecvSubscriber = (event: SocketEvent) => void export type SocketUnsubscriber = () => void @@ -78,8 +81,9 @@ export interface ISocket { open(): void close(): void cleanup(): void - send(...message: any[]): void - subscribe(cb: SocketSubscriber): SocketUnsubscriber + send(message: ClientMessage): void + onSend(cb: SocketSendSubscriber): SocketUnsubscriber + subscribe(cb: SocketRecvSubscriber): SocketUnsubscriber onError(cb: (error: string) => void): SocketUnsubscriber onStatus(cb: (status: SocketStatus) => void): SocketUnsubscriber onMessage(cb: (message: RelayMessage) => void): SocketUnsubscriber @@ -92,14 +96,27 @@ export interface ISocket { export class Socket implements ISocket { _ws?: WebSocket - _subs: SocketSubscriber[] = [] - _queue: TaskQueue + _sendSubs: SocketSendSubscriber[] = [] + _recvSubs: SocketRecvSubscriber[] = [] + _sendQueue: TaskQueue + _recvQueue: TaskQueue constructor(readonly url: string) { - this._queue = new TaskQueue({ + this._sendQueue = new TaskQueue({ + batchSize: 50, + processItem: (message: ClientMessage) => { + this._ws?.send(JSON.stringify(message)) + + for (const cb of this._sendSubs) { + cb(message) + } + }, + }) + + this._recvQueue = new TaskQueue({ batchSize: 50, processItem: (event: SocketEvent) => { - for (const cb of this._subs) { + for (const cb of this._recvSubs) { cb(event) } }, @@ -107,12 +124,26 @@ export class Socket implements ISocket { } open = () => { + if (this._ws) { + throw new Error("Attempted to open a websocket that has not been closed") + } + try { this._ws = new WebSocket(this.url) - this._queue.push(makeSocketStatusEvent(SocketStatus.Opening)) - this._ws.onopen = () => this._queue.push(makeSocketStatusEvent(SocketStatus.Open)) - this._ws.onerror = () => this._queue.push(makeSocketStatusEvent(SocketStatus.Error)) - this._ws.onclose = () => this._queue.push(makeSocketStatusEvent(SocketStatus.Closed)) + this._recvQueue.push(makeSocketStatusEvent(SocketStatus.Opening)) + + this._ws.onopen = () => this._recvQueue.push(makeSocketStatusEvent(SocketStatus.Open)) + + this._ws.onerror = () => { + this._recvQueue.push(makeSocketStatusEvent(SocketStatus.Error)) + this._ws = undefined + } + + this._ws.onclose = () => { + this._recvQueue.push(makeSocketStatusEvent(SocketStatus.Closed)) + this._ws = undefined + } + this._ws.onmessage = (event: any) => { const data = event.data as string @@ -120,16 +151,22 @@ export class Socket implements ISocket { const message = JSON.parse(data) if (Array.isArray(message)) { - this._queue.push(makeSocketMessageEvent(message as RelayMessage)) + this._recvQueue.push(makeSocketMessageEvent(message as RelayMessage)) } else { - this._queue.push(makeSocketErrorEvent("Invalid message received")) + this._recvQueue.push(makeSocketErrorEvent("Invalid message received")) } } catch (e) { - this._queue.push(makeSocketErrorEvent("Invalid message received")) + this._recvQueue.push(makeSocketErrorEvent("Invalid message received")) } } } catch (e) { - this._queue.push(makeSocketStatusEvent(SocketStatus.Invalid)) + this._recvQueue.push(makeSocketStatusEvent(SocketStatus.Invalid)) + } + } + + attemptToOpen = () => { + if (!this._ws) { + this.open() } } @@ -140,19 +177,29 @@ export class Socket implements ISocket { cleanup = () => { this.close() - this._subs = [] - this._queue.clear() + this._recvSubs = [] + this._recvQueue.clear() + this._sendSubs = [] + this._sendQueue.clear() } - send = (...message: any[]) => { - this._ws?.send(JSON.stringify(message)) + send = (message: ClientMessage) => { + this._sendQueue.push(message) } - subscribe = (cb: SocketSubscriber) => { - this._subs.push(cb) + onSend = (cb: SocketSendSubscriber) => { + this._sendSubs.push(cb) return () => { - this._subs = remove(cb, this._subs) + this._sendSubs = remove(cb, this._sendSubs) + } + } + + subscribe = (cb: SocketRecvSubscriber) => { + this._recvSubs.push(cb) + + return () => { + this._recvSubs = remove(cb, this._recvSubs) } } @@ -224,3 +271,55 @@ export class Socket implements ISocket { }) } } + +export const socketPolicySendWhenOpen = (socket: Socket) => { + // Pause sending messages when the socket isn't open + const unsubscribe = socket.onStatus(newStatus => { + if (newStatus === SocketStatus.Open) { + socket._sendQueue.start() + } else { + socket._sendQueue.stop() + } + }) + + return unsubscribe +} + +export const socketPolicyConnectOnSend = (socket: Socket) => { + let lastError = 0 + let currentStatus = SocketStatus.Closed + + const unsubscribeOnStatus = socket.onStatus(newStatus => { + // Keep track of the most recent error + if (newStatus === SocketStatus.Error) { + lastError = now() + } + + // Keep track of the current status + currentStatus = newStatus + }) + + const unsubscribeOnSend = socket.onSend(message => { + // When a new message is sent, make sure the socket is open (unless there was a recent error) + if (currentStatus === SocketStatus.Closed && now() - lastError < ago(30)) { + socket.open() + } + }) + + return () => { + unsubscribeOnStatus() + unsubscribeOnSend() + } +} + +export const defaultSocketPolicies = [socketPolicySendWhenOpen, socketPolicyConnectOnSend] + +export const makeSocket = (url: string, policies = defaultSocketPolicies) => { + const socket = new Socket(url) + + for (const applyPolicy of defaultSocketPolicies) { + applyPolicy(socket) + } + + return socket +}