.Count() to handle CLOSED messages and support AUTH like .Subscribe().
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -655,6 +656,9 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
|
||||
// do this so we don't have the possibility of closing the Events channel and then trying to send to it
|
||||
sub.mu.Lock()
|
||||
close(sub.Events)
|
||||
if sub.countResult != nil {
|
||||
close(sub.countResult)
|
||||
}
|
||||
sub.mu.Unlock()
|
||||
}()
|
||||
|
||||
@@ -691,29 +695,25 @@ func (r *Relay) QueryEvents(filter Filter) iter.Seq[Event] {
|
||||
}
|
||||
|
||||
// Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
|
||||
func (r *Relay) Count(
|
||||
ctx context.Context,
|
||||
filter Filter,
|
||||
opts SubscriptionOptions,
|
||||
) (uint32, []byte, error) {
|
||||
// If opts.AutoAuth is set, it will handle "auth-required:" CLOSEs using RelayOptions.AuthHandler.
|
||||
func (r *Relay) Count(ctx context.Context, filter Filter, opts SubscriptionOptions) (uint32, []byte, error) {
|
||||
v, err := r.countInternal(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if v.Count == nil {
|
||||
return 0, nil, errors.New("count subscription ended without result")
|
||||
}
|
||||
|
||||
return *v.Count, v.HyperLogLog, nil
|
||||
}
|
||||
|
||||
func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) {
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
sub.countResult = make(chan CountEnvelope)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
return CountEnvelope{}, err
|
||||
if !r.IsConnected() {
|
||||
return CountEnvelope{}, ErrDisconnected
|
||||
}
|
||||
|
||||
defer sub.cancel(errors.New("countInternal() ended"))
|
||||
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.CancelFunc
|
||||
@@ -721,13 +721,54 @@ func (r *Relay) countInternal(ctx context.Context, filter Filter, opts Subscript
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
hasAuthed := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case count := <-sub.countResult:
|
||||
return count, nil
|
||||
case <-ctx.Done():
|
||||
return CountEnvelope{}, ctx.Err()
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
sub.countResult = make(chan CountEnvelope, 1)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
sub.cancel(ErrFireFailed)
|
||||
return CountEnvelope{}, fmt.Errorf("couldn't count %v at %s: %w", filter, r.URL, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
sub.cancel(nil)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case count, ok := <-sub.countResult:
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
if !ok || count.Count == nil {
|
||||
return CountEnvelope{}, errors.New("count subscription ended without result")
|
||||
}
|
||||
return count, nil
|
||||
case reason := <-sub.ClosedReason:
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
if strings.HasPrefix(reason, "auth-required:") && r.authHandler != nil && !hasAuthed {
|
||||
authErr := r.Auth(ctx, func(authCtx context.Context, evt *Event) error {
|
||||
return r.authHandler(authCtx, r, evt)
|
||||
})
|
||||
if authErr == nil {
|
||||
hasAuthed = true
|
||||
goto resubscribe
|
||||
}
|
||||
return CountEnvelope{}, fmt.Errorf("failed to auth: %w", authErr)
|
||||
}
|
||||
return CountEnvelope{}, fmt.Errorf("count: CLOSED received: %s", reason)
|
||||
case <-sub.Context.Done():
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
return CountEnvelope{}, context.Cause(sub.Context)
|
||||
case <-ctx.Done():
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
return CountEnvelope{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
resubscribe:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user