From 1df85217d99e74572d9af02601e39e2c14eae1b5 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Sat, 28 Feb 2026 07:33:48 -0300 Subject: [PATCH] merge connection into relay, do all the closing logic on context cancelation and have closeMutex be a channelmutex. --- connection.go | 158 ----------------------------- eose_test.go | 2 +- go.mod | 1 + go.sum | 2 + relay.go | 236 +++++++++++++++++++++++++++++++++++++++---- subscription_test.go | 2 +- 6 files changed, 219 insertions(+), 182 deletions(-) delete mode 100644 connection.go diff --git a/connection.go b/connection.go deleted file mode 100644 index 141c565..0000000 --- a/connection.go +++ /dev/null @@ -1,158 +0,0 @@ -package nostr - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net/http" - "net/textproto" - "sync/atomic" - "time" - - ws "github.com/coder/websocket" -) - -var ErrDisconnected = errors.New("") - -type writeRequest struct { - msg []byte - answer chan error -} - -func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error { - debugLogf("{%s} connecting!\n", r.URL) - - dialCtx := ctx - if _, ok := dialCtx.Deadline(); !ok { - // if no timeout is set, force it to 7 seconds - dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) - } - - dialOpts := &ws.DialOptions{ - HTTPHeader: http.Header{ - textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"}, - }, - CompressionMode: ws.CompressionContextTakeover, - HTTPClient: httpClient, - } - for k, v := range r.requestHeader { - dialOpts.HTTPHeader[k] = v - } - - c, _, err := ws.Dial(dialCtx, r.URL, dialOpts) - if err != nil { - return err - } - c.SetReadLimit(2 << 24) // 33MB - - // this will tell if the connection is closed - - // ping every 29 seconds - ticker := time.NewTicker(29 * time.Second) - - // main websocket loop - readQueue := make(chan string) - - r.conn = c - r.writeQueue = make(chan writeRequest) - r.closed = &atomic.Bool{} - r.closedNotify = make(chan struct{}) - - go func() { - pingAttempt := 0 - - for { - select { - case <-ctx.Done(): - r.closeConnection(ws.StatusNormalClosure, "") - debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx)) - return - case <-r.closedNotify: - return - case <-ticker.C: - debugLogf("{%s} pinging\n", r.URL) - ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long")) - err := c.Ping(ctx) - cancel() - - if err != nil { - pingAttempt++ - debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err) - - if pingAttempt >= 3 { - debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL) - err = r.Close() // this should trigger a context cancelation - if err != nil { - debugLogf("{%s} failed to close relay: %v", r.URL, err) - } - } - - continue - } - - // ping was OK - debugLogf("{%s} ping OK", r.URL) - pingAttempt = 0 - case wr := <-r.writeQueue: - debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg)) - ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long")) - err := c.Write(ctx, ws.MessageText, wr.msg) - cancel() - if err != nil { - debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err) - r.closeConnection(ws.StatusAbnormalClosure, "write failed") - if wr.answer != nil { - wr.answer <- err - } - return - } - if wr.answer != nil { - close(wr.answer) - } - case msg := <-readQueue: - debugLogf("{%s} received %v\n", r.URL, msg) - r.handleMessage(msg) - } - } - }() - - // read loop -- loops back to the main loop - go func() { - buf := new(bytes.Buffer) - - for { - buf.Reset() - - _, reader, err := c.Reader(ctx) - if err != nil { - debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err) - r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader") - return - } - if _, err := io.Copy(buf, reader); err != nil { - debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err) - r.closeConnection(ws.StatusAbnormalClosure, "failed to read") - return - } - - readQueue <- string(buf.Bytes()) - } - }() - - return nil -} - -func (r *Relay) closeConnection(code ws.StatusCode, reason string) { - wasClosed := r.closed.Swap(true) - if !wasClosed { - r.conn.Close(code, reason) - r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason)) - r.closeMutex.Lock() - close(r.closedNotify) - close(r.writeQueue) - r.conn = nil - r.closeMutex.Unlock() - } -} diff --git a/eose_test.go b/eose_test.go index 71305a2..6a8e857 100644 --- a/eose_test.go +++ b/eose_test.go @@ -17,7 +17,7 @@ func TestEOSEMadness(t *testing.T) { }, SubscriptionOptions{}) assert.NoError(t, err) - timeout := time.After(3 * time.Second) + timeout := time.After(2 * time.Second) n := 0 e := 0 diff --git a/go.mod b/go.mod index 95864b5..4a055ec 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( ) require ( + fiatjaf.com/lib v0.3.5 github.com/dgraph-io/ristretto/v2 v2.3.0 github.com/go-git/go-git/v5 v5.16.3 github.com/sivukhin/godjot v1.0.6 diff --git a/go.sum b/go.sum index b4ececd..70d7cb3 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +fiatjaf.com/lib v0.3.5 h1:9kSuAOqHuhShH5sMeLJwceojUCxQ6SS7zrfC2EnJ0+E= +fiatjaf.com/lib v0.3.5/go.mod h1:UlHaZvPHj25PtKLh9GjZkUHRmQ2xZ8Jkoa4VRaLeeQ8= github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc= github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw= github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc= diff --git a/relay.go b/relay.go index bdb72e8..ac76dec 100644 --- a/relay.go +++ b/relay.go @@ -1,28 +1,54 @@ package nostr import ( + "bytes" "context" "crypto/tls" "errors" "fmt" + "io" "iter" "log" "math" "net/http" + "net/textproto" "strconv" "sync" "sync/atomic" "time" + "fiatjaf.com/lib/channelmutex" ws "github.com/coder/websocket" "github.com/puzpuzpuz/xsync/v3" ) var subscriptionIDCounter atomic.Int64 +var ( + ErrDisconnected = errors.New("") + ErrPingFailed = errors.New("") +) + +type writeRequest struct { + msg []byte + answer chan error +} + +type closeCause struct { + code ws.StatusCode + reason string +} + +func (c closeCause) Error() string { + if c.reason == "" { + return "relay closed" + } + return c.reason +} + // Relay represents a connection to a Nostr relay. type Relay struct { - closeMutex sync.Mutex + closeMutex *channelmutex.Mutex URL string requestHeader http.Header // e.g. for origin header @@ -66,8 +92,26 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay { customHandler: opts.CustomHandler, noticeHandler: opts.NoticeHandler, authHandler: opts.AuthHandler, + closeMutex: channelmutex.New(), + closed: &atomic.Bool{}, + closedNotify: make(chan struct{}), } + go func() { + <-ctx.Done() + cause := context.Cause(ctx) + code := ws.StatusNormalClosure + reason := "" + var cc closeCause + if errors.As(cause, &cc) { + code = cc.code + reason = cc.reason + } else if cause != nil { + reason = cause.Error() + } + r.closeConnection(code, reason) + }() + return r } @@ -136,6 +180,9 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro if r.connectionContext == nil || r.Subscriptions == nil { return fmt.Errorf("relay must be initialized with a call to NewRelay()") } + if r.connectionContext.Err() != nil { + return fmt.Errorf("relay context canceled") + } if r.URL == "" { return fmt.Errorf("invalid relay URL '%s'", r.URL) @@ -148,6 +195,157 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro return nil } +func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error { + debugLogf("{%s} connecting!\n", r.URL) + if r.connectionContext.Err() != nil { + return fmt.Errorf("relay context canceled") + } + + dialCtx := ctx + if _, ok := dialCtx.Deadline(); !ok { + // if no timeout is set, force it to 7 seconds + dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) + } + + dialOpts := &ws.DialOptions{ + HTTPHeader: http.Header{ + textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"}, + }, + CompressionMode: ws.CompressionContextTakeover, + HTTPClient: httpClient, + } + for k, v := range r.requestHeader { + dialOpts.HTTPHeader[k] = v + } + + c, _, err := ws.Dial(dialCtx, r.URL, dialOpts) + if err != nil { + return err + } + c.SetReadLimit(2 << 24) // 33MB + + // ping every 19 seconds + ticker := time.NewTicker(19 * time.Second) + + // main websocket loop + readQueue := make(chan string) + + r.conn = c + r.writeQueue = make(chan writeRequest) + if r.closed == nil { + r.closed = &atomic.Bool{} + } + r.closed.Store(false) + r.closedNotify = make(chan struct{}) + + connCtx := r.connectionContext + go func() { + pingAttempt := 0 + + for { + select { + case <-connCtx.Done(): + return + case <-r.closedNotify: + return + case <-ticker.C: + debugLogf("{%s} pinging\n", r.URL) + pingCtx, cancel := context.WithTimeoutCause(connCtx, time.Millisecond*800, errors.New("ping took too long")) + err := c.Ping(pingCtx) + cancel() + + if err != nil { + pingAttempt++ + debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err) + + if pingAttempt >= 3 { + debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL) + _ = r.close(ErrPingFailed) + } + + continue + } + + // ping was OK + debugLogf("{%s} ping OK", r.URL) + pingAttempt = 0 + case wr := <-r.writeQueue: + debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg)) + writeCtx, cancel := context.WithTimeoutCause(connCtx, time.Second*10, errors.New("write took too long")) + err := c.Write(writeCtx, ws.MessageText, wr.msg) + cancel() + if err != nil { + debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err) + _ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "write failed"}) + if wr.answer != nil { + wr.answer <- err + } + return + } + if wr.answer != nil { + close(wr.answer) + } + case msg := <-readQueue: + debugLogf("{%s} received %v\n", r.URL, msg) + r.handleMessage(msg) + } + } + }() + + // read loop -- loops back to the main loop + go func() { + buf := new(bytes.Buffer) + + for { + buf.Reset() + + _, reader, err := c.Reader(connCtx) + if err != nil { + debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err) + _ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to get reader"}) + return + } + if _, err := io.Copy(buf, reader); err != nil { + debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err) + _ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to read"}) + return + } + + msg := string(buf.Bytes()) + readQueue <- msg + } + }() + + return nil +} + +func (r *Relay) closeConnection(code ws.StatusCode, reason string) { + if r.closed == nil { + r.closed = &atomic.Bool{} + } + wasClosed := r.closed.Swap(true) + if wasClosed { + return + } + if r.closeMutex != nil { + r.closeMutex.Invalidate() + } + if r.conn != nil { + _ = r.conn.Close(code, reason) + } + if r.closeMutex != nil { + r.closeMutex.Lock() + if r.closedNotify != nil { + close(r.closedNotify) + } + if r.writeQueue != nil { + close(r.writeQueue) + } + r.conn = nil + r.closeMutex.Unlock() + } +} + func (r *Relay) handleMessage(message string) { // if this is an "EVENT" we will have this preparser logic that should speed things up a little // as we skip handling duplicate events @@ -244,14 +442,16 @@ func (r *Relay) handleMessage(message string) { // Write queues an arbitrary message to be sent to the relay. func (r *Relay) Write(msg []byte) { - r.closeMutex.Lock() - defer r.closeMutex.Unlock() select { + case <-r.closeMutex.C(): // this locks the mutex case <-r.closedNotify: return - default: + case <-r.connectionContext.Done(): + return } + defer r.closeMutex.Unlock() + select { case <-r.connectionContext.Done(): case r.writeQueue <- writeRequest{msg: msg, answer: nil}: @@ -261,13 +461,17 @@ func (r *Relay) Write(msg []byte) { // WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed). func (r *Relay) WriteWithError(msg []byte) error { ch := make(chan error) - r.closeMutex.Lock() - defer r.closeMutex.Unlock() + select { + case <-r.closeMutex.C(): // this locks the channel/mutex case <-r.closedNotify: return fmt.Errorf("failed to write to %s: ", r.URL) - default: + case <-r.connectionContext.Done(): + return fmt.Errorf("failed to write to %s: ", r.URL) } + + defer r.closeMutex.Unlock() + select { case <-r.connectionContext.Done(): return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext)) @@ -396,9 +600,9 @@ func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionO go func() { select { case <-r.closedNotify: - sub.unsub(ErrDisconnected) + sub.cancel(ErrDisconnected) case <-ctx.Done(): - sub.unsub(nil) + sub.cancel(nil) } }() @@ -448,7 +652,7 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub go func() { time.Sleep(opts.MaxWaitForEOSE) - sub.eoseTimedOut <- struct{}{} + close(sub.eoseTimedOut) sub.dispatchEose() }() } @@ -535,19 +739,7 @@ func (r *Relay) Close() error { } func (r *Relay) close(reason error) error { - r.closeMutex.Lock() - defer r.closeMutex.Unlock() - - if r.connectionContextCancel == nil { - return fmt.Errorf("relay already closed") - } - - if r.conn == nil { - return fmt.Errorf("relay not connected") - } - r.connectionContextCancel(reason) - return nil } diff --git a/subscription_test.go b/subscription_test.go index ca33321..50ff179 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) -const RELAY = "wss://nos.lol" +const RELAY = "wss://relay.damus.io" // test if we can fetch a couple of random events func TestSubscribeBasic(t *testing.T) {