improve/refactor websocket connections hoping this will fix the undetected disconnections we're seeing.
this commit also remove all the sonic envelope parsing and reintroduces filters in REQ as a slice instead of as a singleton. why? well, the sonic stuff wasn't really that fast, it was a little bit but only got fast enough once I introduced unsafe conversions between []byte and string and did weird unsafe reuse of []byte in order to save the values of tags, which would definitely cause issues in the future if the caller wasn't aware of it (and even if they were, like myself). and the filters stuff is because we abandoned the idea of changing NIP-01 to only accept one filter per REQ.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
@@ -10,7 +9,6 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -38,7 +36,6 @@ type Relay struct {
|
||||
noticeHandler func(string) // NIP-01 NOTICEs
|
||||
customHandler func(string) // nonstandard unparseable messages
|
||||
okCallbacks *xsync.MapOf[ID, func(bool, string)]
|
||||
writeQueue chan writeRequest
|
||||
subscriptionChannelCloseQueue chan *Subscription
|
||||
|
||||
// custom things that aren't often used
|
||||
@@ -46,11 +43,6 @@ type Relay struct {
|
||||
AssumeValid bool // this will skip verifying signatures for events received from this relay
|
||||
}
|
||||
|
||||
type writeRequest struct {
|
||||
msg []byte
|
||||
answer chan error
|
||||
}
|
||||
|
||||
// NewRelay returns a new relay. It takes a context that, when canceled, will close the relay connection.
|
||||
func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
@@ -60,7 +52,6 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
||||
connectionContextCancel: cancel,
|
||||
Subscriptions: xsync.NewMapOf[int64, *Subscription](),
|
||||
okCallbacks: xsync.NewMapOf[ID, func(bool, string)](),
|
||||
writeQueue: make(chan writeRequest),
|
||||
subscriptionChannelCloseQueue: make(chan *Subscription),
|
||||
requestHeader: opts.RequestHeader,
|
||||
}
|
||||
@@ -103,7 +94,7 @@ func (r *Relay) String() string {
|
||||
func (r *Relay) Context() context.Context { return r.connectionContext }
|
||||
|
||||
// IsConnected returns true if the connection to this relay seems to be active.
|
||||
func (r *Relay) IsConnected() bool { return r.connectionContext.Err() == nil }
|
||||
func (r *Relay) IsConnected() bool { return !r.Connection.closed.Load() }
|
||||
|
||||
// Connect tries to establish a websocket connection to r.URL.
|
||||
// If the context expires before the connection is complete, an error is returned.
|
||||
@@ -128,164 +119,123 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
|
||||
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||
defer cancel()
|
||||
ctx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||
}
|
||||
|
||||
conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig)
|
||||
conn, err := NewConnection(ctx, r.URL, r.handleMessage, r.requestHeader, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
|
||||
}
|
||||
r.Connection = conn
|
||||
|
||||
// ping every 29 seconds
|
||||
ticker := time.NewTicker(29 * time.Second)
|
||||
|
||||
// queue all write operations here so we don't do mutex spaghetti
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-r.connectionContext.Done():
|
||||
ticker.Stop()
|
||||
r.Connection = nil
|
||||
|
||||
for _, sub := range r.Subscriptions.Range {
|
||||
sub.unsub(fmt.Errorf("relay connection closed: %w / %w", context.Cause(r.connectionContext), r.ConnectionError))
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := r.Connection.Ping(r.connectionContext)
|
||||
if err != nil && !strings.Contains(err.Error(), "failed to wait for pong") {
|
||||
InfoLogger.Printf("{%s} error writing ping: %v; closing websocket", r.URL, err)
|
||||
r.Close() // this should trigger a context cancelation
|
||||
return
|
||||
}
|
||||
case writeRequest := <-r.writeQueue:
|
||||
// all write requests will go through this to prevent races
|
||||
debugLogf("{%s} sending %v\n", r.URL, string(writeRequest.msg))
|
||||
if err := r.Connection.WriteMessage(r.connectionContext, writeRequest.msg); err != nil {
|
||||
writeRequest.answer <- err
|
||||
}
|
||||
close(writeRequest.answer)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// general message reader loop
|
||||
go func() {
|
||||
buf := new(bytes.Buffer)
|
||||
mp := NewMessageParser()
|
||||
|
||||
for {
|
||||
buf.Reset()
|
||||
|
||||
if err := conn.ReadMessage(r.connectionContext, buf); err != nil {
|
||||
r.ConnectionError = err
|
||||
r.close(err)
|
||||
break
|
||||
}
|
||||
|
||||
message := string(buf.Bytes())
|
||||
debugLogf("{%s} received %v\n", r.URL, message)
|
||||
|
||||
// if this is an "EVENT" we will have this preparser logic that should speed things up a little
|
||||
// as we skip handling duplicate events
|
||||
subid := extractSubID(message)
|
||||
sub, ok := r.Subscriptions.Load(subIdToSerial(subid))
|
||||
if ok {
|
||||
if sub.checkDuplicate != nil {
|
||||
if sub.checkDuplicate(extractEventID(message[10+len(subid):]), r.URL) {
|
||||
continue
|
||||
}
|
||||
} else if sub.checkDuplicateReplaceable != nil {
|
||||
if sub.checkDuplicateReplaceable(
|
||||
ReplaceableKey{extractEventPubKey(message), extractDTag(message)},
|
||||
extractTimestamp(message),
|
||||
) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envelope, err := mp.ParseMessage(message)
|
||||
if envelope == nil {
|
||||
if r.customHandler != nil && err == UnknownLabel {
|
||||
r.customHandler(message)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch env := envelope.(type) {
|
||||
case *NoticeEnvelope:
|
||||
// see WithNoticeHandler
|
||||
if r.noticeHandler != nil {
|
||||
r.noticeHandler(string(*env))
|
||||
} else {
|
||||
log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env))
|
||||
}
|
||||
case *AuthEnvelope:
|
||||
if env.Challenge == nil {
|
||||
continue
|
||||
}
|
||||
r.challenge = *env.Challenge
|
||||
case *EventEnvelope:
|
||||
// we already have the subscription from the pre-check above, so we can just reuse it
|
||||
if sub == nil {
|
||||
// InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID)
|
||||
continue
|
||||
} else {
|
||||
// check if the event matches the desired filter, ignore otherwise
|
||||
if !sub.match(env.Event) {
|
||||
InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, sub.Filter, env.Event)
|
||||
continue
|
||||
}
|
||||
|
||||
// check signature, ignore invalid, except from trusted (AssumeValid) relays
|
||||
if !r.AssumeValid {
|
||||
if !env.Event.VerifySignature() {
|
||||
InfoLogger.Printf("{%s} bad signature on %s\n", r.URL, env.Event.ID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// dispatch this to the internal .events channel of the subscription
|
||||
sub.dispatchEvent(env.Event)
|
||||
}
|
||||
case *EOSEEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(string(*env))); ok {
|
||||
subscription.dispatchEose()
|
||||
}
|
||||
case *ClosedEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok {
|
||||
subscription.handleClosed(env.Reason)
|
||||
}
|
||||
case *CountEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
|
||||
subscription.countResult <- *env
|
||||
}
|
||||
case *OKEnvelope:
|
||||
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
||||
okCallback(env.OK, env.Reason)
|
||||
} else {
|
||||
InfoLogger.Printf("{%s} got an unexpected OK message for event %s", r.URL, env.EventID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleMessage(message string) {
|
||||
// if this is an "EVENT" we will have this preparser logic that should speed things up a little
|
||||
// as we skip handling duplicate events
|
||||
subid := extractSubID(message)
|
||||
sub, ok := r.Subscriptions.Load(subIdToSerial(subid))
|
||||
if ok {
|
||||
if sub.checkDuplicate != nil {
|
||||
if sub.checkDuplicate(extractEventID(message[10+len(subid):]), r.URL) {
|
||||
return
|
||||
}
|
||||
} else if sub.checkDuplicateReplaceable != nil {
|
||||
if sub.checkDuplicateReplaceable(
|
||||
ReplaceableKey{extractEventPubKey(message), extractDTag(message)},
|
||||
extractTimestamp(message),
|
||||
) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envelope, err := ParseMessage(message)
|
||||
if envelope == nil {
|
||||
if r.customHandler != nil && err == UnknownLabel {
|
||||
r.customHandler(message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
switch env := envelope.(type) {
|
||||
case *NoticeEnvelope:
|
||||
// see WithNoticeHandler
|
||||
if r.noticeHandler != nil {
|
||||
r.noticeHandler(string(*env))
|
||||
} else {
|
||||
log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env))
|
||||
}
|
||||
case *AuthEnvelope:
|
||||
if env.Challenge == nil {
|
||||
return
|
||||
}
|
||||
r.challenge = *env.Challenge
|
||||
case *EventEnvelope:
|
||||
// we already have the subscription from the pre-check above, so we can just reuse it
|
||||
if sub == nil {
|
||||
// InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID)
|
||||
return
|
||||
} else {
|
||||
// check if the event matches the desired filter, ignore otherwise
|
||||
if !sub.match(env.Event) {
|
||||
InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, sub.Filter, env.Event)
|
||||
return
|
||||
}
|
||||
|
||||
// check signature, ignore invalid, except from trusted (AssumeValid) relays
|
||||
if !r.AssumeValid {
|
||||
if !env.Event.VerifySignature() {
|
||||
InfoLogger.Printf("{%s} bad signature on %s\n", r.URL, env.Event.ID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// dispatch this to the internal .events channel of the subscription
|
||||
sub.dispatchEvent(env.Event)
|
||||
}
|
||||
case *EOSEEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(string(*env))); ok {
|
||||
subscription.dispatchEose()
|
||||
}
|
||||
case *ClosedEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok {
|
||||
subscription.handleClosed(env.Reason)
|
||||
}
|
||||
case *CountEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
|
||||
subscription.countResult <- *env
|
||||
}
|
||||
case *OKEnvelope:
|
||||
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
||||
okCallback(env.OK, env.Reason)
|
||||
} else {
|
||||
InfoLogger.Printf("{%s} got an unexpected OK message for event %s", r.URL, env.EventID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write queues an arbitrary message to be sent to the relay.
|
||||
func (r *Relay) Write(msg []byte) <-chan error {
|
||||
func (r *Relay) Write(msg []byte) {
|
||||
select {
|
||||
case r.Connection.writeQueue <- writeRequest{msg: msg, answer: nil}:
|
||||
case <-r.Connection.closedNotify:
|
||||
case <-r.connectionContext.Done():
|
||||
}
|
||||
}
|
||||
|
||||
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
|
||||
func (r *Relay) WriteWithError(msg []byte) error {
|
||||
ch := make(chan error)
|
||||
select {
|
||||
case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
||||
case r.Connection.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
||||
case <-r.Connection.closedNotify:
|
||||
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
|
||||
case <-r.connectionContext.Done():
|
||||
go func() { ch <- fmt.Errorf("connection closed") }()
|
||||
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
||||
}
|
||||
return ch
|
||||
return <-ch
|
||||
}
|
||||
|
||||
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an OK response.
|
||||
@@ -342,7 +292,7 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
|
||||
|
||||
// publish event
|
||||
envb, _ := env.MarshalJSON()
|
||||
if err := <-r.Write(envb); err != nil {
|
||||
if err := r.WriteWithError(envb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -508,11 +458,6 @@ func (r *Relay) close(reason error) error {
|
||||
return fmt.Errorf("relay not connected")
|
||||
}
|
||||
|
||||
err := r.Connection.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user