Put auth inside socket

This commit is contained in:
Jon Staab
2025-04-02 14:01:37 -07:00
parent 932a08b7b1
commit 7440a07ffc
7 changed files with 36 additions and 52 deletions
+1 -1
View File
@@ -77,7 +77,7 @@ describe('auth', () => {
it("should handle client AUTH message", () => {
const message: RelayMessage = ["AUTH", { id: "123", kind: CLIENT_AUTH }]
socket.emit(SocketEvent.Enqueue, message)
socket.emit(SocketEvent.Sending, message)
expect(authManager.state.status).toBe(AuthStatus.PendingResponse)
})
+4 -4
View File
@@ -205,7 +205,7 @@ describe('policy', () => {
// Send a message
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.emit(SocketEvent.Enqueue, event)
socket.emit(SocketEvent.Sending, event)
// Should open the socket
expect(openSpy).toHaveBeenCalled()
@@ -222,7 +222,7 @@ describe('policy', () => {
// Send a message
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.emit(SocketEvent.Enqueue, event)
socket.emit(SocketEvent.Sending, event)
// Should not try to open the socket
expect(openSpy).not.toHaveBeenCalled()
@@ -240,7 +240,7 @@ describe('policy', () => {
// Send a message
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.emit(SocketEvent.Enqueue, event)
socket.emit(SocketEvent.Sending, event)
// Should not try to open the socket due to recent error
expect(openSpy).not.toHaveBeenCalled()
@@ -249,7 +249,7 @@ describe('policy', () => {
vi.advanceTimersByTime(31000)
// Send another message
socket.emit(SocketEvent.Enqueue, event)
socket.emit(SocketEvent.Sending, event)
// Now it should try to open
expect(openSpy).toHaveBeenCalled()
+1 -1
View File
@@ -89,7 +89,7 @@ describe("Socket", () => {
describe("send", () => {
it("should queue messages and emit enqueue event", () => {
const enqueueSpy = vi.fn()
socket.on(SocketEvent.Enqueue, enqueueSpy)
socket.on(SocketEvent.Sending, enqueueSpy)
const message: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
socket.send(message)
+1 -9
View File
@@ -52,8 +52,6 @@ export class AuthState extends EventEmitter {
if (isRelayOk(message)) {
const [_, id, ok, details] = message
console.log("ok", message)
if (id === this.request) {
this.details = details
@@ -68,23 +66,19 @@ export class AuthState extends EventEmitter {
if (isRelayAuth(message)) {
const [_, challenge] = message
console.log("relay auth", message)
this.challenge = challenge
this.request = undefined
this.details = undefined
this.setStatus(AuthStatus.Requested)
}
}),
on(socket, SocketEvent.Enqueue, (message: RelayMessage) => {
on(socket, SocketEvent.Sending, (message: RelayMessage) => {
if (isClientAuth(message)) {
console.log("client auth", message)
this.setStatus(AuthStatus.PendingResponse)
}
}),
on(socket, SocketEvent.Status, (status: SocketStatus) => {
if (status === SocketStatus.Closed) {
console.log("closed")
this.challenge = undefined
this.request = undefined
this.details = undefined
@@ -113,8 +107,6 @@ export class AuthState extends EventEmitter {
const template = makeAuthEvent(this.socket.url, this.challenge)
const event = await sign(template)
console.log(event)
if (event) {
this.request = event.id
this.socket.send(["AUTH", event])
+7 -11
View File
@@ -12,7 +12,7 @@ import {
isRelayClosed,
} from "./message.js"
import {Socket, SocketStatus, SocketEvent} from "./socket.js"
import {AuthState, AuthStatus, AuthStateEvent} from "./auth.js"
import {AuthStatus, AuthStateEvent} from "./auth.js"
/**
* Defers sending messages when a challenge has been presented and not answered yet
@@ -21,12 +21,11 @@ import {AuthState, AuthStatus, AuthStateEvent} from "./auth.js"
*/
export const socketPolicyDeferOnAuth = (socket: Socket) => {
const buffer: ClientMessage[] = []
const authState = new AuthState(socket)
const okStatuses = [AuthStatus.None, AuthStatus.Ok]
const unsubscribers = [
// Pause sending certain messages when we're not authenticated
on(socket, SocketEvent.Enqueue, (message: ClientMessage) => {
on(socket, SocketEvent.Sending, (message: ClientMessage) => {
// If we're closing a request, but it never got sent, remove both from the queue
// Otherwise, always send CLOSE
if (isClientClose(message)) {
@@ -47,13 +46,13 @@ export const socketPolicyDeferOnAuth = (socket: Socket) => {
if (isClientEvent(message) && message[1].kind === AUTH_JOIN) return
// If we're not ok, remove the message and save it for later
if (!okStatuses.includes(authState.status)) {
if (!okStatuses.includes(socket.auth.status)) {
buffer.push(message)
socket._sendQueue.remove(message)
}
}),
// Send buffered messages when we get successful auth
on(authState, AuthStateEvent.Status, (status: AuthStatus) => {
on(socket.auth, AuthStateEvent.Status, (status: AuthStatus) => {
if (okStatuses.includes(status) && buffer.length > 0) {
for (const message of buffer.splice(0)) {
socket.send(message)
@@ -64,7 +63,6 @@ export const socketPolicyDeferOnAuth = (socket: Socket) => {
return () => {
unsubscribers.forEach(call)
authState.cleanup()
}
}
@@ -142,7 +140,7 @@ export const socketPolicyConnectOnSend = (socket: Socket) => {
lastError = now()
}
}),
on(socket, SocketEvent.Enqueue, (message: ClientMessage) => {
on(socket, SocketEvent.Sending, (message: ClientMessage) => {
// When a new message is sent, make sure the socket is open (unless there was a recent error)
if (socket.status === SocketStatus.Closed && lastError < ago(30)) {
socket.open()
@@ -242,20 +240,18 @@ export type SocketPolicyAuthOptions = {
* @return a socket policy
*/
export const makeSocketPolicyAuth = (options: SocketPolicyAuthOptions) => (socket: Socket) => {
const authState = new AuthState(socket)
const shouldAuth = options.shouldAuth || always(true)
const unsubscribers = [
on(authState, AuthStateEvent.Status, (status: AuthStatus) => {
on(socket.auth, AuthStateEvent.Status, (status: AuthStatus) => {
if (status === AuthStatus.Requested && shouldAuth(socket)) {
authState.authenticate(options.sign)
socket.auth.authenticate(options.sign)
}
}),
]
return () => {
unsubscribers.forEach(call)
authState.cleanup()
}
}
+11 -22
View File
@@ -1,7 +1,6 @@
import {remove} from "@welshman/lib"
import {normalizeRelayUrl} from "@welshman/util"
import {Socket} from "./socket.js"
import {AuthState} from "./auth.js"
import {defaultSocketPolicies} from "./policy.js"
export const makeSocket = (url: string, policies = defaultSocketPolicies) => {
@@ -22,13 +21,8 @@ export type PoolOptions = {
export let poolSingleton: Pool
export type PoolItem = {
socket: Socket
auth: AuthState
}
export class Pool {
_data = new Map<string, PoolItem>()
_data = new Map<string, Socket>()
_subs: PoolSubscription[] = []
static getSingleton() {
@@ -55,25 +49,21 @@ export class Pool {
get(_url: string): Socket {
const url = normalizeRelayUrl(_url)
const item = this._data.get(url)
const socket = this._data.get(url)
if (item) {
return item.socket
if (socket) {
return socket
}
const socket = this.makeSocket(url)
const newSocket = this.makeSocket(url)
this._data.set(url, {socket, auth: new AuthState(socket)})
this._data.set(url, newSocket)
for (const cb of this._subs) {
cb(socket)
cb(newSocket)
}
return socket
}
getAuth(url: string) {
return this._data.get(normalizeRelayUrl(url))?.auth
return newSocket
}
subscribe(cb: PoolSubscription) {
@@ -85,11 +75,10 @@ export class Pool {
}
remove(url: string) {
const item = this._data.get(url)
const socket = this._data.get(url)
if (item) {
item.socket.cleanup()
item.auth.cleanup()
if (socket) {
socket.cleanup()
this._data.delete(url)
}
+11 -4
View File
@@ -2,6 +2,7 @@ import WebSocket from "isomorphic-ws"
import EventEmitter from "events"
import {TaskQueue} from "@welshman/lib"
import {RelayMessage, ClientMessage} from "./message.js"
import {AuthState} from "./auth.js"
export enum SocketStatus {
Open = "socket:status:open",
@@ -16,19 +17,22 @@ export enum SocketEvent {
Error = "socket:event:error",
Status = "socket:event:status",
Send = "socket:event:send",
Enqueue = "socket:event:enqueue",
Sending = "socket:event:sending",
Receive = "socket:event:receive",
Receiving = "socket:event:receiving",
}
export type SocketEvents = {
[SocketEvent.Error]: (error: string, url: string) => void
[SocketEvent.Status]: (status: SocketStatus, url: string) => void
[SocketEvent.Send]: (message: ClientMessage, url: string) => void
[SocketEvent.Enqueue]: (message: ClientMessage, url: string) => void
[SocketEvent.Sending]: (message: ClientMessage, url: string) => void
[SocketEvent.Receive]: (message: RelayMessage, url: string) => void
[SocketEvent.Receiving]: (message: RelayMessage, url: string) => void
}
export class Socket extends EventEmitter {
auth: AuthState
status = SocketStatus.Closed
_ws?: WebSocket
@@ -38,6 +42,8 @@ export class Socket extends EventEmitter {
constructor(readonly url: string) {
super()
this.auth = new AuthState(this)
this._sendQueue = new TaskQueue<ClientMessage>({
batchSize: 50,
processItem: (message: ClientMessage) => {
@@ -84,7 +90,6 @@ export class Socket extends EventEmitter {
this._ws.onclose = () => {
this._ws = undefined
this._sendQueue.stop()
console.log("socket closed", this.url)
this.emit(SocketEvent.Status, SocketStatus.Closed, this.url)
}
@@ -96,6 +101,7 @@ export class Socket extends EventEmitter {
if (Array.isArray(message)) {
this._recvQueue.push(message as RelayMessage)
this.emit(SocketEvent.Receiving, message, this.url)
} else {
this.emit(SocketEvent.Error, "Invalid message received", this.url)
}
@@ -121,6 +127,7 @@ export class Socket extends EventEmitter {
cleanup = () => {
this.close()
this.auth.cleanup()
this._recvQueue.clear()
this._sendQueue.clear()
this.removeAllListeners()
@@ -128,6 +135,6 @@ export class Socket extends EventEmitter {
send = (message: ClientMessage) => {
this._sendQueue.push(message)
this.emit(SocketEvent.Enqueue, message, this.url)
this.emit(SocketEvent.Sending, message, this.url)
}
}