diff --git a/pool.go b/pool.go index 356d071..1ea6928 100644 --- a/pool.go +++ b/pool.go @@ -33,8 +33,7 @@ type Pool struct { relayOptions RelayOptions // custom things not often used - penaltyBoxMu sync.Mutex - penaltyBox map[string][2]float64 + penaltyBox *xsync.MapOf[string, [2]float64] } // DirectedFilter combines a Filter with a specific relay URL. @@ -99,23 +98,22 @@ type PoolOptions struct { } func (pool *Pool) startPenaltyBox() { - pool.penaltyBox = make(map[string][2]float64) + pool.penaltyBox = xsync.NewMapOf[string, [2]float64]() go func() { sleep := 30.0 for { time.Sleep(time.Duration(sleep) * time.Second) - pool.penaltyBoxMu.Lock() nextSleep := 300.0 - for url, v := range pool.penaltyBox { + for url, v := range pool.penaltyBox.Range { remainingSeconds := v[1] remainingSeconds -= sleep if remainingSeconds <= 0 { - pool.penaltyBox[url] = [2]float64{v[0], 0} + pool.penaltyBox.Store(url, [2]float64{v[0], 0}) continue } else { - pool.penaltyBox[url] = [2]float64{v[0], remainingSeconds} + pool.penaltyBox.Store(url, [2]float64{v[0], remainingSeconds}) } if remainingSeconds < nextSleep { @@ -124,7 +122,6 @@ func (pool *Pool) startPenaltyBox() { } sleep = nextSleep - pool.penaltyBoxMu.Unlock() } }() } @@ -138,9 +135,7 @@ func (pool *Pool) EnsureRelay(url string) (*Relay, error) { relays, ok := pool.Relays.Load(nm) if ok && relays == nil { if pool.penaltyBox != nil { - pool.penaltyBoxMu.Lock() - defer pool.penaltyBoxMu.Unlock() - v, _ := pool.penaltyBox[nm] + v, _ := pool.penaltyBox.Load(nm) if v[1] > 0 { return nil, fmt.Errorf("in penalty box, %fs remaining", v[1]) } @@ -173,10 +168,9 @@ func (pool *Pool) ensureNewRelay(nm string) (*Relay, error) { if err := relay.Connect(pool.Context); err != nil { if pool.penaltyBox != nil { // putting relay in penalty box - pool.penaltyBoxMu.Lock() - defer pool.penaltyBoxMu.Unlock() - v, _ := pool.penaltyBox[nm] - pool.penaltyBox[nm] = [2]float64{v[0] + 1, 30.0 + math.Pow(2, v[0]+1)} + pool.penaltyBox.Compute(nm, func(v [2]float64, loaded bool) (newV [2]float64, delete bool) { + return [2]float64{v[0] + 1, 30.0 + math.Pow(2, v[0]+1)}, false + }) pool.Relays.Store(nm, nil) // this is important for penalty box detection on EnsureRelay } return nil, fmt.Errorf("failed to connect: %w", err) @@ -421,7 +415,9 @@ func (pool *Pool) FetchManyReplaceable( sub, err := relay.Subscribe(ctx, filter, opts) if errors.Is(err, ErrTooManySubscriptions) { + unlock := namedLock(relay.URL) newRelay, newErr := pool.ensureNewRelay(relay.URL) + unlock() if newErr != nil { return } @@ -571,7 +567,9 @@ func (pool *Pool) subMany( sub, err = relay.Subscribe(ctx, filter, opts) if errors.Is(err, ErrTooManySubscriptions) { + unlock := namedLock(relay.URL) newRelay, newErr := pool.ensureNewRelay(relay.URL) + unlock() if newErr == nil { newRelay.realSubscriptionsLimit = relay.realSubscriptionsLimit relay = newRelay @@ -725,7 +723,9 @@ func (pool *Pool) subManyEose( sub, err := relay.Subscribe(ctx, filter, opts) if errors.Is(err, ErrTooManySubscriptions) { + unlock := namedLock(relay.URL) newRelay, newErr := pool.ensureNewRelay(relay.URL) + unlock() if newErr != nil { return } @@ -821,7 +821,9 @@ func (pool *Pool) CountMany( ce, err := relay.countInternal(ctx, filter, opts) if errors.Is(err, ErrTooManySubscriptions) { + unlock := namedLock(relay.URL) newRelay, newErr := pool.ensureNewRelay(relay.URL) + unlock() if newErr != nil { return }