merge connection into relay, do all the closing logic on context cancelation and have closeMutex be a channelmutex.
This commit is contained in:
-158
@@ -1,158 +0,0 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
var ErrDisconnected = errors.New("<disconnected>")
|
||||
|
||||
type writeRequest struct {
|
||||
msg []byte
|
||||
answer chan error
|
||||
}
|
||||
|
||||
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
|
||||
debugLogf("{%s} connecting!\n", r.URL)
|
||||
|
||||
dialCtx := ctx
|
||||
if _, ok := dialCtx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||
}
|
||||
|
||||
dialOpts := &ws.DialOptions{
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
|
||||
},
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
for k, v := range r.requestHeader {
|
||||
dialOpts.HTTPHeader[k] = v
|
||||
}
|
||||
|
||||
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.SetReadLimit(2 << 24) // 33MB
|
||||
|
||||
// this will tell if the connection is closed
|
||||
|
||||
// ping every 29 seconds
|
||||
ticker := time.NewTicker(29 * time.Second)
|
||||
|
||||
// main websocket loop
|
||||
readQueue := make(chan string)
|
||||
|
||||
r.conn = c
|
||||
r.writeQueue = make(chan writeRequest)
|
||||
r.closed = &atomic.Bool{}
|
||||
r.closedNotify = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
pingAttempt := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.closeConnection(ws.StatusNormalClosure, "")
|
||||
debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx))
|
||||
return
|
||||
case <-r.closedNotify:
|
||||
return
|
||||
case <-ticker.C:
|
||||
debugLogf("{%s} pinging\n", r.URL)
|
||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
|
||||
err := c.Ping(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
pingAttempt++
|
||||
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
|
||||
|
||||
if pingAttempt >= 3 {
|
||||
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
|
||||
err = r.Close() // this should trigger a context cancelation
|
||||
if err != nil {
|
||||
debugLogf("{%s} failed to close relay: %v", r.URL, err)
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// ping was OK
|
||||
debugLogf("{%s} ping OK", r.URL)
|
||||
pingAttempt = 0
|
||||
case wr := <-r.writeQueue:
|
||||
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
|
||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
|
||||
err := c.Write(ctx, ws.MessageText, wr.msg)
|
||||
cancel()
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
|
||||
r.closeConnection(ws.StatusAbnormalClosure, "write failed")
|
||||
if wr.answer != nil {
|
||||
wr.answer <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
if wr.answer != nil {
|
||||
close(wr.answer)
|
||||
}
|
||||
case msg := <-readQueue:
|
||||
debugLogf("{%s} received %v\n", r.URL, msg)
|
||||
r.handleMessage(msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// read loop -- loops back to the main loop
|
||||
go func() {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
for {
|
||||
buf.Reset()
|
||||
|
||||
_, reader, err := c.Reader(ctx)
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
|
||||
r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader")
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(buf, reader); err != nil {
|
||||
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
|
||||
r.closeConnection(ws.StatusAbnormalClosure, "failed to read")
|
||||
return
|
||||
}
|
||||
|
||||
readQueue <- string(buf.Bytes())
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||
wasClosed := r.closed.Swap(true)
|
||||
if !wasClosed {
|
||||
r.conn.Close(code, reason)
|
||||
r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason))
|
||||
r.closeMutex.Lock()
|
||||
close(r.closedNotify)
|
||||
close(r.writeQueue)
|
||||
r.conn = nil
|
||||
r.closeMutex.Unlock()
|
||||
}
|
||||
}
|
||||
+1
-1
@@ -17,7 +17,7 @@ func TestEOSEMadness(t *testing.T) {
|
||||
}, SubscriptionOptions{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeout := time.After(3 * time.Second)
|
||||
timeout := time.After(2 * time.Second)
|
||||
n := 0
|
||||
e := 0
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
fiatjaf.com/lib v0.3.5
|
||||
github.com/dgraph-io/ristretto/v2 v2.3.0
|
||||
github.com/go-git/go-git/v5 v5.16.3
|
||||
github.com/sivukhin/godjot v1.0.6
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
fiatjaf.com/lib v0.3.5 h1:9kSuAOqHuhShH5sMeLJwceojUCxQ6SS7zrfC2EnJ0+E=
|
||||
fiatjaf.com/lib v0.3.5/go.mod h1:UlHaZvPHj25PtKLh9GjZkUHRmQ2xZ8Jkoa4VRaLeeQ8=
|
||||
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc=
|
||||
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw=
|
||||
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc=
|
||||
|
||||
@@ -1,28 +1,54 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"fiatjaf.com/lib/channelmutex"
|
||||
ws "github.com/coder/websocket"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
)
|
||||
|
||||
var subscriptionIDCounter atomic.Int64
|
||||
|
||||
var (
|
||||
ErrDisconnected = errors.New("<disconnected>")
|
||||
ErrPingFailed = errors.New("<ping failed>")
|
||||
)
|
||||
|
||||
type writeRequest struct {
|
||||
msg []byte
|
||||
answer chan error
|
||||
}
|
||||
|
||||
type closeCause struct {
|
||||
code ws.StatusCode
|
||||
reason string
|
||||
}
|
||||
|
||||
func (c closeCause) Error() string {
|
||||
if c.reason == "" {
|
||||
return "relay closed"
|
||||
}
|
||||
return c.reason
|
||||
}
|
||||
|
||||
// Relay represents a connection to a Nostr relay.
|
||||
type Relay struct {
|
||||
closeMutex sync.Mutex
|
||||
closeMutex *channelmutex.Mutex
|
||||
|
||||
URL string
|
||||
requestHeader http.Header // e.g. for origin header
|
||||
@@ -66,8 +92,26 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
||||
customHandler: opts.CustomHandler,
|
||||
noticeHandler: opts.NoticeHandler,
|
||||
authHandler: opts.AuthHandler,
|
||||
closeMutex: channelmutex.New(),
|
||||
closed: &atomic.Bool{},
|
||||
closedNotify: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cause := context.Cause(ctx)
|
||||
code := ws.StatusNormalClosure
|
||||
reason := ""
|
||||
var cc closeCause
|
||||
if errors.As(cause, &cc) {
|
||||
code = cc.code
|
||||
reason = cc.reason
|
||||
} else if cause != nil {
|
||||
reason = cause.Error()
|
||||
}
|
||||
r.closeConnection(code, reason)
|
||||
}()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -136,6 +180,9 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
|
||||
if r.connectionContext == nil || r.Subscriptions == nil {
|
||||
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
|
||||
}
|
||||
if r.connectionContext.Err() != nil {
|
||||
return fmt.Errorf("relay context canceled")
|
||||
}
|
||||
|
||||
if r.URL == "" {
|
||||
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
||||
@@ -148,6 +195,157 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
|
||||
debugLogf("{%s} connecting!\n", r.URL)
|
||||
if r.connectionContext.Err() != nil {
|
||||
return fmt.Errorf("relay context canceled")
|
||||
}
|
||||
|
||||
dialCtx := ctx
|
||||
if _, ok := dialCtx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||
}
|
||||
|
||||
dialOpts := &ws.DialOptions{
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
|
||||
},
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
for k, v := range r.requestHeader {
|
||||
dialOpts.HTTPHeader[k] = v
|
||||
}
|
||||
|
||||
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.SetReadLimit(2 << 24) // 33MB
|
||||
|
||||
// ping every 19 seconds
|
||||
ticker := time.NewTicker(19 * time.Second)
|
||||
|
||||
// main websocket loop
|
||||
readQueue := make(chan string)
|
||||
|
||||
r.conn = c
|
||||
r.writeQueue = make(chan writeRequest)
|
||||
if r.closed == nil {
|
||||
r.closed = &atomic.Bool{}
|
||||
}
|
||||
r.closed.Store(false)
|
||||
r.closedNotify = make(chan struct{})
|
||||
|
||||
connCtx := r.connectionContext
|
||||
go func() {
|
||||
pingAttempt := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
case <-r.closedNotify:
|
||||
return
|
||||
case <-ticker.C:
|
||||
debugLogf("{%s} pinging\n", r.URL)
|
||||
pingCtx, cancel := context.WithTimeoutCause(connCtx, time.Millisecond*800, errors.New("ping took too long"))
|
||||
err := c.Ping(pingCtx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
pingAttempt++
|
||||
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
|
||||
|
||||
if pingAttempt >= 3 {
|
||||
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
|
||||
_ = r.close(ErrPingFailed)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// ping was OK
|
||||
debugLogf("{%s} ping OK", r.URL)
|
||||
pingAttempt = 0
|
||||
case wr := <-r.writeQueue:
|
||||
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
|
||||
writeCtx, cancel := context.WithTimeoutCause(connCtx, time.Second*10, errors.New("write took too long"))
|
||||
err := c.Write(writeCtx, ws.MessageText, wr.msg)
|
||||
cancel()
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
|
||||
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "write failed"})
|
||||
if wr.answer != nil {
|
||||
wr.answer <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
if wr.answer != nil {
|
||||
close(wr.answer)
|
||||
}
|
||||
case msg := <-readQueue:
|
||||
debugLogf("{%s} received %v\n", r.URL, msg)
|
||||
r.handleMessage(msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// read loop -- loops back to the main loop
|
||||
go func() {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
for {
|
||||
buf.Reset()
|
||||
|
||||
_, reader, err := c.Reader(connCtx)
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
|
||||
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to get reader"})
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(buf, reader); err != nil {
|
||||
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
|
||||
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to read"})
|
||||
return
|
||||
}
|
||||
|
||||
msg := string(buf.Bytes())
|
||||
readQueue <- msg
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||
if r.closed == nil {
|
||||
r.closed = &atomic.Bool{}
|
||||
}
|
||||
wasClosed := r.closed.Swap(true)
|
||||
if wasClosed {
|
||||
return
|
||||
}
|
||||
if r.closeMutex != nil {
|
||||
r.closeMutex.Invalidate()
|
||||
}
|
||||
if r.conn != nil {
|
||||
_ = r.conn.Close(code, reason)
|
||||
}
|
||||
if r.closeMutex != nil {
|
||||
r.closeMutex.Lock()
|
||||
if r.closedNotify != nil {
|
||||
close(r.closedNotify)
|
||||
}
|
||||
if r.writeQueue != nil {
|
||||
close(r.writeQueue)
|
||||
}
|
||||
r.conn = nil
|
||||
r.closeMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) handleMessage(message string) {
|
||||
// if this is an "EVENT" we will have this preparser logic that should speed things up a little
|
||||
// as we skip handling duplicate events
|
||||
@@ -244,14 +442,16 @@ func (r *Relay) handleMessage(message string) {
|
||||
|
||||
// Write queues an arbitrary message to be sent to the relay.
|
||||
func (r *Relay) Write(msg []byte) {
|
||||
r.closeMutex.Lock()
|
||||
defer r.closeMutex.Unlock()
|
||||
select {
|
||||
case <-r.closeMutex.C(): // this locks the mutex
|
||||
case <-r.closedNotify:
|
||||
return
|
||||
default:
|
||||
case <-r.connectionContext.Done():
|
||||
return
|
||||
}
|
||||
|
||||
defer r.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-r.connectionContext.Done():
|
||||
case r.writeQueue <- writeRequest{msg: msg, answer: nil}:
|
||||
@@ -261,13 +461,17 @@ func (r *Relay) Write(msg []byte) {
|
||||
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
|
||||
func (r *Relay) WriteWithError(msg []byte) error {
|
||||
ch := make(chan error)
|
||||
r.closeMutex.Lock()
|
||||
defer r.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-r.closeMutex.C(): // this locks the channel/mutex
|
||||
case <-r.closedNotify:
|
||||
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
|
||||
default:
|
||||
case <-r.connectionContext.Done():
|
||||
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
|
||||
}
|
||||
|
||||
defer r.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-r.connectionContext.Done():
|
||||
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
||||
@@ -396,9 +600,9 @@ func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionO
|
||||
go func() {
|
||||
select {
|
||||
case <-r.closedNotify:
|
||||
sub.unsub(ErrDisconnected)
|
||||
sub.cancel(ErrDisconnected)
|
||||
case <-ctx.Done():
|
||||
sub.unsub(nil)
|
||||
sub.cancel(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -448,7 +652,7 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
|
||||
|
||||
go func() {
|
||||
time.Sleep(opts.MaxWaitForEOSE)
|
||||
sub.eoseTimedOut <- struct{}{}
|
||||
close(sub.eoseTimedOut)
|
||||
sub.dispatchEose()
|
||||
}()
|
||||
}
|
||||
@@ -535,19 +739,7 @@ func (r *Relay) Close() error {
|
||||
}
|
||||
|
||||
func (r *Relay) close(reason error) error {
|
||||
r.closeMutex.Lock()
|
||||
defer r.closeMutex.Unlock()
|
||||
|
||||
if r.connectionContextCancel == nil {
|
||||
return fmt.Errorf("relay already closed")
|
||||
}
|
||||
|
||||
if r.conn == nil {
|
||||
return fmt.Errorf("relay not connected")
|
||||
}
|
||||
|
||||
r.connectionContextCancel(reason)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const RELAY = "wss://nos.lol"
|
||||
const RELAY = "wss://relay.damus.io"
|
||||
|
||||
// test if we can fetch a couple of random events
|
||||
func TestSubscribeBasic(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user