.Count() to handle CLOSED messages and support AUTH like .Subscribe().

This commit is contained in:
fiatjaf
2026-03-25 09:56:04 -03:00
parent d43fbbf02d
commit ec6f3f8a41
+58 -17
View File
@@ -13,6 +13,7 @@ import (
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -655,6 +656,9 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
// do this so we don't have the possibility of closing the Events channel and then trying to send to it
sub.mu.Lock()
close(sub.Events)
if sub.countResult != nil {
close(sub.countResult)
}
sub.mu.Unlock()
}()
@@ -691,29 +695,25 @@ func (r *Relay) QueryEvents(filter Filter) iter.Seq[Event] {
}
// Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
func (r *Relay) Count(
ctx context.Context,
filter Filter,
opts SubscriptionOptions,
) (uint32, []byte, error) {
// If opts.AutoAuth is set, it will handle "auth-required:" CLOSEs using RelayOptions.AuthHandler.
func (r *Relay) Count(ctx context.Context, filter Filter, opts SubscriptionOptions) (uint32, []byte, error) {
v, err := r.countInternal(ctx, filter, opts)
if err != nil {
return 0, nil, err
}
if v.Count == nil {
return 0, nil, errors.New("count subscription ended without result")
}
return *v.Count, v.HyperLogLog, nil
}
func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) {
sub := r.PrepareSubscription(ctx, filter, opts)
sub.countResult = make(chan CountEnvelope)
if err := sub.Fire(); err != nil {
return CountEnvelope{}, err
if !r.IsConnected() {
return CountEnvelope{}, ErrDisconnected
}
defer sub.cancel(errors.New("countInternal() ended"))
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
var cancel context.CancelFunc
@@ -721,13 +721,54 @@ func (r *Relay) countInternal(ctx context.Context, filter Filter, opts Subscript
defer cancel()
}
hasAuthed := false
for {
select {
case count := <-sub.countResult:
return count, nil
case <-ctx.Done():
return CountEnvelope{}, ctx.Err()
sub := r.PrepareSubscription(ctx, filter, opts)
sub.countResult = make(chan CountEnvelope, 1)
if err := sub.Fire(); err != nil {
sub.cancel(ErrFireFailed)
return CountEnvelope{}, fmt.Errorf("couldn't count %v at %s: %w", filter, r.URL, err)
}
go func() {
<-ctx.Done()
sub.cancel(nil)
}()
for {
select {
case count, ok := <-sub.countResult:
sub.cancel(errors.New("countInternal() ended"))
if !ok || count.Count == nil {
return CountEnvelope{}, errors.New("count subscription ended without result")
}
return count, nil
case reason := <-sub.ClosedReason:
sub.cancel(errors.New("countInternal() ended"))
if strings.HasPrefix(reason, "auth-required:") && r.authHandler != nil && !hasAuthed {
authErr := r.Auth(ctx, func(authCtx context.Context, evt *Event) error {
return r.authHandler(authCtx, r, evt)
})
if authErr == nil {
hasAuthed = true
goto resubscribe
}
return CountEnvelope{}, fmt.Errorf("failed to auth: %w", authErr)
}
return CountEnvelope{}, fmt.Errorf("count: CLOSED received: %s", reason)
case <-sub.Context.Done():
sub.cancel(errors.New("countInternal() ended"))
return CountEnvelope{}, context.Cause(sub.Context)
case <-ctx.Done():
sub.cancel(errors.New("countInternal() ended"))
return CountEnvelope{}, ctx.Err()
}
}
resubscribe:
continue
}
}