ensure we fail subscriptions to closed relays.
This commit is contained in:
@@ -99,6 +99,14 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if wasClosed := r.closed.Swap(true); wasClosed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.closeMutex.Invalidate()
|
||||||
|
|
||||||
|
if r.conn != nil {
|
||||||
cause := context.Cause(ctx)
|
cause := context.Cause(ctx)
|
||||||
code := ws.StatusNormalClosure
|
code := ws.StatusNormalClosure
|
||||||
reason := ""
|
reason := ""
|
||||||
@@ -109,7 +117,20 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
|||||||
} else if cause != nil {
|
} else if cause != nil {
|
||||||
reason = cause.Error()
|
reason = cause.Error()
|
||||||
}
|
}
|
||||||
r.closeConnection(code, reason)
|
|
||||||
|
_ = r.conn.Close(code, reason)
|
||||||
|
}
|
||||||
|
if r.closeMutex != nil {
|
||||||
|
r.closeMutex.Lock()
|
||||||
|
if r.closedNotify != nil {
|
||||||
|
close(r.closedNotify)
|
||||||
|
}
|
||||||
|
if r.writeQueue != nil {
|
||||||
|
close(r.writeQueue)
|
||||||
|
}
|
||||||
|
r.conn = nil
|
||||||
|
r.closeMutex.Unlock()
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return r
|
return r
|
||||||
@@ -229,10 +250,7 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro
|
|||||||
|
|
||||||
r.conn = c
|
r.conn = c
|
||||||
r.writeQueue = make(chan writeRequest)
|
r.writeQueue = make(chan writeRequest)
|
||||||
if r.closed == nil {
|
|
||||||
r.closed = &atomic.Bool{}
|
r.closed = &atomic.Bool{}
|
||||||
}
|
|
||||||
r.closed.Store(false)
|
|
||||||
r.closedNotify = make(chan struct{})
|
r.closedNotify = make(chan struct{})
|
||||||
|
|
||||||
connCtx := r.connectionContext
|
connCtx := r.connectionContext
|
||||||
@@ -317,30 +335,6 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||||
if r.closed == nil {
|
|
||||||
r.closed = &atomic.Bool{}
|
|
||||||
}
|
|
||||||
wasClosed := r.closed.Swap(true)
|
|
||||||
if wasClosed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r.closeMutex != nil {
|
|
||||||
r.closeMutex.Invalidate()
|
|
||||||
}
|
|
||||||
if r.conn != nil {
|
|
||||||
_ = r.conn.Close(code, reason)
|
|
||||||
}
|
|
||||||
if r.closeMutex != nil {
|
|
||||||
r.closeMutex.Lock()
|
|
||||||
if r.closedNotify != nil {
|
|
||||||
close(r.closedNotify)
|
|
||||||
}
|
|
||||||
if r.writeQueue != nil {
|
|
||||||
close(r.writeQueue)
|
|
||||||
}
|
|
||||||
r.conn = nil
|
|
||||||
r.closeMutex.Unlock()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) handleMessage(message string) {
|
func (r *Relay) handleMessage(message string) {
|
||||||
@@ -582,13 +576,12 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
|
|||||||
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
|
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
|
||||||
// Failure to do that will result in a huge number of halted goroutines being created.
|
// Failure to do that will result in a huge number of halted goroutines being created.
|
||||||
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
||||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
if r.conn == nil || r.closed.Load() {
|
||||||
|
return nil, ErrDisconnected
|
||||||
if r.conn == nil {
|
|
||||||
sub.cancel(ErrNotConnected)
|
|
||||||
return nil, fmt.Errorf("not connected to %s", r.URL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||||
|
|
||||||
if err := sub.Fire(); err != nil {
|
if err := sub.Fire(); err != nil {
|
||||||
sub.cancel(ErrFireFailed)
|
sub.cancel(ErrFireFailed)
|
||||||
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)
|
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)
|
||||||
|
|||||||
Reference in New Issue
Block a user