ensure we fail subscriptions to closed relays.

This commit is contained in:
fiatjaf
2026-03-01 09:26:00 -03:00
parent 4b5c51ffc0
commit 44c429d6b1
+26 -33
View File
@@ -99,6 +99,14 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
go func() { go func() {
<-ctx.Done() <-ctx.Done()
if wasClosed := r.closed.Swap(true); wasClosed {
return
}
r.closeMutex.Invalidate()
if r.conn != nil {
cause := context.Cause(ctx) cause := context.Cause(ctx)
code := ws.StatusNormalClosure code := ws.StatusNormalClosure
reason := "" reason := ""
@@ -109,7 +117,20 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
} else if cause != nil { } else if cause != nil {
reason = cause.Error() reason = cause.Error()
} }
r.closeConnection(code, reason)
_ = r.conn.Close(code, reason)
}
if r.closeMutex != nil {
r.closeMutex.Lock()
if r.closedNotify != nil {
close(r.closedNotify)
}
if r.writeQueue != nil {
close(r.writeQueue)
}
r.conn = nil
r.closeMutex.Unlock()
}
}() }()
return r return r
@@ -229,10 +250,7 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro
r.conn = c r.conn = c
r.writeQueue = make(chan writeRequest) r.writeQueue = make(chan writeRequest)
if r.closed == nil {
r.closed = &atomic.Bool{} r.closed = &atomic.Bool{}
}
r.closed.Store(false)
r.closedNotify = make(chan struct{}) r.closedNotify = make(chan struct{})
connCtx := r.connectionContext connCtx := r.connectionContext
@@ -317,30 +335,6 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro
} }
func (r *Relay) closeConnection(code ws.StatusCode, reason string) { func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
if r.closed == nil {
r.closed = &atomic.Bool{}
}
wasClosed := r.closed.Swap(true)
if wasClosed {
return
}
if r.closeMutex != nil {
r.closeMutex.Invalidate()
}
if r.conn != nil {
_ = r.conn.Close(code, reason)
}
if r.closeMutex != nil {
r.closeMutex.Lock()
if r.closedNotify != nil {
close(r.closedNotify)
}
if r.writeQueue != nil {
close(r.writeQueue)
}
r.conn = nil
r.closeMutex.Unlock()
}
} }
func (r *Relay) handleMessage(message string) { func (r *Relay) handleMessage(message string) {
@@ -582,13 +576,12 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point. // Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
// Failure to do that will result in a huge number of halted goroutines being created. // Failure to do that will result in a huge number of halted goroutines being created.
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) { func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
sub := r.PrepareSubscription(ctx, filter, opts) if r.conn == nil || r.closed.Load() {
return nil, ErrDisconnected
if r.conn == nil {
sub.cancel(ErrNotConnected)
return nil, fmt.Errorf("not connected to %s", r.URL)
} }
sub := r.PrepareSubscription(ctx, filter, opts)
if err := sub.Fire(); err != nil { if err := sub.Fire(); err != nil {
sub.cancel(ErrFireFailed) sub.cancel(ErrFireFailed)
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err) return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)