Put auth inside socket

This commit is contained in:
Jon Staab
2025-04-02 14:01:37 -07:00
parent 932a08b7b1
commit 7440a07ffc
7 changed files with 36 additions and 52 deletions
+1 -1
View File
@@ -77,7 +77,7 @@ describe('auth', () => {
it("should handle client AUTH message", () => { it("should handle client AUTH message", () => {
const message: RelayMessage = ["AUTH", { id: "123", kind: CLIENT_AUTH }] const message: RelayMessage = ["AUTH", { id: "123", kind: CLIENT_AUTH }]
socket.emit(SocketEvent.Enqueue, message) socket.emit(SocketEvent.Sending, message)
expect(authManager.state.status).toBe(AuthStatus.PendingResponse) expect(authManager.state.status).toBe(AuthStatus.PendingResponse)
}) })
+4 -4
View File
@@ -205,7 +205,7 @@ describe('policy', () => {
// Send a message // Send a message
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }] const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.emit(SocketEvent.Enqueue, event) socket.emit(SocketEvent.Sending, event)
// Should open the socket // Should open the socket
expect(openSpy).toHaveBeenCalled() expect(openSpy).toHaveBeenCalled()
@@ -222,7 +222,7 @@ describe('policy', () => {
// Send a message // Send a message
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }] const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.emit(SocketEvent.Enqueue, event) socket.emit(SocketEvent.Sending, event)
// Should not try to open the socket // Should not try to open the socket
expect(openSpy).not.toHaveBeenCalled() expect(openSpy).not.toHaveBeenCalled()
@@ -240,7 +240,7 @@ describe('policy', () => {
// Send a message // Send a message
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }] const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.emit(SocketEvent.Enqueue, event) socket.emit(SocketEvent.Sending, event)
// Should not try to open the socket due to recent error // Should not try to open the socket due to recent error
expect(openSpy).not.toHaveBeenCalled() expect(openSpy).not.toHaveBeenCalled()
@@ -249,7 +249,7 @@ describe('policy', () => {
vi.advanceTimersByTime(31000) vi.advanceTimersByTime(31000)
// Send another message // Send another message
socket.emit(SocketEvent.Enqueue, event) socket.emit(SocketEvent.Sending, event)
// Now it should try to open // Now it should try to open
expect(openSpy).toHaveBeenCalled() expect(openSpy).toHaveBeenCalled()
+1 -1
View File
@@ -89,7 +89,7 @@ describe("Socket", () => {
describe("send", () => { describe("send", () => {
it("should queue messages and emit enqueue event", () => { it("should queue messages and emit enqueue event", () => {
const enqueueSpy = vi.fn() const enqueueSpy = vi.fn()
socket.on(SocketEvent.Enqueue, enqueueSpy) socket.on(SocketEvent.Sending, enqueueSpy)
const message: ClientMessage = ["EVENT", { id: "123", kind: 1 }] const message: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.send(message) socket.send(message)
+1 -9
View File
@@ -52,8 +52,6 @@ export class AuthState extends EventEmitter {
if (isRelayOk(message)) { if (isRelayOk(message)) {
const [_, id, ok, details] = message const [_, id, ok, details] = message
console.log("ok", message)
if (id === this.request) { if (id === this.request) {
this.details = details this.details = details
@@ -68,23 +66,19 @@ export class AuthState extends EventEmitter {
if (isRelayAuth(message)) { if (isRelayAuth(message)) {
const [_, challenge] = message const [_, challenge] = message
console.log("relay auth", message)
this.challenge = challenge this.challenge = challenge
this.request = undefined this.request = undefined
this.details = undefined this.details = undefined
this.setStatus(AuthStatus.Requested) this.setStatus(AuthStatus.Requested)
} }
}), }),
on(socket, SocketEvent.Enqueue, (message: RelayMessage) => { on(socket, SocketEvent.Sending, (message: RelayMessage) => {
if (isClientAuth(message)) { if (isClientAuth(message)) {
console.log("client auth", message)
this.setStatus(AuthStatus.PendingResponse) this.setStatus(AuthStatus.PendingResponse)
} }
}), }),
on(socket, SocketEvent.Status, (status: SocketStatus) => { on(socket, SocketEvent.Status, (status: SocketStatus) => {
if (status === SocketStatus.Closed) { if (status === SocketStatus.Closed) {
console.log("closed")
this.challenge = undefined this.challenge = undefined
this.request = undefined this.request = undefined
this.details = undefined this.details = undefined
@@ -113,8 +107,6 @@ export class AuthState extends EventEmitter {
const template = makeAuthEvent(this.socket.url, this.challenge) const template = makeAuthEvent(this.socket.url, this.challenge)
const event = await sign(template) const event = await sign(template)
console.log(event)
if (event) { if (event) {
this.request = event.id this.request = event.id
this.socket.send(["AUTH", event]) this.socket.send(["AUTH", event])
+7 -11
View File
@@ -12,7 +12,7 @@ import {
isRelayClosed, isRelayClosed,
} from "./message.js" } from "./message.js"
import {Socket, SocketStatus, SocketEvent} from "./socket.js" import {Socket, SocketStatus, SocketEvent} from "./socket.js"
import {AuthState, AuthStatus, AuthStateEvent} from "./auth.js" import {AuthStatus, AuthStateEvent} from "./auth.js"
/** /**
* Defers sending messages when a challenge has been presented and not answered yet * Defers sending messages when a challenge has been presented and not answered yet
@@ -21,12 +21,11 @@ import {AuthState, AuthStatus, AuthStateEvent} from "./auth.js"
*/ */
export const socketPolicyDeferOnAuth = (socket: Socket) => { export const socketPolicyDeferOnAuth = (socket: Socket) => {
const buffer: ClientMessage[] = [] const buffer: ClientMessage[] = []
const authState = new AuthState(socket)
const okStatuses = [AuthStatus.None, AuthStatus.Ok] const okStatuses = [AuthStatus.None, AuthStatus.Ok]
const unsubscribers = [ const unsubscribers = [
// Pause sending certain messages when we're not authenticated // Pause sending certain messages when we're not authenticated
on(socket, SocketEvent.Enqueue, (message: ClientMessage) => { on(socket, SocketEvent.Sending, (message: ClientMessage) => {
// If we're closing a request, but it never got sent, remove both from the queue // If we're closing a request, but it never got sent, remove both from the queue
// Otherwise, always send CLOSE // Otherwise, always send CLOSE
if (isClientClose(message)) { if (isClientClose(message)) {
@@ -47,13 +46,13 @@ export const socketPolicyDeferOnAuth = (socket: Socket) => {
if (isClientEvent(message) && message[1].kind === AUTH_JOIN) return if (isClientEvent(message) && message[1].kind === AUTH_JOIN) return
// If we're not ok, remove the message and save it for later // If we're not ok, remove the message and save it for later
if (!okStatuses.includes(authState.status)) { if (!okStatuses.includes(socket.auth.status)) {
buffer.push(message) buffer.push(message)
socket._sendQueue.remove(message) socket._sendQueue.remove(message)
} }
}), }),
// Send buffered messages when we get successful auth // Send buffered messages when we get successful auth
on(authState, AuthStateEvent.Status, (status: AuthStatus) => { on(socket.auth, AuthStateEvent.Status, (status: AuthStatus) => {
if (okStatuses.includes(status) && buffer.length > 0) { if (okStatuses.includes(status) && buffer.length > 0) {
for (const message of buffer.splice(0)) { for (const message of buffer.splice(0)) {
socket.send(message) socket.send(message)
@@ -64,7 +63,6 @@ export const socketPolicyDeferOnAuth = (socket: Socket) => {
return () => { return () => {
unsubscribers.forEach(call) unsubscribers.forEach(call)
authState.cleanup()
} }
} }
@@ -142,7 +140,7 @@ export const socketPolicyConnectOnSend = (socket: Socket) => {
lastError = now() lastError = now()
} }
}), }),
on(socket, SocketEvent.Enqueue, (message: ClientMessage) => { on(socket, SocketEvent.Sending, (message: ClientMessage) => {
// When a new message is sent, make sure the socket is open (unless there was a recent error) // When a new message is sent, make sure the socket is open (unless there was a recent error)
if (socket.status === SocketStatus.Closed && lastError < ago(30)) { if (socket.status === SocketStatus.Closed && lastError < ago(30)) {
socket.open() socket.open()
@@ -242,20 +240,18 @@ export type SocketPolicyAuthOptions = {
* @return a socket policy * @return a socket policy
*/ */
export const makeSocketPolicyAuth = (options: SocketPolicyAuthOptions) => (socket: Socket) => { export const makeSocketPolicyAuth = (options: SocketPolicyAuthOptions) => (socket: Socket) => {
const authState = new AuthState(socket)
const shouldAuth = options.shouldAuth || always(true) const shouldAuth = options.shouldAuth || always(true)
const unsubscribers = [ const unsubscribers = [
on(authState, AuthStateEvent.Status, (status: AuthStatus) => { on(socket.auth, AuthStateEvent.Status, (status: AuthStatus) => {
if (status === AuthStatus.Requested && shouldAuth(socket)) { if (status === AuthStatus.Requested && shouldAuth(socket)) {
authState.authenticate(options.sign) socket.auth.authenticate(options.sign)
} }
}), }),
] ]
return () => { return () => {
unsubscribers.forEach(call) unsubscribers.forEach(call)
authState.cleanup()
} }
} }
+11 -22
View File
@@ -1,7 +1,6 @@
import {remove} from "@welshman/lib" import {remove} from "@welshman/lib"
import {normalizeRelayUrl} from "@welshman/util" import {normalizeRelayUrl} from "@welshman/util"
import {Socket} from "./socket.js" import {Socket} from "./socket.js"
import {AuthState} from "./auth.js"
import {defaultSocketPolicies} from "./policy.js" import {defaultSocketPolicies} from "./policy.js"
export const makeSocket = (url: string, policies = defaultSocketPolicies) => { export const makeSocket = (url: string, policies = defaultSocketPolicies) => {
@@ -22,13 +21,8 @@ export type PoolOptions = {
export let poolSingleton: Pool export let poolSingleton: Pool
export type PoolItem = {
socket: Socket
auth: AuthState
}
export class Pool { export class Pool {
_data = new Map<string, PoolItem>() _data = new Map<string, Socket>()
_subs: PoolSubscription[] = [] _subs: PoolSubscription[] = []
static getSingleton() { static getSingleton() {
@@ -55,25 +49,21 @@ export class Pool {
get(_url: string): Socket { get(_url: string): Socket {
const url = normalizeRelayUrl(_url) const url = normalizeRelayUrl(_url)
const item = this._data.get(url) const socket = this._data.get(url)
if (item) { if (socket) {
return item.socket return socket
} }
const socket = this.makeSocket(url) const newSocket = this.makeSocket(url)
this._data.set(url, {socket, auth: new AuthState(socket)}) this._data.set(url, newSocket)
for (const cb of this._subs) { for (const cb of this._subs) {
cb(socket) cb(newSocket)
} }
return socket return newSocket
}
getAuth(url: string) {
return this._data.get(normalizeRelayUrl(url))?.auth
} }
subscribe(cb: PoolSubscription) { subscribe(cb: PoolSubscription) {
@@ -85,11 +75,10 @@ export class Pool {
} }
remove(url: string) { remove(url: string) {
const item = this._data.get(url) const socket = this._data.get(url)
if (item) { if (socket) {
item.socket.cleanup() socket.cleanup()
item.auth.cleanup()
this._data.delete(url) this._data.delete(url)
} }
+11 -4
View File
@@ -2,6 +2,7 @@ import WebSocket from "isomorphic-ws"
import EventEmitter from "events" import EventEmitter from "events"
import {TaskQueue} from "@welshman/lib" import {TaskQueue} from "@welshman/lib"
import {RelayMessage, ClientMessage} from "./message.js" import {RelayMessage, ClientMessage} from "./message.js"
import {AuthState} from "./auth.js"
export enum SocketStatus { export enum SocketStatus {
Open = "socket:status:open", Open = "socket:status:open",
@@ -16,19 +17,22 @@ export enum SocketEvent {
Error = "socket:event:error", Error = "socket:event:error",
Status = "socket:event:status", Status = "socket:event:status",
Send = "socket:event:send", Send = "socket:event:send",
Enqueue = "socket:event:enqueue", Sending = "socket:event:sending",
Receive = "socket:event:receive", Receive = "socket:event:receive",
Receiving = "socket:event:receiving",
} }
export type SocketEvents = { export type SocketEvents = {
[SocketEvent.Error]: (error: string, url: string) => void [SocketEvent.Error]: (error: string, url: string) => void
[SocketEvent.Status]: (status: SocketStatus, url: string) => void [SocketEvent.Status]: (status: SocketStatus, url: string) => void
[SocketEvent.Send]: (message: ClientMessage, url: string) => void [SocketEvent.Send]: (message: ClientMessage, url: string) => void
[SocketEvent.Enqueue]: (message: ClientMessage, url: string) => void [SocketEvent.Sending]: (message: ClientMessage, url: string) => void
[SocketEvent.Receive]: (message: RelayMessage, url: string) => void [SocketEvent.Receive]: (message: RelayMessage, url: string) => void
[SocketEvent.Receiving]: (message: RelayMessage, url: string) => void
} }
export class Socket extends EventEmitter { export class Socket extends EventEmitter {
auth: AuthState
status = SocketStatus.Closed status = SocketStatus.Closed
_ws?: WebSocket _ws?: WebSocket
@@ -38,6 +42,8 @@ export class Socket extends EventEmitter {
constructor(readonly url: string) { constructor(readonly url: string) {
super() super()
this.auth = new AuthState(this)
this._sendQueue = new TaskQueue<ClientMessage>({ this._sendQueue = new TaskQueue<ClientMessage>({
batchSize: 50, batchSize: 50,
processItem: (message: ClientMessage) => { processItem: (message: ClientMessage) => {
@@ -84,7 +90,6 @@ export class Socket extends EventEmitter {
this._ws.onclose = () => { this._ws.onclose = () => {
this._ws = undefined this._ws = undefined
this._sendQueue.stop() this._sendQueue.stop()
console.log("socket closed", this.url)
this.emit(SocketEvent.Status, SocketStatus.Closed, this.url) this.emit(SocketEvent.Status, SocketStatus.Closed, this.url)
} }
@@ -96,6 +101,7 @@ export class Socket extends EventEmitter {
if (Array.isArray(message)) { if (Array.isArray(message)) {
this._recvQueue.push(message as RelayMessage) this._recvQueue.push(message as RelayMessage)
this.emit(SocketEvent.Receiving, message, this.url)
} else { } else {
this.emit(SocketEvent.Error, "Invalid message received", this.url) this.emit(SocketEvent.Error, "Invalid message received", this.url)
} }
@@ -121,6 +127,7 @@ export class Socket extends EventEmitter {
cleanup = () => { cleanup = () => {
this.close() this.close()
this.auth.cleanup()
this._recvQueue.clear() this._recvQueue.clear()
this._sendQueue.clear() this._sendQueue.clear()
this.removeAllListeners() this.removeAllListeners()
@@ -128,6 +135,6 @@ export class Socket extends EventEmitter {
send = (message: ClientMessage) => { send = (message: ClientMessage) => {
this._sendQueue.push(message) this._sendQueue.push(message)
this.emit(SocketEvent.Enqueue, message, this.url) this.emit(SocketEvent.Sending, message, this.url)
} }
} }