diff --git a/pool.go b/pool.go index 9b2a8b5..e9fa12b 100644 --- a/pool.go +++ b/pool.go @@ -455,15 +455,25 @@ func (pool *Pool) subMany( } } - pending := xsync.NewCounter() - pending.Add(int64(len(urls))) + pendingWg := sync.WaitGroup{} + pendingWg.Add(len(urls)) + + go func() { + pendingWg.Wait() + close(events) + cancel(fmt.Errorf("aborted: %w", context.Cause(ctx))) + if closedChan != nil { + close(closedChan) + } + }() + for i, url := range urls { url = NormalizeURL(url) urls[i] = url if idx := slices.Index(urls, url); idx != i { // skip duplicate relays in the list eoseWg.Done() - pending.Dec() + pendingWg.Done() continue } @@ -471,14 +481,10 @@ func (pool *Pool) subMany( go func(nm string) { defer func() { - pending.Dec() - if pending.Value() == 0 { - close(events) - cancel(fmt.Errorf("aborted: %w", context.Cause(ctx))) - } if eosed.CompareAndSwap(false, true) { eoseWg.Done() } + pendingWg.Done() }() hasAuthed := false @@ -565,10 +571,13 @@ func (pool *Pool) subMany( if err == nil { hasAuthed = true // so we don't keep doing AUTH again and again if closedChan != nil { - closedChan <- RelayClosed{ + select { + case closedChan <- RelayClosed{ Reason: reason, Relay: relay, HandledAuth: true, + }: + case <-ctx.Done(): } } goto subscribe @@ -576,9 +585,12 @@ func (pool *Pool) subMany( } debugLogf("CLOSED from %s: '%s'\n", nm, reason) if closedChan != nil { - closedChan <- RelayClosed{ + select { + case closedChan <- RelayClosed{ Reason: reason, Relay: relay, + }: + case <-ctx.Done(): } } @@ -619,6 +631,9 @@ func (pool *Pool) subManyEose( wg.Wait() cancel(errors.New("all subscriptions ended")) close(events) + if closedChan != nil { + close(closedChan) + } }() for _, url := range urls { @@ -663,10 +678,13 @@ func (pool *Pool) subManyEose( if err == nil { hasAuthed = true // so we don't keep doing AUTH again and again if closedChan != nil { - closedChan <- RelayClosed{ + select { + case closedChan <- RelayClosed{ Relay: relay, Reason: reason, HandledAuth: true, + }: + case <-ctx.Done(): } } goto subscribe @@ -674,9 +692,12 @@ func (pool *Pool) subManyEose( } debugLogf("[pool] CLOSED from %s: '%s'\n", nm, reason) if closedChan != nil { - closedChan <- RelayClosed{ + select { + case closedChan <- RelayClosed{ Relay: relay, Reason: reason, + }: + case <-ctx.Done(): } } return @@ -781,6 +802,7 @@ func (pool *Pool) batchedQueryMany( wg := sync.WaitGroup{} wg.Add(len(dfs)) seenAlready := xsync.NewMapOf[ID, struct{}]() + forwardWg := sync.WaitGroup{} opts.CheckDuplicate = func(id ID, relay string) bool { _, exists := seenAlready.LoadOrStore(id, struct{}{}) @@ -792,10 +814,28 @@ func (pool *Pool) batchedQueryMany( for _, df := range dfs { go func(df DirectedFilter) { + var innerClosed chan RelayClosed + if closedChan != nil { + innerClosed = make(chan RelayClosed) + forwardWg.Add(1) + go func() { + defer forwardWg.Done() + for rc := range innerClosed { + select { + case closedChan <- rc: + case <-ctx.Done(): + for range innerClosed { + } + return + } + } + }() + } + for ie := range pool.subManyEose(ctx, []string{df.Relay}, df.Filter, - closedChan, + innerClosed, opts, ) { select { @@ -812,6 +852,10 @@ func (pool *Pool) batchedQueryMany( go func() { wg.Wait() close(res) + if closedChan != nil { + forwardWg.Wait() + close(closedChan) + } }() return res @@ -847,6 +891,7 @@ func (pool *Pool) batchedSubscribeMany( wg := sync.WaitGroup{} wg.Add(len(dfs)) seenAlready := xsync.NewMapOf[ID, struct{}]() + forwardWg := sync.WaitGroup{} opts.CheckDuplicate = func(id ID, relay string) bool { _, exists := seenAlready.LoadOrStore(id, struct{}{}) @@ -858,11 +903,29 @@ func (pool *Pool) batchedSubscribeMany( for _, df := range dfs { go func(df DirectedFilter) { + var innerClosed chan RelayClosed + if closedChan != nil { + innerClosed = make(chan RelayClosed) + forwardWg.Add(1) + go func() { + defer forwardWg.Done() + for rc := range innerClosed { + select { + case closedChan <- rc: + case <-ctx.Done(): + for range innerClosed { + } + return + } + } + }() + } + for ie := range pool.subMany(ctx, []string{df.Relay}, df.Filter, nil, - closedChan, + innerClosed, opts, ) { select { @@ -879,6 +942,10 @@ func (pool *Pool) batchedSubscribeMany( go func() { wg.Wait() close(res) + if closedChan != nil { + forwardWg.Wait() + close(closedChan) + } }() return res