merge connection into relay, do all the closing logic on context cancelation and have closeMutex be a channelmutex.

This commit is contained in:
fiatjaf
2026-02-28 07:33:48 -03:00
parent 195cb944e2
commit 1df85217d9
6 changed files with 219 additions and 182 deletions
-158
View File
@@ -1,158 +0,0 @@
package nostr
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/textproto"
"sync/atomic"
"time"
ws "github.com/coder/websocket"
)
var ErrDisconnected = errors.New("<disconnected>")
type writeRequest struct {
msg []byte
answer chan error
}
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
debugLogf("{%s} connecting!\n", r.URL)
dialCtx := ctx
if _, ok := dialCtx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
}
dialOpts := &ws.DialOptions{
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
},
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: httpClient,
}
for k, v := range r.requestHeader {
dialOpts.HTTPHeader[k] = v
}
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
if err != nil {
return err
}
c.SetReadLimit(2 << 24) // 33MB
// this will tell if the connection is closed
// ping every 29 seconds
ticker := time.NewTicker(29 * time.Second)
// main websocket loop
readQueue := make(chan string)
r.conn = c
r.writeQueue = make(chan writeRequest)
r.closed = &atomic.Bool{}
r.closedNotify = make(chan struct{})
go func() {
pingAttempt := 0
for {
select {
case <-ctx.Done():
r.closeConnection(ws.StatusNormalClosure, "")
debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx))
return
case <-r.closedNotify:
return
case <-ticker.C:
debugLogf("{%s} pinging\n", r.URL)
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
err := c.Ping(ctx)
cancel()
if err != nil {
pingAttempt++
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
if pingAttempt >= 3 {
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
err = r.Close() // this should trigger a context cancelation
if err != nil {
debugLogf("{%s} failed to close relay: %v", r.URL, err)
}
}
continue
}
// ping was OK
debugLogf("{%s} ping OK", r.URL)
pingAttempt = 0
case wr := <-r.writeQueue:
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
err := c.Write(ctx, ws.MessageText, wr.msg)
cancel()
if err != nil {
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
r.closeConnection(ws.StatusAbnormalClosure, "write failed")
if wr.answer != nil {
wr.answer <- err
}
return
}
if wr.answer != nil {
close(wr.answer)
}
case msg := <-readQueue:
debugLogf("{%s} received %v\n", r.URL, msg)
r.handleMessage(msg)
}
}
}()
// read loop -- loops back to the main loop
go func() {
buf := new(bytes.Buffer)
for {
buf.Reset()
_, reader, err := c.Reader(ctx)
if err != nil {
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader")
return
}
if _, err := io.Copy(buf, reader); err != nil {
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
r.closeConnection(ws.StatusAbnormalClosure, "failed to read")
return
}
readQueue <- string(buf.Bytes())
}
}()
return nil
}
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
wasClosed := r.closed.Swap(true)
if !wasClosed {
r.conn.Close(code, reason)
r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason))
r.closeMutex.Lock()
close(r.closedNotify)
close(r.writeQueue)
r.conn = nil
r.closeMutex.Unlock()
}
}
+1 -1
View File
@@ -17,7 +17,7 @@ func TestEOSEMadness(t *testing.T) {
}, SubscriptionOptions{})
assert.NoError(t, err)
timeout := time.After(3 * time.Second)
timeout := time.After(2 * time.Second)
n := 0
e := 0
+1
View File
@@ -40,6 +40,7 @@ require (
)
require (
fiatjaf.com/lib v0.3.5
github.com/dgraph-io/ristretto/v2 v2.3.0
github.com/go-git/go-git/v5 v5.16.3
github.com/sivukhin/godjot v1.0.6
+2
View File
@@ -1,3 +1,5 @@
fiatjaf.com/lib v0.3.5 h1:9kSuAOqHuhShH5sMeLJwceojUCxQ6SS7zrfC2EnJ0+E=
fiatjaf.com/lib v0.3.5/go.mod h1:UlHaZvPHj25PtKLh9GjZkUHRmQ2xZ8Jkoa4VRaLeeQ8=
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc=
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw=
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc=
+214 -22
View File
@@ -1,28 +1,54 @@
package nostr
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"iter"
"log"
"math"
"net/http"
"net/textproto"
"strconv"
"sync"
"sync/atomic"
"time"
"fiatjaf.com/lib/channelmutex"
ws "github.com/coder/websocket"
"github.com/puzpuzpuz/xsync/v3"
)
var subscriptionIDCounter atomic.Int64
var (
ErrDisconnected = errors.New("<disconnected>")
ErrPingFailed = errors.New("<ping failed>")
)
type writeRequest struct {
msg []byte
answer chan error
}
type closeCause struct {
code ws.StatusCode
reason string
}
func (c closeCause) Error() string {
if c.reason == "" {
return "relay closed"
}
return c.reason
}
// Relay represents a connection to a Nostr relay.
type Relay struct {
closeMutex sync.Mutex
closeMutex *channelmutex.Mutex
URL string
requestHeader http.Header // e.g. for origin header
@@ -66,8 +92,26 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
customHandler: opts.CustomHandler,
noticeHandler: opts.NoticeHandler,
authHandler: opts.AuthHandler,
closeMutex: channelmutex.New(),
closed: &atomic.Bool{},
closedNotify: make(chan struct{}),
}
go func() {
<-ctx.Done()
cause := context.Cause(ctx)
code := ws.StatusNormalClosure
reason := ""
var cc closeCause
if errors.As(cause, &cc) {
code = cc.code
reason = cc.reason
} else if cause != nil {
reason = cause.Error()
}
r.closeConnection(code, reason)
}()
return r
}
@@ -136,6 +180,9 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
if r.connectionContext == nil || r.Subscriptions == nil {
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
}
if r.connectionContext.Err() != nil {
return fmt.Errorf("relay context canceled")
}
if r.URL == "" {
return fmt.Errorf("invalid relay URL '%s'", r.URL)
@@ -148,6 +195,157 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
return nil
}
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
debugLogf("{%s} connecting!\n", r.URL)
if r.connectionContext.Err() != nil {
return fmt.Errorf("relay context canceled")
}
dialCtx := ctx
if _, ok := dialCtx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
}
dialOpts := &ws.DialOptions{
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
},
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: httpClient,
}
for k, v := range r.requestHeader {
dialOpts.HTTPHeader[k] = v
}
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
if err != nil {
return err
}
c.SetReadLimit(2 << 24) // 33MB
// ping every 19 seconds
ticker := time.NewTicker(19 * time.Second)
// main websocket loop
readQueue := make(chan string)
r.conn = c
r.writeQueue = make(chan writeRequest)
if r.closed == nil {
r.closed = &atomic.Bool{}
}
r.closed.Store(false)
r.closedNotify = make(chan struct{})
connCtx := r.connectionContext
go func() {
pingAttempt := 0
for {
select {
case <-connCtx.Done():
return
case <-r.closedNotify:
return
case <-ticker.C:
debugLogf("{%s} pinging\n", r.URL)
pingCtx, cancel := context.WithTimeoutCause(connCtx, time.Millisecond*800, errors.New("ping took too long"))
err := c.Ping(pingCtx)
cancel()
if err != nil {
pingAttempt++
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
if pingAttempt >= 3 {
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
_ = r.close(ErrPingFailed)
}
continue
}
// ping was OK
debugLogf("{%s} ping OK", r.URL)
pingAttempt = 0
case wr := <-r.writeQueue:
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
writeCtx, cancel := context.WithTimeoutCause(connCtx, time.Second*10, errors.New("write took too long"))
err := c.Write(writeCtx, ws.MessageText, wr.msg)
cancel()
if err != nil {
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "write failed"})
if wr.answer != nil {
wr.answer <- err
}
return
}
if wr.answer != nil {
close(wr.answer)
}
case msg := <-readQueue:
debugLogf("{%s} received %v\n", r.URL, msg)
r.handleMessage(msg)
}
}
}()
// read loop -- loops back to the main loop
go func() {
buf := new(bytes.Buffer)
for {
buf.Reset()
_, reader, err := c.Reader(connCtx)
if err != nil {
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to get reader"})
return
}
if _, err := io.Copy(buf, reader); err != nil {
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to read"})
return
}
msg := string(buf.Bytes())
readQueue <- msg
}
}()
return nil
}
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
if r.closed == nil {
r.closed = &atomic.Bool{}
}
wasClosed := r.closed.Swap(true)
if wasClosed {
return
}
if r.closeMutex != nil {
r.closeMutex.Invalidate()
}
if r.conn != nil {
_ = r.conn.Close(code, reason)
}
if r.closeMutex != nil {
r.closeMutex.Lock()
if r.closedNotify != nil {
close(r.closedNotify)
}
if r.writeQueue != nil {
close(r.writeQueue)
}
r.conn = nil
r.closeMutex.Unlock()
}
}
func (r *Relay) handleMessage(message string) {
// if this is an "EVENT" we will have this preparser logic that should speed things up a little
// as we skip handling duplicate events
@@ -244,14 +442,16 @@ func (r *Relay) handleMessage(message string) {
// Write queues an arbitrary message to be sent to the relay.
func (r *Relay) Write(msg []byte) {
r.closeMutex.Lock()
defer r.closeMutex.Unlock()
select {
case <-r.closeMutex.C(): // this locks the mutex
case <-r.closedNotify:
return
default:
case <-r.connectionContext.Done():
return
}
defer r.closeMutex.Unlock()
select {
case <-r.connectionContext.Done():
case r.writeQueue <- writeRequest{msg: msg, answer: nil}:
@@ -261,13 +461,17 @@ func (r *Relay) Write(msg []byte) {
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
func (r *Relay) WriteWithError(msg []byte) error {
ch := make(chan error)
r.closeMutex.Lock()
defer r.closeMutex.Unlock()
select {
case <-r.closeMutex.C(): // this locks the channel/mutex
case <-r.closedNotify:
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
default:
case <-r.connectionContext.Done():
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
}
defer r.closeMutex.Unlock()
select {
case <-r.connectionContext.Done():
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
@@ -396,9 +600,9 @@ func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionO
go func() {
select {
case <-r.closedNotify:
sub.unsub(ErrDisconnected)
sub.cancel(ErrDisconnected)
case <-ctx.Done():
sub.unsub(nil)
sub.cancel(nil)
}
}()
@@ -448,7 +652,7 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
go func() {
time.Sleep(opts.MaxWaitForEOSE)
sub.eoseTimedOut <- struct{}{}
close(sub.eoseTimedOut)
sub.dispatchEose()
}()
}
@@ -535,19 +739,7 @@ func (r *Relay) Close() error {
}
func (r *Relay) close(reason error) error {
r.closeMutex.Lock()
defer r.closeMutex.Unlock()
if r.connectionContextCancel == nil {
return fmt.Errorf("relay already closed")
}
if r.conn == nil {
return fmt.Errorf("relay not connected")
}
r.connectionContextCancel(reason)
return nil
}
+1 -1
View File
@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert"
)
const RELAY = "wss://nos.lol"
const RELAY = "wss://relay.damus.io"
// test if we can fetch a couple of random events
func TestSubscribeBasic(t *testing.T) {