diff --git a/relay.go b/relay.go index c11b27a..427cc34 100644 --- a/relay.go +++ b/relay.go @@ -13,6 +13,7 @@ import ( "net/http" "net/textproto" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -655,6 +656,9 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub // do this so we don't have the possibility of closing the Events channel and then trying to send to it sub.mu.Lock() close(sub.Events) + if sub.countResult != nil { + close(sub.countResult) + } sub.mu.Unlock() }() @@ -691,29 +695,25 @@ func (r *Relay) QueryEvents(filter Filter) iter.Seq[Event] { } // Count sends a "COUNT" command to the relay and returns the count of events matching the filters. -func (r *Relay) Count( - ctx context.Context, - filter Filter, - opts SubscriptionOptions, -) (uint32, []byte, error) { +// If opts.AutoAuth is set, it will handle "auth-required:" CLOSEs using RelayOptions.AuthHandler. +func (r *Relay) Count(ctx context.Context, filter Filter, opts SubscriptionOptions) (uint32, []byte, error) { v, err := r.countInternal(ctx, filter, opts) if err != nil { return 0, nil, err } + if v.Count == nil { + return 0, nil, errors.New("count subscription ended without result") + } + return *v.Count, v.HyperLogLog, nil } func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) { - sub := r.PrepareSubscription(ctx, filter, opts) - sub.countResult = make(chan CountEnvelope) - - if err := sub.Fire(); err != nil { - return CountEnvelope{}, err + if !r.IsConnected() { + return CountEnvelope{}, ErrDisconnected } - defer sub.cancel(errors.New("countInternal() ended")) - if _, ok := ctx.Deadline(); !ok { // if no timeout is set, force it to 7 seconds var cancel context.CancelFunc @@ -721,13 +721,54 @@ func (r *Relay) countInternal(ctx context.Context, filter Filter, opts Subscript defer cancel() } + hasAuthed := false + for { - select { - case count := <-sub.countResult: - return count, nil - case <-ctx.Done(): - return CountEnvelope{}, ctx.Err() + sub := r.PrepareSubscription(ctx, filter, opts) + sub.countResult = make(chan CountEnvelope, 1) + + if err := sub.Fire(); err != nil { + sub.cancel(ErrFireFailed) + return CountEnvelope{}, fmt.Errorf("couldn't count %v at %s: %w", filter, r.URL, err) } + + go func() { + <-ctx.Done() + sub.cancel(nil) + }() + + for { + select { + case count, ok := <-sub.countResult: + sub.cancel(errors.New("countInternal() ended")) + if !ok || count.Count == nil { + return CountEnvelope{}, errors.New("count subscription ended without result") + } + return count, nil + case reason := <-sub.ClosedReason: + sub.cancel(errors.New("countInternal() ended")) + if strings.HasPrefix(reason, "auth-required:") && r.authHandler != nil && !hasAuthed { + authErr := r.Auth(ctx, func(authCtx context.Context, evt *Event) error { + return r.authHandler(authCtx, r, evt) + }) + if authErr == nil { + hasAuthed = true + goto resubscribe + } + return CountEnvelope{}, fmt.Errorf("failed to auth: %w", authErr) + } + return CountEnvelope{}, fmt.Errorf("count: CLOSED received: %s", reason) + case <-sub.Context.Done(): + sub.cancel(errors.New("countInternal() ended")) + return CountEnvelope{}, context.Cause(sub.Context) + case <-ctx.Done(): + sub.cancel(errors.New("countInternal() ended")) + return CountEnvelope{}, ctx.Err() + } + } + + resubscribe: + continue } }