diff --git a/relay.go b/relay.go index 0c8776a..df7ffe3 100644 --- a/relay.go +++ b/relay.go @@ -99,17 +99,38 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay { go func() { <-ctx.Done() - cause := context.Cause(ctx) - code := ws.StatusNormalClosure - reason := "" - var cc closeCause - if errors.As(cause, &cc) { - code = cc.code - reason = cc.reason - } else if cause != nil { - reason = cause.Error() + + if wasClosed := r.closed.Swap(true); wasClosed { + return + } + + r.closeMutex.Invalidate() + + if r.conn != nil { + cause := context.Cause(ctx) + code := ws.StatusNormalClosure + reason := "" + var cc closeCause + if errors.As(cause, &cc) { + code = cc.code + reason = cc.reason + } else if cause != nil { + reason = cause.Error() + } + + _ = r.conn.Close(code, reason) + } + if r.closeMutex != nil { + r.closeMutex.Lock() + if r.closedNotify != nil { + close(r.closedNotify) + } + if r.writeQueue != nil { + close(r.writeQueue) + } + r.conn = nil + r.closeMutex.Unlock() } - r.closeConnection(code, reason) }() return r @@ -229,10 +250,7 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro r.conn = c r.writeQueue = make(chan writeRequest) - if r.closed == nil { - r.closed = &atomic.Bool{} - } - r.closed.Store(false) + r.closed = &atomic.Bool{} r.closedNotify = make(chan struct{}) connCtx := r.connectionContext @@ -317,30 +335,6 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro } func (r *Relay) closeConnection(code ws.StatusCode, reason string) { - if r.closed == nil { - r.closed = &atomic.Bool{} - } - wasClosed := r.closed.Swap(true) - if wasClosed { - return - } - if r.closeMutex != nil { - r.closeMutex.Invalidate() - } - if r.conn != nil { - _ = r.conn.Close(code, reason) - } - if r.closeMutex != nil { - r.closeMutex.Lock() - if r.closedNotify != nil { - close(r.closedNotify) - } - if r.writeQueue != nil { - close(r.writeQueue) - } - r.conn = nil - r.closeMutex.Unlock() - } } func (r *Relay) handleMessage(message string) { @@ -582,13 +576,12 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error { // Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point. // Failure to do that will result in a huge number of halted goroutines being created. func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) { - sub := r.PrepareSubscription(ctx, filter, opts) - - if r.conn == nil { - sub.cancel(ErrNotConnected) - return nil, fmt.Errorf("not connected to %s", r.URL) + if r.conn == nil || r.closed.Load() { + return nil, ErrDisconnected } + sub := r.PrepareSubscription(ctx, filter, opts) + if err := sub.Fire(); err != nil { sub.cancel(ErrFireFailed) return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)