.Count() to handle CLOSED messages and support AUTH like .Subscribe().

This commit is contained in:
fiatjaf
2026-03-25 09:56:04 -03:00
parent d43fbbf02d
commit ec6f3f8a41
+58 -17
View File
@@ -13,6 +13,7 @@ import (
"net/http" "net/http"
"net/textproto" "net/textproto"
"strconv" "strconv"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -655,6 +656,9 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
// do this so we don't have the possibility of closing the Events channel and then trying to send to it // do this so we don't have the possibility of closing the Events channel and then trying to send to it
sub.mu.Lock() sub.mu.Lock()
close(sub.Events) close(sub.Events)
if sub.countResult != nil {
close(sub.countResult)
}
sub.mu.Unlock() sub.mu.Unlock()
}() }()
@@ -691,29 +695,25 @@ func (r *Relay) QueryEvents(filter Filter) iter.Seq[Event] {
} }
// Count sends a "COUNT" command to the relay and returns the count of events matching the filters. // Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
func (r *Relay) Count( // If opts.AutoAuth is set, it will handle "auth-required:" CLOSEs using RelayOptions.AuthHandler.
ctx context.Context, func (r *Relay) Count(ctx context.Context, filter Filter, opts SubscriptionOptions) (uint32, []byte, error) {
filter Filter,
opts SubscriptionOptions,
) (uint32, []byte, error) {
v, err := r.countInternal(ctx, filter, opts) v, err := r.countInternal(ctx, filter, opts)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
if v.Count == nil {
return 0, nil, errors.New("count subscription ended without result")
}
return *v.Count, v.HyperLogLog, nil return *v.Count, v.HyperLogLog, nil
} }
func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) { func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) {
sub := r.PrepareSubscription(ctx, filter, opts) if !r.IsConnected() {
sub.countResult = make(chan CountEnvelope) return CountEnvelope{}, ErrDisconnected
if err := sub.Fire(); err != nil {
return CountEnvelope{}, err
} }
defer sub.cancel(errors.New("countInternal() ended"))
if _, ok := ctx.Deadline(); !ok { if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds // if no timeout is set, force it to 7 seconds
var cancel context.CancelFunc var cancel context.CancelFunc
@@ -721,13 +721,54 @@ func (r *Relay) countInternal(ctx context.Context, filter Filter, opts Subscript
defer cancel() defer cancel()
} }
hasAuthed := false
for { for {
select { sub := r.PrepareSubscription(ctx, filter, opts)
case count := <-sub.countResult: sub.countResult = make(chan CountEnvelope, 1)
return count, nil
case <-ctx.Done(): if err := sub.Fire(); err != nil {
return CountEnvelope{}, ctx.Err() sub.cancel(ErrFireFailed)
return CountEnvelope{}, fmt.Errorf("couldn't count %v at %s: %w", filter, r.URL, err)
} }
go func() {
<-ctx.Done()
sub.cancel(nil)
}()
for {
select {
case count, ok := <-sub.countResult:
sub.cancel(errors.New("countInternal() ended"))
if !ok || count.Count == nil {
return CountEnvelope{}, errors.New("count subscription ended without result")
}
return count, nil
case reason := <-sub.ClosedReason:
sub.cancel(errors.New("countInternal() ended"))
if strings.HasPrefix(reason, "auth-required:") && r.authHandler != nil && !hasAuthed {
authErr := r.Auth(ctx, func(authCtx context.Context, evt *Event) error {
return r.authHandler(authCtx, r, evt)
})
if authErr == nil {
hasAuthed = true
goto resubscribe
}
return CountEnvelope{}, fmt.Errorf("failed to auth: %w", authErr)
}
return CountEnvelope{}, fmt.Errorf("count: CLOSED received: %s", reason)
case <-sub.Context.Done():
sub.cancel(errors.New("countInternal() ended"))
return CountEnvelope{}, context.Cause(sub.Context)
case <-ctx.Done():
sub.cancel(errors.New("countInternal() ended"))
return CountEnvelope{}, ctx.Err()
}
}
resubscribe:
continue
} }
} }