forked from coracle/nostrlib
ensure we fail subscriptions to closed relays.
This commit is contained in:
@@ -99,17 +99,38 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cause := context.Cause(ctx)
|
||||
code := ws.StatusNormalClosure
|
||||
reason := ""
|
||||
var cc closeCause
|
||||
if errors.As(cause, &cc) {
|
||||
code = cc.code
|
||||
reason = cc.reason
|
||||
} else if cause != nil {
|
||||
reason = cause.Error()
|
||||
|
||||
if wasClosed := r.closed.Swap(true); wasClosed {
|
||||
return
|
||||
}
|
||||
|
||||
r.closeMutex.Invalidate()
|
||||
|
||||
if r.conn != nil {
|
||||
cause := context.Cause(ctx)
|
||||
code := ws.StatusNormalClosure
|
||||
reason := ""
|
||||
var cc closeCause
|
||||
if errors.As(cause, &cc) {
|
||||
code = cc.code
|
||||
reason = cc.reason
|
||||
} else if cause != nil {
|
||||
reason = cause.Error()
|
||||
}
|
||||
|
||||
_ = r.conn.Close(code, reason)
|
||||
}
|
||||
if r.closeMutex != nil {
|
||||
r.closeMutex.Lock()
|
||||
if r.closedNotify != nil {
|
||||
close(r.closedNotify)
|
||||
}
|
||||
if r.writeQueue != nil {
|
||||
close(r.writeQueue)
|
||||
}
|
||||
r.conn = nil
|
||||
r.closeMutex.Unlock()
|
||||
}
|
||||
r.closeConnection(code, reason)
|
||||
}()
|
||||
|
||||
return r
|
||||
@@ -229,10 +250,7 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro
|
||||
|
||||
r.conn = c
|
||||
r.writeQueue = make(chan writeRequest)
|
||||
if r.closed == nil {
|
||||
r.closed = &atomic.Bool{}
|
||||
}
|
||||
r.closed.Store(false)
|
||||
r.closed = &atomic.Bool{}
|
||||
r.closedNotify = make(chan struct{})
|
||||
|
||||
connCtx := r.connectionContext
|
||||
@@ -317,30 +335,6 @@ func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) erro
|
||||
}
|
||||
|
||||
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||
if r.closed == nil {
|
||||
r.closed = &atomic.Bool{}
|
||||
}
|
||||
wasClosed := r.closed.Swap(true)
|
||||
if wasClosed {
|
||||
return
|
||||
}
|
||||
if r.closeMutex != nil {
|
||||
r.closeMutex.Invalidate()
|
||||
}
|
||||
if r.conn != nil {
|
||||
_ = r.conn.Close(code, reason)
|
||||
}
|
||||
if r.closeMutex != nil {
|
||||
r.closeMutex.Lock()
|
||||
if r.closedNotify != nil {
|
||||
close(r.closedNotify)
|
||||
}
|
||||
if r.writeQueue != nil {
|
||||
close(r.writeQueue)
|
||||
}
|
||||
r.conn = nil
|
||||
r.closeMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) handleMessage(message string) {
|
||||
@@ -582,13 +576,12 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
|
||||
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
|
||||
// Failure to do that will result in a huge number of halted goroutines being created.
|
||||
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
|
||||
if r.conn == nil {
|
||||
sub.cancel(ErrNotConnected)
|
||||
return nil, fmt.Errorf("not connected to %s", r.URL)
|
||||
if r.conn == nil || r.closed.Load() {
|
||||
return nil, ErrDisconnected
|
||||
}
|
||||
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
sub.cancel(ErrFireFailed)
|
||||
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)
|
||||
|
||||
Reference in New Issue
Block a user