Put auth inside socket
This commit is contained in:
@@ -77,7 +77,7 @@ describe('auth', () => {
|
|||||||
|
|
||||||
it("should handle client AUTH message", () => {
|
it("should handle client AUTH message", () => {
|
||||||
const message: RelayMessage = ["AUTH", { id: "123", kind: CLIENT_AUTH }]
|
const message: RelayMessage = ["AUTH", { id: "123", kind: CLIENT_AUTH }]
|
||||||
socket.emit(SocketEvent.Enqueue, message)
|
socket.emit(SocketEvent.Sending, message)
|
||||||
|
|
||||||
expect(authManager.state.status).toBe(AuthStatus.PendingResponse)
|
expect(authManager.state.status).toBe(AuthStatus.PendingResponse)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ describe('policy', () => {
|
|||||||
|
|
||||||
// Send a message
|
// Send a message
|
||||||
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
||||||
socket.emit(SocketEvent.Enqueue, event)
|
socket.emit(SocketEvent.Sending, event)
|
||||||
|
|
||||||
// Should open the socket
|
// Should open the socket
|
||||||
expect(openSpy).toHaveBeenCalled()
|
expect(openSpy).toHaveBeenCalled()
|
||||||
@@ -222,7 +222,7 @@ describe('policy', () => {
|
|||||||
|
|
||||||
// Send a message
|
// Send a message
|
||||||
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
||||||
socket.emit(SocketEvent.Enqueue, event)
|
socket.emit(SocketEvent.Sending, event)
|
||||||
|
|
||||||
// Should not try to open the socket
|
// Should not try to open the socket
|
||||||
expect(openSpy).not.toHaveBeenCalled()
|
expect(openSpy).not.toHaveBeenCalled()
|
||||||
@@ -240,7 +240,7 @@ describe('policy', () => {
|
|||||||
|
|
||||||
// Send a message
|
// Send a message
|
||||||
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
const event: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
||||||
socket.emit(SocketEvent.Enqueue, event)
|
socket.emit(SocketEvent.Sending, event)
|
||||||
|
|
||||||
// Should not try to open the socket due to recent error
|
// Should not try to open the socket due to recent error
|
||||||
expect(openSpy).not.toHaveBeenCalled()
|
expect(openSpy).not.toHaveBeenCalled()
|
||||||
@@ -249,7 +249,7 @@ describe('policy', () => {
|
|||||||
vi.advanceTimersByTime(31000)
|
vi.advanceTimersByTime(31000)
|
||||||
|
|
||||||
// Send another message
|
// Send another message
|
||||||
socket.emit(SocketEvent.Enqueue, event)
|
socket.emit(SocketEvent.Sending, event)
|
||||||
|
|
||||||
// Now it should try to open
|
// Now it should try to open
|
||||||
expect(openSpy).toHaveBeenCalled()
|
expect(openSpy).toHaveBeenCalled()
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ describe("Socket", () => {
|
|||||||
describe("send", () => {
|
describe("send", () => {
|
||||||
it("should queue messages and emit enqueue event", () => {
|
it("should queue messages and emit enqueue event", () => {
|
||||||
const enqueueSpy = vi.fn()
|
const enqueueSpy = vi.fn()
|
||||||
socket.on(SocketEvent.Enqueue, enqueueSpy)
|
socket.on(SocketEvent.Sending, enqueueSpy)
|
||||||
|
|
||||||
const message: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
const message: ClientMessage = ["EVENT", { id: "123", kind: 1 }]
|
||||||
socket.send(message)
|
socket.send(message)
|
||||||
|
|||||||
@@ -52,8 +52,6 @@ export class AuthState extends EventEmitter {
|
|||||||
if (isRelayOk(message)) {
|
if (isRelayOk(message)) {
|
||||||
const [_, id, ok, details] = message
|
const [_, id, ok, details] = message
|
||||||
|
|
||||||
console.log("ok", message)
|
|
||||||
|
|
||||||
if (id === this.request) {
|
if (id === this.request) {
|
||||||
this.details = details
|
this.details = details
|
||||||
|
|
||||||
@@ -68,23 +66,19 @@ export class AuthState extends EventEmitter {
|
|||||||
if (isRelayAuth(message)) {
|
if (isRelayAuth(message)) {
|
||||||
const [_, challenge] = message
|
const [_, challenge] = message
|
||||||
|
|
||||||
console.log("relay auth", message)
|
|
||||||
|
|
||||||
this.challenge = challenge
|
this.challenge = challenge
|
||||||
this.request = undefined
|
this.request = undefined
|
||||||
this.details = undefined
|
this.details = undefined
|
||||||
this.setStatus(AuthStatus.Requested)
|
this.setStatus(AuthStatus.Requested)
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
on(socket, SocketEvent.Enqueue, (message: RelayMessage) => {
|
on(socket, SocketEvent.Sending, (message: RelayMessage) => {
|
||||||
if (isClientAuth(message)) {
|
if (isClientAuth(message)) {
|
||||||
console.log("client auth", message)
|
|
||||||
this.setStatus(AuthStatus.PendingResponse)
|
this.setStatus(AuthStatus.PendingResponse)
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
on(socket, SocketEvent.Status, (status: SocketStatus) => {
|
on(socket, SocketEvent.Status, (status: SocketStatus) => {
|
||||||
if (status === SocketStatus.Closed) {
|
if (status === SocketStatus.Closed) {
|
||||||
console.log("closed")
|
|
||||||
this.challenge = undefined
|
this.challenge = undefined
|
||||||
this.request = undefined
|
this.request = undefined
|
||||||
this.details = undefined
|
this.details = undefined
|
||||||
@@ -113,8 +107,6 @@ export class AuthState extends EventEmitter {
|
|||||||
const template = makeAuthEvent(this.socket.url, this.challenge)
|
const template = makeAuthEvent(this.socket.url, this.challenge)
|
||||||
const event = await sign(template)
|
const event = await sign(template)
|
||||||
|
|
||||||
console.log(event)
|
|
||||||
|
|
||||||
if (event) {
|
if (event) {
|
||||||
this.request = event.id
|
this.request = event.id
|
||||||
this.socket.send(["AUTH", event])
|
this.socket.send(["AUTH", event])
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import {
|
|||||||
isRelayClosed,
|
isRelayClosed,
|
||||||
} from "./message.js"
|
} from "./message.js"
|
||||||
import {Socket, SocketStatus, SocketEvent} from "./socket.js"
|
import {Socket, SocketStatus, SocketEvent} from "./socket.js"
|
||||||
import {AuthState, AuthStatus, AuthStateEvent} from "./auth.js"
|
import {AuthStatus, AuthStateEvent} from "./auth.js"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Defers sending messages when a challenge has been presented and not answered yet
|
* Defers sending messages when a challenge has been presented and not answered yet
|
||||||
@@ -21,12 +21,11 @@ import {AuthState, AuthStatus, AuthStateEvent} from "./auth.js"
|
|||||||
*/
|
*/
|
||||||
export const socketPolicyDeferOnAuth = (socket: Socket) => {
|
export const socketPolicyDeferOnAuth = (socket: Socket) => {
|
||||||
const buffer: ClientMessage[] = []
|
const buffer: ClientMessage[] = []
|
||||||
const authState = new AuthState(socket)
|
|
||||||
const okStatuses = [AuthStatus.None, AuthStatus.Ok]
|
const okStatuses = [AuthStatus.None, AuthStatus.Ok]
|
||||||
|
|
||||||
const unsubscribers = [
|
const unsubscribers = [
|
||||||
// Pause sending certain messages when we're not authenticated
|
// Pause sending certain messages when we're not authenticated
|
||||||
on(socket, SocketEvent.Enqueue, (message: ClientMessage) => {
|
on(socket, SocketEvent.Sending, (message: ClientMessage) => {
|
||||||
// If we're closing a request, but it never got sent, remove both from the queue
|
// If we're closing a request, but it never got sent, remove both from the queue
|
||||||
// Otherwise, always send CLOSE
|
// Otherwise, always send CLOSE
|
||||||
if (isClientClose(message)) {
|
if (isClientClose(message)) {
|
||||||
@@ -47,13 +46,13 @@ export const socketPolicyDeferOnAuth = (socket: Socket) => {
|
|||||||
if (isClientEvent(message) && message[1].kind === AUTH_JOIN) return
|
if (isClientEvent(message) && message[1].kind === AUTH_JOIN) return
|
||||||
|
|
||||||
// If we're not ok, remove the message and save it for later
|
// If we're not ok, remove the message and save it for later
|
||||||
if (!okStatuses.includes(authState.status)) {
|
if (!okStatuses.includes(socket.auth.status)) {
|
||||||
buffer.push(message)
|
buffer.push(message)
|
||||||
socket._sendQueue.remove(message)
|
socket._sendQueue.remove(message)
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
// Send buffered messages when we get successful auth
|
// Send buffered messages when we get successful auth
|
||||||
on(authState, AuthStateEvent.Status, (status: AuthStatus) => {
|
on(socket.auth, AuthStateEvent.Status, (status: AuthStatus) => {
|
||||||
if (okStatuses.includes(status) && buffer.length > 0) {
|
if (okStatuses.includes(status) && buffer.length > 0) {
|
||||||
for (const message of buffer.splice(0)) {
|
for (const message of buffer.splice(0)) {
|
||||||
socket.send(message)
|
socket.send(message)
|
||||||
@@ -64,7 +63,6 @@ export const socketPolicyDeferOnAuth = (socket: Socket) => {
|
|||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
unsubscribers.forEach(call)
|
unsubscribers.forEach(call)
|
||||||
authState.cleanup()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +140,7 @@ export const socketPolicyConnectOnSend = (socket: Socket) => {
|
|||||||
lastError = now()
|
lastError = now()
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
on(socket, SocketEvent.Enqueue, (message: ClientMessage) => {
|
on(socket, SocketEvent.Sending, (message: ClientMessage) => {
|
||||||
// When a new message is sent, make sure the socket is open (unless there was a recent error)
|
// When a new message is sent, make sure the socket is open (unless there was a recent error)
|
||||||
if (socket.status === SocketStatus.Closed && lastError < ago(30)) {
|
if (socket.status === SocketStatus.Closed && lastError < ago(30)) {
|
||||||
socket.open()
|
socket.open()
|
||||||
@@ -242,20 +240,18 @@ export type SocketPolicyAuthOptions = {
|
|||||||
* @return a socket policy
|
* @return a socket policy
|
||||||
*/
|
*/
|
||||||
export const makeSocketPolicyAuth = (options: SocketPolicyAuthOptions) => (socket: Socket) => {
|
export const makeSocketPolicyAuth = (options: SocketPolicyAuthOptions) => (socket: Socket) => {
|
||||||
const authState = new AuthState(socket)
|
|
||||||
const shouldAuth = options.shouldAuth || always(true)
|
const shouldAuth = options.shouldAuth || always(true)
|
||||||
|
|
||||||
const unsubscribers = [
|
const unsubscribers = [
|
||||||
on(authState, AuthStateEvent.Status, (status: AuthStatus) => {
|
on(socket.auth, AuthStateEvent.Status, (status: AuthStatus) => {
|
||||||
if (status === AuthStatus.Requested && shouldAuth(socket)) {
|
if (status === AuthStatus.Requested && shouldAuth(socket)) {
|
||||||
authState.authenticate(options.sign)
|
socket.auth.authenticate(options.sign)
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
]
|
]
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
unsubscribers.forEach(call)
|
unsubscribers.forEach(call)
|
||||||
authState.cleanup()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+11
-22
@@ -1,7 +1,6 @@
|
|||||||
import {remove} from "@welshman/lib"
|
import {remove} from "@welshman/lib"
|
||||||
import {normalizeRelayUrl} from "@welshman/util"
|
import {normalizeRelayUrl} from "@welshman/util"
|
||||||
import {Socket} from "./socket.js"
|
import {Socket} from "./socket.js"
|
||||||
import {AuthState} from "./auth.js"
|
|
||||||
import {defaultSocketPolicies} from "./policy.js"
|
import {defaultSocketPolicies} from "./policy.js"
|
||||||
|
|
||||||
export const makeSocket = (url: string, policies = defaultSocketPolicies) => {
|
export const makeSocket = (url: string, policies = defaultSocketPolicies) => {
|
||||||
@@ -22,13 +21,8 @@ export type PoolOptions = {
|
|||||||
|
|
||||||
export let poolSingleton: Pool
|
export let poolSingleton: Pool
|
||||||
|
|
||||||
export type PoolItem = {
|
|
||||||
socket: Socket
|
|
||||||
auth: AuthState
|
|
||||||
}
|
|
||||||
|
|
||||||
export class Pool {
|
export class Pool {
|
||||||
_data = new Map<string, PoolItem>()
|
_data = new Map<string, Socket>()
|
||||||
_subs: PoolSubscription[] = []
|
_subs: PoolSubscription[] = []
|
||||||
|
|
||||||
static getSingleton() {
|
static getSingleton() {
|
||||||
@@ -55,25 +49,21 @@ export class Pool {
|
|||||||
|
|
||||||
get(_url: string): Socket {
|
get(_url: string): Socket {
|
||||||
const url = normalizeRelayUrl(_url)
|
const url = normalizeRelayUrl(_url)
|
||||||
const item = this._data.get(url)
|
const socket = this._data.get(url)
|
||||||
|
|
||||||
if (item) {
|
if (socket) {
|
||||||
return item.socket
|
return socket
|
||||||
}
|
}
|
||||||
|
|
||||||
const socket = this.makeSocket(url)
|
const newSocket = this.makeSocket(url)
|
||||||
|
|
||||||
this._data.set(url, {socket, auth: new AuthState(socket)})
|
this._data.set(url, newSocket)
|
||||||
|
|
||||||
for (const cb of this._subs) {
|
for (const cb of this._subs) {
|
||||||
cb(socket)
|
cb(newSocket)
|
||||||
}
|
}
|
||||||
|
|
||||||
return socket
|
return newSocket
|
||||||
}
|
|
||||||
|
|
||||||
getAuth(url: string) {
|
|
||||||
return this._data.get(normalizeRelayUrl(url))?.auth
|
|
||||||
}
|
}
|
||||||
|
|
||||||
subscribe(cb: PoolSubscription) {
|
subscribe(cb: PoolSubscription) {
|
||||||
@@ -85,11 +75,10 @@ export class Pool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
remove(url: string) {
|
remove(url: string) {
|
||||||
const item = this._data.get(url)
|
const socket = this._data.get(url)
|
||||||
|
|
||||||
if (item) {
|
if (socket) {
|
||||||
item.socket.cleanup()
|
socket.cleanup()
|
||||||
item.auth.cleanup()
|
|
||||||
|
|
||||||
this._data.delete(url)
|
this._data.delete(url)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import WebSocket from "isomorphic-ws"
|
|||||||
import EventEmitter from "events"
|
import EventEmitter from "events"
|
||||||
import {TaskQueue} from "@welshman/lib"
|
import {TaskQueue} from "@welshman/lib"
|
||||||
import {RelayMessage, ClientMessage} from "./message.js"
|
import {RelayMessage, ClientMessage} from "./message.js"
|
||||||
|
import {AuthState} from "./auth.js"
|
||||||
|
|
||||||
export enum SocketStatus {
|
export enum SocketStatus {
|
||||||
Open = "socket:status:open",
|
Open = "socket:status:open",
|
||||||
@@ -16,19 +17,22 @@ export enum SocketEvent {
|
|||||||
Error = "socket:event:error",
|
Error = "socket:event:error",
|
||||||
Status = "socket:event:status",
|
Status = "socket:event:status",
|
||||||
Send = "socket:event:send",
|
Send = "socket:event:send",
|
||||||
Enqueue = "socket:event:enqueue",
|
Sending = "socket:event:sending",
|
||||||
Receive = "socket:event:receive",
|
Receive = "socket:event:receive",
|
||||||
|
Receiving = "socket:event:receiving",
|
||||||
}
|
}
|
||||||
|
|
||||||
export type SocketEvents = {
|
export type SocketEvents = {
|
||||||
[SocketEvent.Error]: (error: string, url: string) => void
|
[SocketEvent.Error]: (error: string, url: string) => void
|
||||||
[SocketEvent.Status]: (status: SocketStatus, url: string) => void
|
[SocketEvent.Status]: (status: SocketStatus, url: string) => void
|
||||||
[SocketEvent.Send]: (message: ClientMessage, url: string) => void
|
[SocketEvent.Send]: (message: ClientMessage, url: string) => void
|
||||||
[SocketEvent.Enqueue]: (message: ClientMessage, url: string) => void
|
[SocketEvent.Sending]: (message: ClientMessage, url: string) => void
|
||||||
[SocketEvent.Receive]: (message: RelayMessage, url: string) => void
|
[SocketEvent.Receive]: (message: RelayMessage, url: string) => void
|
||||||
|
[SocketEvent.Receiving]: (message: RelayMessage, url: string) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
export class Socket extends EventEmitter {
|
export class Socket extends EventEmitter {
|
||||||
|
auth: AuthState
|
||||||
status = SocketStatus.Closed
|
status = SocketStatus.Closed
|
||||||
|
|
||||||
_ws?: WebSocket
|
_ws?: WebSocket
|
||||||
@@ -38,6 +42,8 @@ export class Socket extends EventEmitter {
|
|||||||
constructor(readonly url: string) {
|
constructor(readonly url: string) {
|
||||||
super()
|
super()
|
||||||
|
|
||||||
|
this.auth = new AuthState(this)
|
||||||
|
|
||||||
this._sendQueue = new TaskQueue<ClientMessage>({
|
this._sendQueue = new TaskQueue<ClientMessage>({
|
||||||
batchSize: 50,
|
batchSize: 50,
|
||||||
processItem: (message: ClientMessage) => {
|
processItem: (message: ClientMessage) => {
|
||||||
@@ -84,7 +90,6 @@ export class Socket extends EventEmitter {
|
|||||||
this._ws.onclose = () => {
|
this._ws.onclose = () => {
|
||||||
this._ws = undefined
|
this._ws = undefined
|
||||||
this._sendQueue.stop()
|
this._sendQueue.stop()
|
||||||
console.log("socket closed", this.url)
|
|
||||||
this.emit(SocketEvent.Status, SocketStatus.Closed, this.url)
|
this.emit(SocketEvent.Status, SocketStatus.Closed, this.url)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,6 +101,7 @@ export class Socket extends EventEmitter {
|
|||||||
|
|
||||||
if (Array.isArray(message)) {
|
if (Array.isArray(message)) {
|
||||||
this._recvQueue.push(message as RelayMessage)
|
this._recvQueue.push(message as RelayMessage)
|
||||||
|
this.emit(SocketEvent.Receiving, message, this.url)
|
||||||
} else {
|
} else {
|
||||||
this.emit(SocketEvent.Error, "Invalid message received", this.url)
|
this.emit(SocketEvent.Error, "Invalid message received", this.url)
|
||||||
}
|
}
|
||||||
@@ -121,6 +127,7 @@ export class Socket extends EventEmitter {
|
|||||||
|
|
||||||
cleanup = () => {
|
cleanup = () => {
|
||||||
this.close()
|
this.close()
|
||||||
|
this.auth.cleanup()
|
||||||
this._recvQueue.clear()
|
this._recvQueue.clear()
|
||||||
this._sendQueue.clear()
|
this._sendQueue.clear()
|
||||||
this.removeAllListeners()
|
this.removeAllListeners()
|
||||||
@@ -128,6 +135,6 @@ export class Socket extends EventEmitter {
|
|||||||
|
|
||||||
send = (message: ClientMessage) => {
|
send = (message: ClientMessage) => {
|
||||||
this._sendQueue.push(message)
|
this._sendQueue.push(message)
|
||||||
this.emit(SocketEvent.Enqueue, message, this.url)
|
this.emit(SocketEvent.Sending, message, this.url)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user