Compare commits
83 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4daeb8737c | |||
| 7ab69cbc60 | |||
| 029f4eb0d8 | |||
| cf734a3ac7 | |||
| d92a0cde16 | |||
| 5944a3ead6 | |||
| 3e35681cb9 | |||
| 8515153df2 | |||
| 98fa53464e | |||
| 29cdd48fcb | |||
| 181de14642 | |||
| 1794f0690f | |||
| 12af4717d4 | |||
| b989b66bb7 | |||
| 4261bc88f8 | |||
| a8205a3790 | |||
| 0152341144 | |||
| 9bf9816c15 | |||
| 82f2fbdb99 | |||
| d5b54a1c91 | |||
| 637412fd38 | |||
| 9b881801d8 | |||
| 371cecdb84 | |||
| 2735abe060 | |||
| b9a3e78752 | |||
| ff03090610 | |||
| 72a5be58d7 | |||
| 2c30300756 | |||
| d1fdc262f2 | |||
| 117a304f68 | |||
| ac2d4579f1 | |||
| 56610a32e6 | |||
| d4940c7858 | |||
| 172e7890b9 | |||
| 3acfbbca0a | |||
| b5974cfa45 | |||
| c74ac74a0e | |||
| ec6f3f8a41 | |||
| d43fbbf02d | |||
| 6a686c31af | |||
| a6fdcd8b30 | |||
| e675f04bd2 | |||
| 0630bbe4e9 | |||
| 55c5194bdf | |||
| f3f5c3982d | |||
| 1520264394 | |||
| 2cec1c9434 | |||
| 6cbe984e16 | |||
| 5a0b18e65a | |||
| bb4093d834 | |||
| 3bd059d1f9 | |||
| 681bd55e55 | |||
| 4e490879b5 | |||
| 4348c64b14 | |||
| 2c0d9712e3 | |||
| 4719c0bc9f | |||
| 163e59e1f1 | |||
| 21ce0046c0 | |||
| 1d14e6bebe | |||
| 23d525f067 | |||
| 4dab261bdf | |||
| 44c429d6b1 | |||
| 4b5c51ffc0 | |||
| 5de9501556 | |||
| 8ba05114cd | |||
| 1df85217d9 | |||
| 195cb944e2 | |||
| c31b92707b | |||
| 00ffe16cb7 | |||
| 4d1b6c1df0 | |||
| 62d15178ec | |||
| 32dd39da81 | |||
| 7aa127a8c3 | |||
| 55cc52876a | |||
| 137c09369a | |||
| d445ba9919 | |||
| d30c1bff46 | |||
| 65ef1c50a7 | |||
| 7a4b71b39b | |||
| 3f52d10421 | |||
| a98ac0d050 | |||
| 28bef1c990 | |||
| beb8a72491 |
-158
@@ -1,158 +0,0 @@
|
|||||||
package nostr
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/textproto"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
ws "github.com/coder/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrDisconnected = errors.New("<disconnected>")
|
|
||||||
|
|
||||||
type writeRequest struct {
|
|
||||||
msg []byte
|
|
||||||
answer chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
|
|
||||||
debugLogf("{%s} connecting!\n", r.URL)
|
|
||||||
|
|
||||||
dialCtx := ctx
|
|
||||||
if _, ok := dialCtx.Deadline(); !ok {
|
|
||||||
// if no timeout is set, force it to 7 seconds
|
|
||||||
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
|
||||||
}
|
|
||||||
|
|
||||||
dialOpts := &ws.DialOptions{
|
|
||||||
HTTPHeader: http.Header{
|
|
||||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
|
|
||||||
},
|
|
||||||
CompressionMode: ws.CompressionContextTakeover,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
}
|
|
||||||
for k, v := range r.requestHeader {
|
|
||||||
dialOpts.HTTPHeader[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.SetReadLimit(2 << 24) // 33MB
|
|
||||||
|
|
||||||
// this will tell if the connection is closed
|
|
||||||
|
|
||||||
// ping every 29 seconds
|
|
||||||
ticker := time.NewTicker(29 * time.Second)
|
|
||||||
|
|
||||||
// main websocket loop
|
|
||||||
readQueue := make(chan string)
|
|
||||||
|
|
||||||
r.conn = c
|
|
||||||
r.writeQueue = make(chan writeRequest)
|
|
||||||
r.closed = &atomic.Bool{}
|
|
||||||
r.closedNotify = make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
pingAttempt := 0
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
r.closeConnection(ws.StatusNormalClosure, "")
|
|
||||||
debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx))
|
|
||||||
return
|
|
||||||
case <-r.closedNotify:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
debugLogf("{%s} pinging\n", r.URL)
|
|
||||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
|
|
||||||
err := c.Ping(ctx)
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
pingAttempt++
|
|
||||||
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
|
|
||||||
|
|
||||||
if pingAttempt >= 3 {
|
|
||||||
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
|
|
||||||
err = r.Close() // this should trigger a context cancelation
|
|
||||||
if err != nil {
|
|
||||||
debugLogf("{%s} failed to close relay: %v", r.URL, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// ping was OK
|
|
||||||
debugLogf("{%s} ping OK", r.URL)
|
|
||||||
pingAttempt = 0
|
|
||||||
case wr := <-r.writeQueue:
|
|
||||||
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
|
|
||||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
|
|
||||||
err := c.Write(ctx, ws.MessageText, wr.msg)
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
|
||||||
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
|
|
||||||
r.closeConnection(ws.StatusAbnormalClosure, "write failed")
|
|
||||||
if wr.answer != nil {
|
|
||||||
wr.answer <- err
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if wr.answer != nil {
|
|
||||||
close(wr.answer)
|
|
||||||
}
|
|
||||||
case msg := <-readQueue:
|
|
||||||
debugLogf("{%s} received %v\n", r.URL, msg)
|
|
||||||
r.handleMessage(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// read loop -- loops back to the main loop
|
|
||||||
go func() {
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
|
|
||||||
for {
|
|
||||||
buf.Reset()
|
|
||||||
|
|
||||||
_, reader, err := c.Reader(ctx)
|
|
||||||
if err != nil {
|
|
||||||
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
|
|
||||||
r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := io.Copy(buf, reader); err != nil {
|
|
||||||
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
|
|
||||||
r.closeConnection(ws.StatusAbnormalClosure, "failed to read")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
readQueue <- string(buf.Bytes())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
|
||||||
wasClosed := r.closed.Swap(true)
|
|
||||||
if !wasClosed {
|
|
||||||
r.conn.Close(code, reason)
|
|
||||||
r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason))
|
|
||||||
r.closeMutex.Lock()
|
|
||||||
close(r.closedNotify)
|
|
||||||
close(r.writeQueue)
|
|
||||||
r.conn = nil
|
|
||||||
r.closeMutex.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+1
-1
@@ -17,7 +17,7 @@ func TestEOSEMadness(t *testing.T) {
|
|||||||
}, SubscriptionOptions{})
|
}, SubscriptionOptions{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
timeout := time.After(3 * time.Second)
|
timeout := time.After(2 * time.Second)
|
||||||
n := 0
|
n := 0
|
||||||
e := 0
|
e := 0
|
||||||
|
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ type Store interface {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
[](https://pkg.go.dev/fiatjaf.com/nostr/eventstore) [](https://fiatjaf.com/nostr/eventstore/actions/workflows/test.yml)
|
|
||||||
|
|
||||||
## Available Implementations
|
## Available Implementations
|
||||||
|
|
||||||
- **bleve**: Full-text search and indexing using the Bleve search library
|
- **bleve**: Full-text search and indexing using the Bleve search library
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
"fiatjaf.com/nostr/eventstore/lmdb"
|
"fiatjaf.com/nostr/eventstore/lmdb"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBleveFlow(t *testing.T) {
|
func TestBleveFlow(t *testing.T) {
|
||||||
@@ -21,7 +22,9 @@ func TestBleveFlow(t *testing.T) {
|
|||||||
Path: "/tmp/blevetest-bleve",
|
Path: "/tmp/blevetest-bleve",
|
||||||
RawEventStore: bb,
|
RawEventStore: bb,
|
||||||
}
|
}
|
||||||
bl.Init()
|
err := bl.Init()
|
||||||
|
require.NoError(t, err, "init")
|
||||||
|
|
||||||
defer bl.Close()
|
defer bl.Close()
|
||||||
|
|
||||||
willDelete := make([]nostr.Event, 0, 3)
|
willDelete := make([]nostr.Event, 0, 3)
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
package bleve
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fiatjaf.com/nostr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b *BleveBackend) DeleteEvent(id nostr.ID) error {
|
|
||||||
return b.index.Delete(id.Hex())
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
package bleve
|
|
||||||
|
|
||||||
const (
|
|
||||||
idField = "i"
|
|
||||||
contentField = "c"
|
|
||||||
kindField = "k"
|
|
||||||
createdAtField = "a"
|
|
||||||
pubkeyField = "p"
|
|
||||||
)
|
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package bleve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// lexer tokenizes the input string
|
||||||
|
type Lexer struct {
|
||||||
|
input string
|
||||||
|
pos int
|
||||||
|
|
||||||
|
peekedQueue []Token
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLexer(input string) *Lexer {
|
||||||
|
return &Lexer{input: input, pos: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) peek() rune {
|
||||||
|
if l.pos >= len(l.input) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return rune(l.input[l.pos])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) advance() rune {
|
||||||
|
if l.pos >= len(l.input) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
ch := rune(l.input[l.pos])
|
||||||
|
l.pos++
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) skipWhitespace() {
|
||||||
|
for l.peek() != 0 && unicode.IsSpace(l.peek()) {
|
||||||
|
l.advance()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) readWord() string {
|
||||||
|
start := l.pos
|
||||||
|
|
||||||
|
// read regular word (alphanumeric, hyphens, underscores)
|
||||||
|
for l.peek() != 0 && !unicode.IsSpace(l.peek()) &&
|
||||||
|
l.peek() != '(' && l.peek() != ')' && l.peek() != '"' {
|
||||||
|
l.advance()
|
||||||
|
}
|
||||||
|
|
||||||
|
return l.input[start:l.pos]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) PeekToken() Token {
|
||||||
|
next := l.NextToken()
|
||||||
|
l.peekedQueue = append(l.peekedQueue, next)
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) ReturnToken(tok Token) {
|
||||||
|
l.peekedQueue = append(l.peekedQueue, tok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) NextToken() (tok Token) {
|
||||||
|
if len(l.peekedQueue) > 0 {
|
||||||
|
next := l.peekedQueue[len(l.peekedQueue)-1]
|
||||||
|
l.peekedQueue = l.peekedQueue[0 : len(l.peekedQueue)-1]
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
l.skipWhitespace()
|
||||||
|
|
||||||
|
if l.pos >= len(l.input) {
|
||||||
|
return Token{Type: TokenEOF}
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := l.peek()
|
||||||
|
|
||||||
|
switch ch {
|
||||||
|
case '(':
|
||||||
|
l.advance()
|
||||||
|
return Token{Type: TokenLParen, Value: "("}
|
||||||
|
case ')':
|
||||||
|
l.advance()
|
||||||
|
return Token{Type: TokenRParen, Value: ")"}
|
||||||
|
case '"':
|
||||||
|
l.advance()
|
||||||
|
return Token{Type: TokenQuote, Value: "\""}
|
||||||
|
default:
|
||||||
|
word := l.readWord()
|
||||||
|
upperWord := strings.ToUpper(word)
|
||||||
|
|
||||||
|
switch upperWord {
|
||||||
|
case "OR", "||":
|
||||||
|
return Token{Type: TokenOR, Value: word}
|
||||||
|
case "AND", "&&":
|
||||||
|
return Token{Type: TokenAND, Value: word}
|
||||||
|
case "NOT", "!":
|
||||||
|
return Token{Type: TokenNOT, Value: word}
|
||||||
|
default:
|
||||||
|
return Token{Type: TokenWord, Value: word}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
+425
-15
@@ -1,34 +1,101 @@
|
|||||||
package bleve
|
package bleve
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
"fiatjaf.com/nostr/eventstore"
|
"fiatjaf.com/nostr/eventstore"
|
||||||
|
"fiatjaf.com/nostr/nip27"
|
||||||
|
"fiatjaf.com/nostr/nip73"
|
||||||
|
"fiatjaf.com/nostr/sdk"
|
||||||
bleve "github.com/blevesearch/bleve/v2"
|
bleve "github.com/blevesearch/bleve/v2"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/analyzer/simple"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/ar"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/cjk"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/da"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/de"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/en"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/es"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/fa"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/fi"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/fr"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/gl"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/hi"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/hr"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/hu"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/in"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/it"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/nl"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/no"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/pl"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/pt"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/ro"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/ru"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/sv"
|
||||||
|
_ "github.com/blevesearch/bleve/v2/analysis/lang/tr"
|
||||||
bleveMapping "github.com/blevesearch/bleve/v2/mapping"
|
bleveMapping "github.com/blevesearch/bleve/v2/mapping"
|
||||||
|
bleveQuery "github.com/blevesearch/bleve/v2/search/query"
|
||||||
|
"github.com/pemistahl/lingua-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ eventstore.Store = (*BleveBackend)(nil)
|
const (
|
||||||
|
labelContentField = "c"
|
||||||
|
labelKindField = "k"
|
||||||
|
labelCreatedAtField = "a"
|
||||||
|
labelAuthorField = "p"
|
||||||
|
labelReferencesField = "r"
|
||||||
|
labelExtrasField = "x"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SupportedLanguages = []lingua.Language{
|
||||||
|
// each of these translates to a specific bleve analyzer
|
||||||
|
// except for japanese-korean-chinese that all use the same "cjk" analyzer
|
||||||
|
lingua.Arabic,
|
||||||
|
lingua.Chinese,
|
||||||
|
lingua.Croatian,
|
||||||
|
lingua.Danish,
|
||||||
|
lingua.Dutch,
|
||||||
|
lingua.English,
|
||||||
|
lingua.Finnish,
|
||||||
|
lingua.French,
|
||||||
|
lingua.German,
|
||||||
|
lingua.Hindi,
|
||||||
|
lingua.Hungarian,
|
||||||
|
lingua.Italian,
|
||||||
|
lingua.Japanese,
|
||||||
|
lingua.Korean,
|
||||||
|
lingua.Persian,
|
||||||
|
lingua.Polish,
|
||||||
|
lingua.Portuguese,
|
||||||
|
lingua.Romanian,
|
||||||
|
lingua.Russian,
|
||||||
|
lingua.Spanish,
|
||||||
|
lingua.Swedish,
|
||||||
|
lingua.Turkish,
|
||||||
|
}
|
||||||
|
|
||||||
type BleveBackend struct {
|
type BleveBackend struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
// Path is where the index will be saved
|
Path string
|
||||||
Path string
|
|
||||||
|
|
||||||
// RawEventStore is where we'll fetch the raw events from
|
|
||||||
// bleve will only store ids, so the actual events must be somewhere else
|
|
||||||
RawEventStore eventstore.Store
|
RawEventStore eventstore.Store
|
||||||
|
ReadOnly bool
|
||||||
|
OpenTimeout time.Duration
|
||||||
|
|
||||||
index bleve.Index
|
IndexableKinds []nostr.Kind
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BleveBackend) Close() {
|
Languages []lingua.Language
|
||||||
if b.index != nil {
|
languageCodes []string
|
||||||
b.index.Close()
|
|
||||||
}
|
index bleve.Index
|
||||||
|
detector lingua.LanguageDetector
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BleveBackend) Init() error {
|
func (b *BleveBackend) Init() error {
|
||||||
@@ -38,12 +105,94 @@ func (b *BleveBackend) Init() error {
|
|||||||
if b.RawEventStore == nil {
|
if b.RawEventStore == nil {
|
||||||
return fmt.Errorf("missing RawEventStore")
|
return fmt.Errorf("missing RawEventStore")
|
||||||
}
|
}
|
||||||
|
if len(b.Languages) == 0 {
|
||||||
|
return fmt.Errorf("missing Languages")
|
||||||
|
}
|
||||||
|
if len(b.IndexableKinds) == 0 {
|
||||||
|
b.IndexableKinds = []nostr.Kind{0, 1, 6, 11, 16, 20, 21, 22, 24, 1111, 9802, 30023, 30818}
|
||||||
|
}
|
||||||
|
|
||||||
// try to open existing index
|
validLanguages := make([]lingua.Language, 0, len(b.Languages))
|
||||||
index, err := bleve.Open(b.Path)
|
b.languageCodes = make([]string, 0, len(b.Languages))
|
||||||
|
for _, lang := range b.Languages {
|
||||||
|
var code string
|
||||||
|
|
||||||
|
switch lang {
|
||||||
|
case lingua.Chinese, lingua.Korean, lingua.Japanese:
|
||||||
|
code = "cjk"
|
||||||
|
default:
|
||||||
|
code = strings.ToLower(lang.IsoCode639_1().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(b.languageCodes, code) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
validLanguages = append(validLanguages, lang)
|
||||||
|
b.languageCodes = append(b.languageCodes, code)
|
||||||
|
}
|
||||||
|
b.Languages = validLanguages
|
||||||
|
|
||||||
|
opts := map[string]any{
|
||||||
|
"read_only": b.ReadOnly,
|
||||||
|
}
|
||||||
|
if b.OpenTimeout != 0 {
|
||||||
|
opts["bolt_timeout"] = b.OpenTimeout.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
index, err := bleve.OpenUsing(b.Path, opts)
|
||||||
if err == bleve.ErrorIndexPathDoesNotExist {
|
if err == bleve.ErrorIndexPathDoesNotExist {
|
||||||
// create new index with default mapping
|
|
||||||
mapping := bleveMapping.NewIndexMapping()
|
mapping := bleveMapping.NewIndexMapping()
|
||||||
|
mapping.DefaultMapping.Dynamic = false
|
||||||
|
doc := bleveMapping.NewDocumentStaticMapping()
|
||||||
|
|
||||||
|
for _, code := range b.languageCodes {
|
||||||
|
contentField := bleveMapping.NewTextFieldMapping()
|
||||||
|
contentField.Analyzer = code
|
||||||
|
contentField.Store = false
|
||||||
|
contentField.IncludeTermVectors = false
|
||||||
|
contentField.DocValues = false
|
||||||
|
contentField.IncludeInAll = false
|
||||||
|
doc.AddFieldMappingsAt(labelContentField+"_"+code, contentField)
|
||||||
|
}
|
||||||
|
|
||||||
|
extrasField := bleveMapping.NewTextFieldMapping()
|
||||||
|
extrasField.Analyzer = "simple"
|
||||||
|
extrasField.Store = false
|
||||||
|
extrasField.IncludeTermVectors = false
|
||||||
|
extrasField.DocValues = false
|
||||||
|
extrasField.IncludeInAll = false
|
||||||
|
doc.AddFieldMappingsAt(labelExtrasField, extrasField)
|
||||||
|
|
||||||
|
referencesField := bleveMapping.NewKeywordFieldMapping()
|
||||||
|
referencesField.DocValues = false
|
||||||
|
referencesField.Store = false
|
||||||
|
referencesField.IncludeTermVectors = false
|
||||||
|
referencesField.IncludeInAll = false
|
||||||
|
doc.AddFieldMappingsAt(labelReferencesField, referencesField)
|
||||||
|
|
||||||
|
authorField := bleveMapping.NewKeywordFieldMapping()
|
||||||
|
authorField.DocValues = false
|
||||||
|
authorField.Store = false
|
||||||
|
authorField.IncludeTermVectors = false
|
||||||
|
doc.AddFieldMappingsAt(labelAuthorField, authorField)
|
||||||
|
|
||||||
|
kindField := bleveMapping.NewKeywordFieldMapping()
|
||||||
|
kindField.DocValues = false
|
||||||
|
kindField.Store = false
|
||||||
|
kindField.IncludeTermVectors = false
|
||||||
|
kindField.IncludeInAll = false
|
||||||
|
doc.AddFieldMappingsAt(labelKindField, kindField)
|
||||||
|
|
||||||
|
timestampField := bleveMapping.NewDateTimeFieldMapping()
|
||||||
|
timestampField.DocValues = false
|
||||||
|
timestampField.Store = false
|
||||||
|
timestampField.IncludeTermVectors = false
|
||||||
|
timestampField.IncludeInAll = false
|
||||||
|
doc.AddFieldMappingsAt(labelCreatedAtField, timestampField)
|
||||||
|
|
||||||
|
mapping.AddDocumentMapping("_default", doc)
|
||||||
|
|
||||||
index, err = bleve.New(b.Path, mapping)
|
index, err = bleve.New(b.Path, mapping)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating index: %w", err)
|
return fmt.Errorf("error creating index: %w", err)
|
||||||
@@ -53,6 +202,116 @@ func (b *BleveBackend) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.index = index
|
b.index = index
|
||||||
|
b.detector = lingua.NewLanguageDetectorBuilder().
|
||||||
|
FromLanguages(b.Languages...).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BleveBackend) Close() {
|
||||||
|
if b != nil && b.index != nil {
|
||||||
|
b.index.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BleveBackend) SaveEvent(event nostr.Event) error {
|
||||||
|
if slices.Contains(b.IndexableKinds, event.Kind) {
|
||||||
|
return b.indexEvent(event)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BleveBackend) DeleteEvent(id nostr.ID) error {
|
||||||
|
if b != nil && b.index != nil {
|
||||||
|
return b.index.Delete(id.Hex())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BleveBackend) indexEvent(evt nostr.Event) error {
|
||||||
|
docID := evt.ID
|
||||||
|
|
||||||
|
var references []string
|
||||||
|
var extras string
|
||||||
|
|
||||||
|
switch evt.Kind {
|
||||||
|
case 6, 16:
|
||||||
|
var innerEvt nostr.Event
|
||||||
|
if err := json.Unmarshal([]byte(evt.Content), &innerEvt); err != nil || !innerEvt.VerifySignature() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
evt = innerEvt
|
||||||
|
case 0:
|
||||||
|
var pm sdk.ProfileMetadata
|
||||||
|
if err := json.Unmarshal([]byte(evt.Content), &pm); err == nil {
|
||||||
|
evt.Content = pm.Name + "\n" + pm.DisplayName + "\n" + pm.About
|
||||||
|
references = append(references, pm.NIP05)
|
||||||
|
}
|
||||||
|
case 9802:
|
||||||
|
for _, tag := range evt.Tags {
|
||||||
|
if len(tag) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch tag[0] {
|
||||||
|
case "comment":
|
||||||
|
evt.Content += "\n\n" + tag[1]
|
||||||
|
case "e":
|
||||||
|
if ptr, err := nostr.EventPointerFromTag(tag); err == nil {
|
||||||
|
references = append(references, ptr.AsTagReference())
|
||||||
|
}
|
||||||
|
case "a":
|
||||||
|
if ptr, err := nostr.EntityPointerFromTag(tag); err == nil {
|
||||||
|
references = append(references, ptr.AsTagReference())
|
||||||
|
}
|
||||||
|
case "r":
|
||||||
|
references = append(references, tag[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
doc := map[string]any{
|
||||||
|
labelKindField: strconv.Itoa(int(evt.Kind)),
|
||||||
|
labelAuthorField: evt.PubKey.Hex()[56:],
|
||||||
|
labelCreatedAtField: evt.CreatedAt.Time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := strings.Builder{}
|
||||||
|
content.Grow(len(evt.Content))
|
||||||
|
|
||||||
|
for block := range nip27.Parse(evt.Content) {
|
||||||
|
if block.Pointer == nil {
|
||||||
|
content.WriteString(strings.TrimSpace(block.Text))
|
||||||
|
} else {
|
||||||
|
references = append(references, block.Pointer.AsTagReference())
|
||||||
|
if ep, ok := block.Pointer.(nip73.ExternalPointer); ok {
|
||||||
|
extras += ep.Thing + " "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
indexableContent := content.String()
|
||||||
|
lang, ok := b.detector.DetectLanguageOf(indexableContent)
|
||||||
|
if !ok {
|
||||||
|
lang = lingua.English
|
||||||
|
}
|
||||||
|
|
||||||
|
var analyzerLangCode string
|
||||||
|
switch lang {
|
||||||
|
case lingua.Japanese, lingua.Chinese, lingua.Korean:
|
||||||
|
analyzerLangCode = "cjk"
|
||||||
|
default:
|
||||||
|
analyzerLangCode = strings.ToLower(lang.IsoCode639_1().String())
|
||||||
|
}
|
||||||
|
doc[labelContentField+"_"+analyzerLangCode] = indexableContent
|
||||||
|
|
||||||
|
doc[labelReferencesField] = references
|
||||||
|
doc[labelExtrasField] = extras
|
||||||
|
|
||||||
|
if err := b.index.Index(docID.Hex(), doc); err != nil {
|
||||||
|
return fmt.Errorf("failed to index '%s' document: %w", docID.Hex(), err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,3 +323,154 @@ func (b *BleveBackend) CountEvents(filter nostr.Filter) (uint32, error) {
|
|||||||
|
|
||||||
return 0, errors.New("not supported")
|
return 0, errors.New("not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BleveBackend) QueryEvents(filter nostr.Filter, maxLimit int) iter.Seq[nostr.Event] {
|
||||||
|
return func(yield func(nostr.Event) bool) {
|
||||||
|
if tlimit := filter.GetTheoreticalLimit(); tlimit == 0 {
|
||||||
|
return
|
||||||
|
} else if tlimit < maxLimit {
|
||||||
|
maxLimit = tlimit
|
||||||
|
}
|
||||||
|
|
||||||
|
filter.Search = strings.TrimSpace(filter.Search)
|
||||||
|
if len(filter.Search) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
and := make([]bleveQuery.Query, 0, 3)
|
||||||
|
|
||||||
|
searchC := strings.Builder{}
|
||||||
|
searchC.Grow(len(filter.Search))
|
||||||
|
|
||||||
|
for block := range nip27.Parse(filter.Search) {
|
||||||
|
if block.Pointer != nil {
|
||||||
|
genericRef := bleve.NewTermQuery(block.Pointer.AsTagReference())
|
||||||
|
genericRef.SetField(labelReferencesField)
|
||||||
|
genericRef.SetBoost(2)
|
||||||
|
|
||||||
|
var ref bleveQuery.Query = genericRef
|
||||||
|
if profile, ok := block.Pointer.(nostr.ProfilePointer); ok {
|
||||||
|
authorQuery := bleve.NewTermQuery(profile.PublicKey.Hex()[56:])
|
||||||
|
authorQuery.SetField(labelAuthorField)
|
||||||
|
authorQuery.SetBoost(2)
|
||||||
|
orRef := bleve.NewDisjunctionQuery()
|
||||||
|
orRef.AddQuery(genericRef)
|
||||||
|
orRef.AddQuery(authorQuery)
|
||||||
|
ref = orRef
|
||||||
|
} else if addr, ok := block.Pointer.(nostr.EntityPointer); ok {
|
||||||
|
authorQuery := bleve.NewTermQuery(addr.PublicKey.Hex()[56:])
|
||||||
|
authorQuery.SetField(labelAuthorField)
|
||||||
|
authorQuery.SetBoost(2)
|
||||||
|
orRef := bleve.NewDisjunctionQuery()
|
||||||
|
orRef.AddQuery(genericRef)
|
||||||
|
orRef.AddQuery(authorQuery)
|
||||||
|
ref = orRef
|
||||||
|
}
|
||||||
|
and = append(and, ref)
|
||||||
|
} else {
|
||||||
|
searchC.WriteString(strings.TrimSpace(block.Text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
searchContent := searchC.String()
|
||||||
|
|
||||||
|
var exactMatches []string
|
||||||
|
if len(searchContent) > 0 {
|
||||||
|
contentQueries := make([]bleveQuery.Query, 0, len(b.Languages)+1)
|
||||||
|
|
||||||
|
searchQ, exactMatches_, err := parse(searchContent, labelContentField+"_"+b.languageCodes[0])
|
||||||
|
if err != nil {
|
||||||
|
for _, code := range b.languageCodes {
|
||||||
|
match := bleve.NewMatchQuery(searchContent)
|
||||||
|
match.SetField(labelContentField + "_" + code)
|
||||||
|
contentQueries = append(contentQueries, match)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
contentQueries = append(contentQueries, searchQ)
|
||||||
|
for _, code := range b.languageCodes[1:] {
|
||||||
|
searchQ, _, _ := parse(searchContent, labelContentField+"_"+code)
|
||||||
|
contentQueries = append(contentQueries, searchQ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
exactMatches = exactMatches_
|
||||||
|
|
||||||
|
extrasQ := bleve.NewMatchQuery(searchContent)
|
||||||
|
extrasQ.SetField(labelExtrasField)
|
||||||
|
contentQueries = append(contentQueries, extrasQ)
|
||||||
|
|
||||||
|
and = append(and, bleveQuery.NewDisjunctionQuery(contentQueries))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filter.Kinds) > 0 {
|
||||||
|
eitherKind := bleve.NewDisjunctionQuery()
|
||||||
|
for _, kind := range filter.Kinds {
|
||||||
|
kindQ := bleve.NewTermQuery(strconv.Itoa(int(kind)))
|
||||||
|
kindQ.SetField(labelKindField)
|
||||||
|
eitherKind.AddQuery(kindQ)
|
||||||
|
}
|
||||||
|
and = append(and, eitherKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filter.Authors) > 0 {
|
||||||
|
eitherPubkey := bleve.NewDisjunctionQuery()
|
||||||
|
for _, pubkey := range filter.Authors {
|
||||||
|
pubkeyQ := bleve.NewTermQuery(pubkey.Hex()[56:])
|
||||||
|
pubkeyQ.SetField(labelAuthorField)
|
||||||
|
eitherPubkey.AddQuery(pubkeyQ)
|
||||||
|
}
|
||||||
|
and = append(and, eitherPubkey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Since != 0 || filter.Until != 0 {
|
||||||
|
var min time.Time
|
||||||
|
if filter.Since != 0 {
|
||||||
|
min = filter.Since.Time()
|
||||||
|
}
|
||||||
|
var max time.Time
|
||||||
|
if filter.Until != 0 {
|
||||||
|
max = filter.Until.Time()
|
||||||
|
} else {
|
||||||
|
max = time.Now()
|
||||||
|
}
|
||||||
|
dateRangeQ := bleve.NewDateRangeQuery(min, max)
|
||||||
|
dateRangeQ.SetField(labelCreatedAtField)
|
||||||
|
and = append(and, dateRangeQ)
|
||||||
|
}
|
||||||
|
|
||||||
|
q := bleveQuery.NewConjunctionQuery(and)
|
||||||
|
req := bleve.NewSearchRequest(q)
|
||||||
|
req.Size = maxLimit
|
||||||
|
req.From = 0
|
||||||
|
req.Explain = true
|
||||||
|
|
||||||
|
result, err := b.index.Search(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resultHit:
|
||||||
|
for _, hit := range result.Hits {
|
||||||
|
id, err := nostr.IDFromHex(hit.ID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for evt := range b.RawEventStore.QueryEvents(nostr.Filter{IDs: []nostr.ID{id}}, 1) {
|
||||||
|
for _, exactMatch := range exactMatches {
|
||||||
|
if !strings.Contains(strings.ToLower(evt.Content), exactMatch) {
|
||||||
|
continue resultHit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for f, v := range filter.Tags {
|
||||||
|
if !evt.Tags.ContainsAny(f, v) {
|
||||||
|
continue resultHit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(evt) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
package bleve
|
|
||||||
|
|
||||||
import (
|
|
||||||
"iter"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
|
||||||
bleve "github.com/blevesearch/bleve/v2"
|
|
||||||
"github.com/blevesearch/bleve/v2/search/query"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b *BleveBackend) QueryEvents(filter nostr.Filter, maxLimit int) iter.Seq[nostr.Event] {
|
|
||||||
return func(yield func(nostr.Event) bool) {
|
|
||||||
if tlimit := filter.GetTheoreticalLimit(); tlimit == 0 {
|
|
||||||
return
|
|
||||||
} else if tlimit < maxLimit {
|
|
||||||
maxLimit = tlimit
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(filter.Search) < 2 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
searchQ := bleve.NewMatchQuery(filter.Search)
|
|
||||||
searchQ.SetField(contentField)
|
|
||||||
var q query.Query = searchQ
|
|
||||||
|
|
||||||
conjQueries := []query.Query{searchQ}
|
|
||||||
|
|
||||||
if len(filter.Kinds) > 0 {
|
|
||||||
eitherKind := bleve.NewDisjunctionQuery()
|
|
||||||
for _, kind := range filter.Kinds {
|
|
||||||
kindQ := bleve.NewTermQuery(strconv.Itoa(int(kind)))
|
|
||||||
kindQ.SetField(kindField)
|
|
||||||
eitherKind.AddQuery(kindQ)
|
|
||||||
}
|
|
||||||
conjQueries = append(conjQueries, eitherKind)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(filter.Authors) > 0 {
|
|
||||||
eitherPubkey := bleve.NewDisjunctionQuery()
|
|
||||||
for _, pubkey := range filter.Authors {
|
|
||||||
if len(pubkey) != 64 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pubkeyQ := bleve.NewTermQuery(pubkey.Hex()[56:])
|
|
||||||
pubkeyQ.SetField(pubkeyField)
|
|
||||||
eitherPubkey.AddQuery(pubkeyQ)
|
|
||||||
}
|
|
||||||
conjQueries = append(conjQueries, eitherPubkey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if filter.Since != 0 || filter.Until != 0 {
|
|
||||||
var min *float64
|
|
||||||
if filter.Since != 0 {
|
|
||||||
minVal := float64(filter.Since)
|
|
||||||
min = &minVal
|
|
||||||
}
|
|
||||||
var max *float64
|
|
||||||
if filter.Until != 0 {
|
|
||||||
maxVal := float64(filter.Until)
|
|
||||||
max = &maxVal
|
|
||||||
}
|
|
||||||
dateRangeQ := bleve.NewNumericRangeInclusiveQuery(min, max, nil, nil)
|
|
||||||
dateRangeQ.SetField(createdAtField)
|
|
||||||
conjQueries = append(conjQueries, dateRangeQ)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(conjQueries) > 1 {
|
|
||||||
q = bleve.NewConjunctionQuery(conjQueries...)
|
|
||||||
}
|
|
||||||
|
|
||||||
req := bleve.NewSearchRequest(q)
|
|
||||||
req.Size = maxLimit
|
|
||||||
req.From = 0
|
|
||||||
|
|
||||||
result, err := b.index.Search(req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, hit := range result.Hits {
|
|
||||||
id, err := nostr.IDFromHex(hit.ID)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for evt := range b.RawEventStore.QueryEvents(nostr.Filter{IDs: []nostr.ID{id}}, 1) {
|
|
||||||
if !yield(evt) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
package bleve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
bleve "github.com/blevesearch/bleve/v2"
|
||||||
|
bleveQuery "github.com/blevesearch/bleve/v2/search/query"
|
||||||
|
)
|
||||||
|
|
||||||
|
// token types
|
||||||
|
type TokenType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenWord TokenType = iota
|
||||||
|
TokenOR
|
||||||
|
TokenAND
|
||||||
|
TokenNOT
|
||||||
|
TokenLParen
|
||||||
|
TokenRParen
|
||||||
|
TokenQuote
|
||||||
|
TokenEOF
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
Type TokenType
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Parser struct {
|
||||||
|
lexer *Lexer
|
||||||
|
field string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parse(input string, field string) (bleveQuery.Query, []string, error) {
|
||||||
|
lexer := NewLexer(input)
|
||||||
|
p := &Parser{
|
||||||
|
lexer: lexer,
|
||||||
|
}
|
||||||
|
|
||||||
|
var exactMatches []string
|
||||||
|
var reusableCurrentMatch strings.Builder
|
||||||
|
var currentExactMatch *strings.Builder
|
||||||
|
var currentWords []string
|
||||||
|
var negated bool
|
||||||
|
var parents []bleveQuery.Query
|
||||||
|
var parentOps []TokenType // tracks if parent should be AND or OR
|
||||||
|
var lastOp TokenType = TokenAND // track last operator for parentheses
|
||||||
|
|
||||||
|
curr := bleve.NewBooleanQuery()
|
||||||
|
|
||||||
|
for {
|
||||||
|
token := p.lexer.NextToken()
|
||||||
|
|
||||||
|
if token.Type == TokenEOF {
|
||||||
|
if len(currentWords) > 0 {
|
||||||
|
match := bleve.NewMatchQuery(strings.Join(currentWords, " "))
|
||||||
|
match.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||||
|
match.SetField(field)
|
||||||
|
if negated {
|
||||||
|
curr.AddMustNot(match)
|
||||||
|
} else {
|
||||||
|
curr.AddMust(match)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.Type == TokenQuote {
|
||||||
|
if currentExactMatch == nil {
|
||||||
|
currentExactMatch = &reusableCurrentMatch
|
||||||
|
} else {
|
||||||
|
exactMatches = append(exactMatches, currentExactMatch.String())
|
||||||
|
currentExactMatch.Reset()
|
||||||
|
reusableCurrentMatch = *currentExactMatch
|
||||||
|
currentExactMatch = nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentExactMatch != nil {
|
||||||
|
if currentExactMatch.Len() > 0 {
|
||||||
|
currentExactMatch.WriteByte(' ')
|
||||||
|
}
|
||||||
|
currentExactMatch.WriteString(strings.ToLower(token.Value))
|
||||||
|
currentWords = append(currentWords, token.Value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.Type == TokenWord {
|
||||||
|
currentWords = append(currentWords, token.Value)
|
||||||
|
continue
|
||||||
|
} else if len(currentWords) > 0 {
|
||||||
|
match := bleve.NewMatchQuery(strings.Join(currentWords, " "))
|
||||||
|
match.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||||
|
match.SetField(field)
|
||||||
|
if negated {
|
||||||
|
curr.AddMustNot(match)
|
||||||
|
} else {
|
||||||
|
curr.AddMust(match)
|
||||||
|
}
|
||||||
|
currentWords = currentWords[:0]
|
||||||
|
negated = false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch token.Type {
|
||||||
|
case TokenLParen:
|
||||||
|
// push current query to parents stack with the last operator
|
||||||
|
parents = append(parents, curr)
|
||||||
|
parentOps = append(parentOps, lastOp)
|
||||||
|
// reset lastOp to default for inner parentheses
|
||||||
|
lastOp = TokenAND
|
||||||
|
// start new boolean query for parentheses content
|
||||||
|
curr = bleve.NewBooleanQuery()
|
||||||
|
continue
|
||||||
|
case TokenRParen:
|
||||||
|
// finalize any remaining words
|
||||||
|
if len(currentWords) > 0 {
|
||||||
|
match := bleve.NewMatchQuery(strings.Join(currentWords, " "))
|
||||||
|
match.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||||
|
match.SetField(field)
|
||||||
|
if negated {
|
||||||
|
curr.AddMustNot(match)
|
||||||
|
} else {
|
||||||
|
curr.AddMust(match)
|
||||||
|
}
|
||||||
|
currentWords = currentWords[:0]
|
||||||
|
negated = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// pop parent and merge with current
|
||||||
|
if len(parents) > 0 {
|
||||||
|
parent := parents[len(parents)-1]
|
||||||
|
op := parentOps[len(parentOps)-1]
|
||||||
|
|
||||||
|
// create a new boolean query to combine parent and current
|
||||||
|
var combined bleveQuery.Query
|
||||||
|
switch op {
|
||||||
|
case TokenOR:
|
||||||
|
or := bleve.NewDisjunctionQuery()
|
||||||
|
or.AddQuery(parent)
|
||||||
|
or.AddQuery(curr)
|
||||||
|
combined = or
|
||||||
|
case TokenAND:
|
||||||
|
and := bleve.NewConjunctionQuery()
|
||||||
|
and.AddQuery(parent)
|
||||||
|
and.AddQuery(curr)
|
||||||
|
combined = and
|
||||||
|
}
|
||||||
|
|
||||||
|
curr = bleve.NewBooleanQuery()
|
||||||
|
curr.AddMust(combined)
|
||||||
|
parents = parents[:len(parents)-1]
|
||||||
|
parentOps = parentOps[:len(parentOps)-1]
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
next := p.lexer.NextToken()
|
||||||
|
following := p.lexer.PeekToken()
|
||||||
|
if next.Type == TokenNOT {
|
||||||
|
negated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch token.Type {
|
||||||
|
case TokenOR:
|
||||||
|
if next.Type != TokenLParen && !(next.Type == TokenNOT && following.Type == TokenLParen) {
|
||||||
|
// if this is not followed by a "(" or "NOT (" consider the follow next word as the only parameter
|
||||||
|
other := bleve.NewMatchQuery(next.Value)
|
||||||
|
other.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||||
|
other.SetField(field)
|
||||||
|
or := bleve.NewDisjunctionQuery()
|
||||||
|
or.AddQuery(curr)
|
||||||
|
or.AddQuery(other)
|
||||||
|
curr = bleve.NewBooleanQuery()
|
||||||
|
curr.AddMust(or)
|
||||||
|
} else {
|
||||||
|
lastOp = TokenOR
|
||||||
|
}
|
||||||
|
case TokenAND:
|
||||||
|
if next.Type != TokenLParen && !(next.Type == TokenNOT && following.Type == TokenLParen) {
|
||||||
|
// if this is not followed by a "(" consider the follow next word as the only parameter
|
||||||
|
other := bleve.NewMatchQuery(next.Value)
|
||||||
|
other.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||||
|
other.SetField(field)
|
||||||
|
and := bleve.NewConjunctionQuery()
|
||||||
|
and.AddQuery(curr)
|
||||||
|
and.AddQuery(other)
|
||||||
|
curr = bleve.NewBooleanQuery()
|
||||||
|
curr.AddMust(and)
|
||||||
|
} else {
|
||||||
|
lastOp = TokenAND
|
||||||
|
}
|
||||||
|
case TokenNOT:
|
||||||
|
if next.Type != TokenLParen {
|
||||||
|
// if this is not followed by a "(" or "NOT (" consider the follow next word as the only parameter
|
||||||
|
other := bleve.NewMatchQuery(next.Value)
|
||||||
|
other.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||||
|
other.SetField(field)
|
||||||
|
curr.AddMustNot(other)
|
||||||
|
} else {
|
||||||
|
negated = true
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
p.lexer.ReturnToken(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return curr, exactMatches, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package bleve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/blevesearch/bleve/v2"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseQuery(t *testing.T) {
|
||||||
|
mapping := bleve.NewIndexMapping()
|
||||||
|
mapping.DefaultAnalyzer = "en"
|
||||||
|
index, err := bleve.NewMemOnly(mapping)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
docs := []map[string]interface{}{
|
||||||
|
{"id": "1", "phrase": "I like fruit especially banana and strawberry"},
|
||||||
|
{"id": "2", "phrase": "I like fruit like apples and oranges"},
|
||||||
|
{"id": "3", "phrase": "I like vegetables but not fruit"},
|
||||||
|
{"id": "4", "phrase": "Banana bread is delicious"},
|
||||||
|
{"id": "5", "phrase": "Strawberry jam and banana smoothie"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, doc := range docs {
|
||||||
|
err := index.Index(doc["id"].(string), doc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testQueries := []struct {
|
||||||
|
query string
|
||||||
|
expected int
|
||||||
|
exactMatches []string
|
||||||
|
}{
|
||||||
|
{"fruit", 3, nil},
|
||||||
|
{"banana (NOT delicious)", 2, nil},
|
||||||
|
{"banana (NOT delicious) bread", 0, nil},
|
||||||
|
{"smoothie OR apples", 2, nil},
|
||||||
|
{"smoothie OR apples (NOT fruit)", 1, nil},
|
||||||
|
{"\"I like\"", 3, []string{"i like"}},
|
||||||
|
{"banana \"I like fruit\" strawberries", 1, []string{"i like fruit"}},
|
||||||
|
{"\"I like fruit\" (strawberry OR apple)", 2, []string{"i like fruit"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testQueries {
|
||||||
|
query, exactMatches, err := parse(test.query, "phrase")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, test.exactMatches, exactMatches)
|
||||||
|
|
||||||
|
search := bleve.NewSearchRequest(query)
|
||||||
|
results, err := index.Search(search)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, test.expected, int(results.Total),
|
||||||
|
"query '%s' expected %d results, got %d", test.query, test.expected, results.Total)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package bleve
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
|
||||||
"fiatjaf.com/nostr/eventstore"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b *BleveBackend) ReplaceEvent(evt nostr.Event) error {
|
|
||||||
b.Lock()
|
|
||||||
defer b.Unlock()
|
|
||||||
|
|
||||||
filter := nostr.Filter{Kinds: []nostr.Kind{evt.Kind}, Authors: []nostr.PubKey{evt.PubKey}}
|
|
||||||
if evt.Kind.IsAddressable() {
|
|
||||||
filter.Tags = nostr.TagMap{"d": []string{evt.Tags.GetD()}}
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldStore := true
|
|
||||||
for previous := range b.QueryEvents(filter, 1) {
|
|
||||||
if nostr.IsOlder(previous, evt) {
|
|
||||||
if err := b.DeleteEvent(previous.ID); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete event for replacing: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
shouldStore = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldStore {
|
|
||||||
if err := b.SaveEvent(evt); err != nil && err != eventstore.ErrDupEvent {
|
|
||||||
return fmt.Errorf("failed to save: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package bleve
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b *BleveBackend) SaveEvent(evt nostr.Event) error {
|
|
||||||
doc := map[string]interface{}{
|
|
||||||
contentField: evt.Content,
|
|
||||||
kindField: strconv.Itoa(int(evt.Kind)),
|
|
||||||
pubkeyField: evt.PubKey.Hex()[56:],
|
|
||||||
createdAtField: float64(evt.CreatedAt),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := b.index.Index(evt.ID.Hex(), doc); err != nil {
|
|
||||||
return fmt.Errorf("failed to index '%s' document: %w", evt.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -28,6 +28,8 @@ type BoltBackend struct {
|
|||||||
MapSize int64
|
MapSize int64
|
||||||
DB *bbolt.DB
|
DB *bbolt.DB
|
||||||
|
|
||||||
|
ReadOnly bool
|
||||||
|
|
||||||
EnableHLLCacheFor func(kind nostr.Kind) (useCache bool, skipSavingActualEvent bool)
|
EnableHLLCacheFor func(kind nostr.Kind) (useCache bool, skipSavingActualEvent bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,6 +38,7 @@ func (b *BoltBackend) Init() error {
|
|||||||
Timeout: 2 * time.Second,
|
Timeout: 2 * time.Second,
|
||||||
PreLoadFreelist: true,
|
PreLoadFreelist: true,
|
||||||
FreelistType: bbolt.FreelistMapType,
|
FreelistType: bbolt.FreelistMapType,
|
||||||
|
ReadOnly: b.ReadOnly,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
"go.etcd.io/bbolt"
|
"go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
func (b *BoltBackend) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||||
return b.DB.Update(func(txn *bbolt.Tx) error {
|
err = b.DB.Update(func(txn *bbolt.Tx) error {
|
||||||
rawBucket := txn.Bucket(rawEventStore)
|
rawBucket := txn.Bucket(rawEventStore)
|
||||||
|
|
||||||
// check if we already have this id
|
// check if we already have this id
|
||||||
@@ -25,12 +25,12 @@ func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// now we fetch the past events, whatever they are, delete them and then save the new
|
// now we fetch the past events, whatever they are, delete them and then save the new
|
||||||
var err error
|
var qerr error
|
||||||
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
||||||
err = b.query(txn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
qerr = b.query(txn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if qerr != nil {
|
||||||
return fmt.Errorf("failed to query past events with %s: %w", filter, err)
|
return fmt.Errorf("failed to query past events with %s: %w", filter, qerr)
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldStore := true
|
shouldStore := true
|
||||||
@@ -39,6 +39,7 @@ func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
|||||||
if err := b.delete(txn, previous.ID); err != nil {
|
if err := b.delete(txn, previous.ID); err != nil {
|
||||||
return fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, err)
|
return fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, err)
|
||||||
}
|
}
|
||||||
|
deleted = append(deleted, previous)
|
||||||
} else {
|
} else {
|
||||||
// there is a newer event already stored, so we won't store this
|
// there is a newer event already stored, so we won't store this
|
||||||
shouldStore = false
|
shouldStore = false
|
||||||
@@ -50,4 +51,5 @@ func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
return deleted, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,12 +40,12 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
|||||||
buf[0] = 0
|
buf[0] = 0
|
||||||
|
|
||||||
if evt.Kind > MaxKind {
|
if evt.Kind > MaxKind {
|
||||||
return fmt.Errorf("kind is too big: %d, max is %d", evt.Kind, MaxKind)
|
return fmt.Errorf("kind is too big: %d, max is %d", evt.Kind, uint16(MaxKind))
|
||||||
}
|
}
|
||||||
binary.LittleEndian.PutUint16(buf[1:3], uint16(evt.Kind))
|
binary.LittleEndian.PutUint16(buf[1:3], uint16(evt.Kind))
|
||||||
|
|
||||||
if evt.CreatedAt > MaxCreatedAt {
|
if evt.CreatedAt > MaxCreatedAt {
|
||||||
return fmt.Errorf("created_at is too big: %d, max is %d", evt.CreatedAt, MaxCreatedAt)
|
return fmt.Errorf("created_at is too big: %d, max is %d", evt.CreatedAt, uint32(MaxCreatedAt))
|
||||||
}
|
}
|
||||||
binary.LittleEndian.PutUint32(buf[3:7], uint32(evt.CreatedAt))
|
binary.LittleEndian.PutUint32(buf[3:7], uint32(evt.CreatedAt))
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
|||||||
|
|
||||||
ntags := len(evt.Tags)
|
ntags := len(evt.Tags)
|
||||||
if ntags > MaxTagCount {
|
if ntags > MaxTagCount {
|
||||||
return fmt.Errorf("can't encode too many tags: %d, max is %d", ntags, MaxTagCount)
|
return fmt.Errorf("can't encode too many tags: %d, max is %d", ntags, uint16(MaxTagCount))
|
||||||
}
|
}
|
||||||
binary.LittleEndian.PutUint16(buf[137:139], uint16(ntags))
|
binary.LittleEndian.PutUint16(buf[137:139], uint16(ntags))
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
|||||||
|
|
||||||
itemCount := len(tag)
|
itemCount := len(tag)
|
||||||
if itemCount > MaxTagItemCount {
|
if itemCount > MaxTagItemCount {
|
||||||
return fmt.Errorf("can't encode a tag with so many items: %d, max is %d", itemCount, MaxTagItemCount)
|
return fmt.Errorf("can't encode a tag with so many items: %d, max is %d", itemCount, uint8(MaxTagItemCount))
|
||||||
}
|
}
|
||||||
buf[tagBase+tagOffset] = uint8(itemCount)
|
buf[tagBase+tagOffset] = uint8(itemCount)
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
|||||||
for _, item := range tag {
|
for _, item := range tag {
|
||||||
itemSize := len(item)
|
itemSize := len(item)
|
||||||
if itemSize > MaxTagItemSize {
|
if itemSize > MaxTagItemSize {
|
||||||
return fmt.Errorf("tag item is too large: %d, max is %d", itemSize, MaxTagItemSize)
|
return fmt.Errorf("tag item is too large: %d, max is %d", itemSize, uint16(MaxTagItemSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
binary.LittleEndian.PutUint16(buf[tagBase+tagOffset+itemOffset:], uint16(itemSize))
|
binary.LittleEndian.PutUint16(buf[tagBase+tagOffset+itemOffset:], uint16(itemSize))
|
||||||
@@ -91,7 +91,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
|||||||
|
|
||||||
// content
|
// content
|
||||||
if contentLength := len(evt.Content); contentLength > MaxContentSize {
|
if contentLength := len(evt.Content); contentLength > MaxContentSize {
|
||||||
return fmt.Errorf("content is too large: %d, max is %d", contentLength, MaxContentSize)
|
return fmt.Errorf("content is too large: %d, max is %d", contentLength, uint16(MaxContentSize))
|
||||||
} else {
|
} else {
|
||||||
binary.LittleEndian.PutUint16(buf[tagBase+tagsSectionLength:], uint16(contentLength))
|
binary.LittleEndian.PutUint16(buf[tagBase+tagsSectionLength:], uint16(contentLength))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package checks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fiatjaf.com/nostr/eventstore"
|
"fiatjaf.com/nostr/eventstore"
|
||||||
"fiatjaf.com/nostr/eventstore/bleve"
|
|
||||||
"fiatjaf.com/nostr/eventstore/boltdb"
|
"fiatjaf.com/nostr/eventstore/boltdb"
|
||||||
"fiatjaf.com/nostr/eventstore/lmdb"
|
"fiatjaf.com/nostr/eventstore/lmdb"
|
||||||
"fiatjaf.com/nostr/eventstore/mmm"
|
"fiatjaf.com/nostr/eventstore/mmm"
|
||||||
@@ -13,5 +12,4 @@ var (
|
|||||||
_ eventstore.Store = (*lmdb.LMDBBackend)(nil)
|
_ eventstore.Store = (*lmdb.LMDBBackend)(nil)
|
||||||
_ eventstore.Store = (*mmm.IndexingLayer)(nil)
|
_ eventstore.Store = (*mmm.IndexingLayer)(nil)
|
||||||
_ eventstore.Store = (*boltdb.BoltBackend)(nil)
|
_ eventstore.Store = (*boltdb.BoltBackend)(nil)
|
||||||
_ eventstore.Store = (*bleve.BleveBackend)(nil)
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,18 +36,15 @@ func (b *LMDBBackend) CountEvents(filter nostr.Filter) (uint32, error) {
|
|||||||
// we already have a k and a v and an err from the cursor setup, so check and use these
|
// we already have a k and a v and an err from the cursor setup, so check and use these
|
||||||
if it.exhausted ||
|
if it.exhausted ||
|
||||||
it.err != nil ||
|
it.err != nil ||
|
||||||
len(it.key) != q.keySize ||
|
len(it.key) != len(q.prefix)+4 ||
|
||||||
!bytes.HasPrefix(it.key, q.prefix) {
|
!bytes.HasPrefix(it.key, q.prefix) {
|
||||||
// either iteration has errored or we reached the end of this prefix
|
// either iteration has errored or we reached the end of this prefix
|
||||||
break // stop this cursor and move to the next one
|
break // stop this cursor and move to the next one
|
||||||
}
|
}
|
||||||
|
|
||||||
// "id" indexes don't contain a timestamp
|
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||||
if q.dbi != b.indexId {
|
if createdAt < since {
|
||||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
break
|
||||||
if createdAt < since {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if extraAuthors == nil && extraKinds == nil && extraTagValues == nil {
|
if extraAuthors == nil && extraKinds == nil && extraTagValues == nil {
|
||||||
@@ -129,18 +126,15 @@ func (b *LMDBBackend) CountEventsHLL(filter nostr.Filter, offset int) (uint32, *
|
|||||||
for {
|
for {
|
||||||
// we already have a k and a v and an err from the cursor setup, so check and use these
|
// we already have a k and a v and an err from the cursor setup, so check and use these
|
||||||
if it.err != nil ||
|
if it.err != nil ||
|
||||||
len(it.key) != q.keySize ||
|
len(it.key) != len(q.prefix)+4 ||
|
||||||
!bytes.HasPrefix(it.key, q.prefix) {
|
!bytes.HasPrefix(it.key, q.prefix) {
|
||||||
// either iteration has errored or we reached the end of this prefix
|
// either iteration has errored or we reached the end of this prefix
|
||||||
break // stop this cursor and move to the next one
|
break // stop this cursor and move to the next one
|
||||||
}
|
}
|
||||||
|
|
||||||
// "id" indexes don't contain a timestamp
|
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||||
if q.dbi != b.indexId {
|
if createdAt < since {
|
||||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
break
|
||||||
if createdAt < since {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetch actual event (we need it regardless because we need the pubkey for the hll)
|
// fetch actual event (we need it regardless because we need the pubkey for the hll)
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (it *iterator) pull(n int, since uint32) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(it.key) != query.keySize || !bytes.HasPrefix(it.key, query.prefix) {
|
if len(it.key) != len(query.prefix)+4 || !bytes.HasPrefix(it.key, query.prefix) {
|
||||||
// we reached the end of this prefix
|
// we reached the end of this prefix
|
||||||
it.exhausted = true
|
it.exhausted = true
|
||||||
return
|
return
|
||||||
|
|||||||
+1
-27
@@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
"fiatjaf.com/nostr/eventstore"
|
"fiatjaf.com/nostr/eventstore"
|
||||||
@@ -34,8 +33,6 @@ type LMDBBackend struct {
|
|||||||
|
|
||||||
hllCache lmdb.DBI
|
hllCache lmdb.DBI
|
||||||
EnableHLLCacheFor func(kind nostr.Kind) (useCache bool, skipSavingActualEvent bool)
|
EnableHLLCacheFor func(kind nostr.Kind) (useCache bool, skipSavingActualEvent bool)
|
||||||
|
|
||||||
lastId atomic.Uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *LMDBBackend) Init() error {
|
func (b *LMDBBackend) Init() error {
|
||||||
@@ -112,7 +109,7 @@ func (b *LMDBBackend) initialize() error {
|
|||||||
env.SetMapSize(b.MapSize)
|
env.SetMapSize(b.MapSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := env.Open(b.Path, lmdb.NoTLS|lmdb.WriteMap|b.extraFlags, 0644); err != nil {
|
if err := env.Open(b.Path, lmdb.NoTLS|b.extraFlags, 0644); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
b.lmdbEnv = env
|
b.lmdbEnv = env
|
||||||
@@ -186,28 +183,5 @@ func (b *LMDBBackend) initialize() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get lastId
|
|
||||||
if err := b.lmdbEnv.View(func(txn *lmdb.Txn) error {
|
|
||||||
txn.RawRead = true
|
|
||||||
cursor, err := txn.OpenCursor(b.rawEventStore)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer cursor.Close()
|
|
||||||
k, _, err := cursor.Get(nil, nil, lmdb.Last)
|
|
||||||
if lmdb.IsNotFound(err) {
|
|
||||||
// nothing found, so we're at zero
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
b.lastId.Store(binary.BigEndian.Uint32(k))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.migrate()
|
return b.migrate()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ func (b *LMDBBackend) queryByIds(txn *lmdb.Txn, ids []nostr.ID, yield func(nostr
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
txn.Get(b.rawEventStore, idx)
|
|
||||||
bin, err := txn.Get(b.rawEventStore, idx)
|
bin, err := txn.Get(b.rawEventStore, idx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ type query struct {
|
|||||||
i int
|
i int
|
||||||
dbi lmdb.DBI
|
dbi lmdb.DBI
|
||||||
prefix []byte
|
prefix []byte
|
||||||
keySize int
|
|
||||||
startingPoint []byte
|
startingPoint []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,10 +39,10 @@ func (b *LMDBBackend) prepareQueries(filter nostr.Filter) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i, q := range queries {
|
for i, q := range queries {
|
||||||
sp := make([]byte, len(q.prefix))
|
sp := make([]byte, len(q.prefix)+4)
|
||||||
sp = sp[0:len(q.prefix)]
|
copy(sp[0:len(q.prefix)], q.prefix)
|
||||||
copy(sp, q.prefix)
|
binary.BigEndian.PutUint32(sp[len(q.prefix):], uint32(until))
|
||||||
queries[i].startingPoint = binary.BigEndian.AppendUint32(sp, uint32(until))
|
queries[i].startingPoint = sp
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -64,39 +63,27 @@ func (b *LMDBBackend) prepareQueries(filter nostr.Filter) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only "p" tag has a goodness of 2, so
|
// only "p" tag has a goodness of 2, so
|
||||||
if goodness == 2 {
|
if goodness == 2 && filter.Kinds != nil {
|
||||||
// this means we got a "p" tag, so we will use the ptag-kind index
|
// this means we got a "p" tag, so we will use the ptag-kind index
|
||||||
i := 0
|
i := 0
|
||||||
if filter.Kinds != nil {
|
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
||||||
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
for _, value := range tagValues {
|
||||||
for _, value := range tagValues {
|
if len(value) != 64 {
|
||||||
if len(value) != 64 {
|
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, kind := range filter.Kinds {
|
|
||||||
k := make([]byte, 8+2)
|
|
||||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
|
||||||
queries[i] = query{i: i, dbi: b.indexPTagKind, prefix: k[0 : 8+2], keySize: 8 + 2 + 4}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// even if there are no kinds, in that case we will just return any kind and not care
|
|
||||||
queries = make([]query, len(tagValues))
|
|
||||||
for i, value := range tagValues {
|
|
||||||
if len(value) != 64 {
|
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
k := make([]byte, 8)
|
for _, kind := range filter.Kinds {
|
||||||
|
k := make([]byte, 8+2)
|
||||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||||
}
|
}
|
||||||
queries[i] = query{i: i, dbi: b.indexPTagKind, prefix: k[0:8], keySize: 8 + 2 + 4}
|
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
||||||
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: b.indexPTagKind,
|
||||||
|
prefix: k[0 : 8+2],
|
||||||
|
}
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -107,7 +94,11 @@ func (b *LMDBBackend) prepareQueries(filter nostr.Filter) (
|
|||||||
dbi, k, offset := b.getTagIndexPrefix(tagKey, value)
|
dbi, k, offset := b.getTagIndexPrefix(tagKey, value)
|
||||||
// remove the last parts part to get just the prefix we want here
|
// remove the last parts part to get just the prefix we want here
|
||||||
prefix := k[0:offset]
|
prefix := k[0:offset]
|
||||||
queries[i] = query{i: i, dbi: dbi, prefix: prefix, keySize: len(prefix) + 4}
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: dbi,
|
||||||
|
prefix: prefix,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add an extra kind filter if available (only do this on plain tag index, not on ptag-kind index)
|
// add an extra kind filter if available (only do this on plain tag index, not on ptag-kind index)
|
||||||
@@ -142,7 +133,11 @@ pubkeyMatching:
|
|||||||
// will use pubkey index
|
// will use pubkey index
|
||||||
queries = make([]query, len(filter.Authors))
|
queries = make([]query, len(filter.Authors))
|
||||||
for i, pk := range filter.Authors {
|
for i, pk := range filter.Authors {
|
||||||
queries[i] = query{i: i, dbi: b.indexPubkey, prefix: pk[0:8], keySize: 8 + 4}
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: b.indexPubkey,
|
||||||
|
prefix: pk[0:8],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// will use pubkeyKind index
|
// will use pubkeyKind index
|
||||||
@@ -153,7 +148,11 @@ pubkeyMatching:
|
|||||||
prefix := make([]byte, 8+2)
|
prefix := make([]byte, 8+2)
|
||||||
copy(prefix[0:8], pk[0:8])
|
copy(prefix[0:8], pk[0:8])
|
||||||
binary.BigEndian.PutUint16(prefix[8:8+2], uint16(kind))
|
binary.BigEndian.PutUint16(prefix[8:8+2], uint16(kind))
|
||||||
queries[i] = query{i: i, dbi: b.indexPubkeyKind, prefix: prefix[0 : 8+2], keySize: 10 + 4}
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: b.indexPubkeyKind,
|
||||||
|
prefix: prefix[0 : 8+2],
|
||||||
|
}
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -170,7 +169,11 @@ pubkeyMatching:
|
|||||||
for i, kind := range filter.Kinds {
|
for i, kind := range filter.Kinds {
|
||||||
prefix := make([]byte, 2)
|
prefix := make([]byte, 2)
|
||||||
binary.BigEndian.PutUint16(prefix[0:2], uint16(kind))
|
binary.BigEndian.PutUint16(prefix[0:2], uint16(kind))
|
||||||
queries[i] = query{i: i, dbi: b.indexKind, prefix: prefix[0:2], keySize: 2 + 4}
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: b.indexKind,
|
||||||
|
prefix: prefix[0:2],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// potentially with an extra useless tag filtering
|
// potentially with an extra useless tag filtering
|
||||||
@@ -181,6 +184,10 @@ pubkeyMatching:
|
|||||||
// if we got here our query will have nothing to filter with
|
// if we got here our query will have nothing to filter with
|
||||||
queries = make([]query, 1)
|
queries = make([]query, 1)
|
||||||
prefix := make([]byte, 0)
|
prefix := make([]byte, 0)
|
||||||
queries[0] = query{i: 0, dbi: b.indexCreatedAt, prefix: prefix, keySize: 0 + 4}
|
queries[0] = query{
|
||||||
|
i: 0,
|
||||||
|
dbi: b.indexCreatedAt,
|
||||||
|
prefix: prefix,
|
||||||
|
}
|
||||||
return queries, nil, nil, "", nil, since, nil
|
return queries, nil, nil, "", nil, since, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+12
-14
@@ -2,14 +2,13 @@ package lmdb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"iter"
|
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) error {
|
func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||||
return b.lmdbEnv.Update(func(txn *lmdb.Txn) error {
|
err = b.lmdbEnv.Update(func(txn *lmdb.Txn) error {
|
||||||
// check if we already have this id
|
// check if we already have this id
|
||||||
_, existsErr := txn.Get(b.indexId, evt.ID[0:8])
|
_, existsErr := txn.Get(b.indexId, evt.ID[0:8])
|
||||||
if existsErr == nil {
|
if existsErr == nil {
|
||||||
@@ -26,24 +25,21 @@ func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// now we fetch the past events, whatever they are, delete them and then save the new
|
// now we fetch the past events, whatever they are, delete them and then save the new
|
||||||
var err error
|
|
||||||
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
|
||||||
err = b.query(txn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to query past events with %s: %w", filter, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldStore := true
|
shouldStore := true
|
||||||
for previous := range results {
|
if qerr := b.query(txn, filter, 10 /* could be just 1 */, func(previous nostr.Event) bool {
|
||||||
if nostr.IsOlder(previous, evt) {
|
if nostr.IsOlder(previous, evt) {
|
||||||
if err := b.delete(txn, previous.ID); err != nil {
|
if qerr := b.delete(txn, previous.ID); qerr != nil {
|
||||||
return fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, err)
|
qerr = fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, qerr)
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
deleted = append(deleted, previous)
|
||||||
} else {
|
} else {
|
||||||
// there is a newer event already stored, so we won't store this
|
// there is a newer event already stored, so we won't store this
|
||||||
shouldStore = false
|
shouldStore = false
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
|
}); qerr != nil {
|
||||||
|
return fmt.Errorf("failed to query past events with %s: %w", filter, qerr)
|
||||||
}
|
}
|
||||||
if shouldStore {
|
if shouldStore {
|
||||||
return b.save(txn, evt)
|
return b.save(txn, evt)
|
||||||
@@ -51,4 +47,6 @@ func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return deleted, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,18 +33,15 @@ func (il *IndexingLayer) CountEvents(filter nostr.Filter) (uint32, error) {
|
|||||||
// we already have a k and a v and an err from the cursor setup, so check and use these
|
// we already have a k and a v and an err from the cursor setup, so check and use these
|
||||||
if it.exhausted ||
|
if it.exhausted ||
|
||||||
it.err != nil ||
|
it.err != nil ||
|
||||||
len(it.key) != q.keySize ||
|
len(it.key) != len(q.prefix)+4 ||
|
||||||
!bytes.HasPrefix(it.key, q.prefix) {
|
!bytes.HasPrefix(it.key, q.prefix) {
|
||||||
// either iteration has errored or we reached the end of this prefix
|
// either iteration has errored or we reached the end of this prefix
|
||||||
break // stop this cursor and move to the next one
|
break // stop this cursor and move to the next one
|
||||||
}
|
}
|
||||||
|
|
||||||
// "id" indexes don't contain a timestamp
|
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||||
if q.timestampSize == 4 {
|
if createdAt < since {
|
||||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
break
|
||||||
if createdAt < since {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if extraAuthors == nil && extraKinds == nil && extraTagValues == nil {
|
if extraAuthors == nil && extraKinds == nil && extraTagValues == nil {
|
||||||
|
|||||||
@@ -116,8 +116,7 @@ func (b *MultiMmapManager) Rescan() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.freeRanges, err = b.gatherFreeRanges(mmmtxn)
|
if err := b.gatherFreeRanges(mmmtxn); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ import (
|
|||||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *MultiMmapManager) gatherFreeRanges(txn *lmdb.Txn) (positions, error) {
|
const LARGE_FREERANGE = 142
|
||||||
|
|
||||||
|
func (b *MultiMmapManager) gatherFreeRanges(txn *lmdb.Txn) error {
|
||||||
cursor, err := txn.OpenCursor(b.indexId)
|
cursor, err := txn.OpenCursor(b.indexId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open cursor on indexId: %w", err)
|
return fmt.Errorf("failed to open cursor on indexId: %w", err)
|
||||||
}
|
}
|
||||||
defer cursor.Close()
|
defer cursor.Close()
|
||||||
|
|
||||||
@@ -28,31 +30,35 @@ func (b *MultiMmapManager) gatherFreeRanges(txn *lmdb.Txn) (positions, error) {
|
|||||||
usedPositions = append(usedPositions, position{start: b.mmapfEnd, size: 0})
|
usedPositions = append(usedPositions, position{start: b.mmapfEnd, size: 0})
|
||||||
|
|
||||||
// calculate free ranges as gaps between used positions
|
// calculate free ranges as gaps between used positions
|
||||||
freeRanges := make(positions, 0, len(usedPositions)/2)
|
b.freeRangesAll = make(positions, 0, len(usedPositions))
|
||||||
|
b.freeRangesLarge = make([]position, 0, len(usedPositions)/10)
|
||||||
var currentStart uint64 = 0
|
var currentStart uint64 = 0
|
||||||
for _, used := range usedPositions {
|
for _, used := range usedPositions {
|
||||||
if used.start > currentStart {
|
if used.start > currentStart {
|
||||||
// gap from currentStart to pos.start
|
// gap from currentStart to pos.start
|
||||||
freeSize := used.start - currentStart
|
freeSize := used.start - currentStart
|
||||||
if freeSize > 0 {
|
if freeSize > 0 {
|
||||||
freeRanges = append(freeRanges, position{
|
fr := position{
|
||||||
start: currentStart,
|
start: currentStart,
|
||||||
size: uint32(freeSize),
|
size: uint32(freeSize),
|
||||||
})
|
}
|
||||||
|
b.freeRangesAll = append(b.freeRangesAll, fr)
|
||||||
|
if fr.isLarge() {
|
||||||
|
b.freeRangesLarge = append(b.freeRangesLarge, fr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
currentStart = used.start + uint64(used.size)
|
currentStart = used.start + uint64(used.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
return freeRanges, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
||||||
// use binary search to find the insertion point for the new pos
|
// use binary search to find the insertion point for the new pos
|
||||||
idx, exists := slices.BinarySearchFunc(b.freeRanges, newFreeRange.start, func(item position, target uint64) int {
|
idx, exists := slices.BinarySearchFunc(b.freeRangesAll, newFreeRange.start, func(item position, target uint64) int {
|
||||||
return cmp.Compare(item.start, target)
|
return cmp.Compare(item.start, target)
|
||||||
})
|
})
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
panic(fmt.Errorf("can't add free range that already exists: %s", newFreeRange))
|
panic(fmt.Errorf("can't add free range that already exists: %s", newFreeRange))
|
||||||
}
|
}
|
||||||
@@ -62,7 +68,7 @@ func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
|||||||
|
|
||||||
// check the range immediately before
|
// check the range immediately before
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
before := b.freeRanges[idx-1]
|
before := b.freeRangesAll[idx-1]
|
||||||
if before.start+uint64(before.size) == newFreeRange.start {
|
if before.start+uint64(before.size) == newFreeRange.start {
|
||||||
deleteStart = idx - 1
|
deleteStart = idx - 1
|
||||||
deleting++
|
deleting++
|
||||||
@@ -72,8 +78,8 @@ func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check the range immediately after
|
// check the range immediately after
|
||||||
if idx < len(b.freeRanges) {
|
if idx < len(b.freeRangesAll) {
|
||||||
after := b.freeRanges[idx]
|
after := b.freeRangesAll[idx]
|
||||||
if newFreeRange.start+uint64(newFreeRange.size) == after.start {
|
if newFreeRange.start+uint64(newFreeRange.size) == after.start {
|
||||||
if deleteStart == -1 {
|
if deleteStart == -1 {
|
||||||
deleteStart = idx
|
deleteStart = idx
|
||||||
@@ -87,13 +93,60 @@ func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
|||||||
switch deleting {
|
switch deleting {
|
||||||
case 0:
|
case 0:
|
||||||
// if we are not deleting anything we must insert the new free range
|
// if we are not deleting anything we must insert the new free range
|
||||||
b.freeRanges = slices.Insert(b.freeRanges, idx, newFreeRange)
|
b.freeRangesAll = slices.Insert(b.freeRangesAll, idx, newFreeRange)
|
||||||
|
|
||||||
|
// if it's large add it to the list of large free ranges
|
||||||
|
if newFreeRange.isLarge() {
|
||||||
|
b.freeRangesLarge = append(b.freeRangesLarge, newFreeRange)
|
||||||
|
}
|
||||||
case 1:
|
case 1:
|
||||||
|
deleted := b.freeRangesAll[deleteStart]
|
||||||
|
|
||||||
// if we're deleting a single range, don't delete it, modify it in-place instead.
|
// if we're deleting a single range, don't delete it, modify it in-place instead.
|
||||||
b.freeRanges[deleteStart] = newFreeRange
|
b.freeRangesAll[deleteStart] = newFreeRange
|
||||||
|
|
||||||
|
// if the list we're modifying is in the list of large ranges modify it there too
|
||||||
|
if deleted.isLarge() {
|
||||||
|
for i, large := range b.freeRangesLarge {
|
||||||
|
if large.start == deleted.start {
|
||||||
|
b.freeRangesLarge[i] = newFreeRange
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if newFreeRange.isLarge() {
|
||||||
|
// otherwise: if after modification it's big enough we should add it to list of large ranges
|
||||||
|
b.freeRangesLarge = append(b.freeRangesLarge, newFreeRange)
|
||||||
|
}
|
||||||
case 2:
|
case 2:
|
||||||
// now if we're deleting two ranges, delete just one instead and modify the other in place
|
// now if we're deleting two ranges, delete the second instead and modify the first in place
|
||||||
b.freeRanges[deleteStart] = newFreeRange
|
first := b.freeRangesAll[deleteStart]
|
||||||
b.freeRanges = slices.Delete(b.freeRanges, deleteStart+1, deleteStart+1+1)
|
second := b.freeRangesAll[deleteStart+1]
|
||||||
|
|
||||||
|
b.freeRangesAll = slices.Delete(b.freeRangesAll, deleteStart+1, deleteStart+1+1)
|
||||||
|
b.freeRangesAll[deleteStart] = newFreeRange
|
||||||
|
|
||||||
|
// if the second was in the list of large lists delete it from there too
|
||||||
|
if second.isLarge() {
|
||||||
|
for i, large := range b.freeRangesLarge {
|
||||||
|
if large.start == second.start {
|
||||||
|
b.freeRangesLarge[i] = b.freeRangesLarge[len(b.freeRangesLarge)-1]
|
||||||
|
b.freeRangesLarge = b.freeRangesLarge[0 : len(b.freeRangesLarge)-1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the list we're modifying (the first) is already in the list of large ranges modify it there too
|
||||||
|
if first.isLarge() {
|
||||||
|
for i, large := range b.freeRangesLarge {
|
||||||
|
if large.start == first.start {
|
||||||
|
b.freeRangesLarge[i] = newFreeRange
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if newFreeRange.isLarge() {
|
||||||
|
// otherwise if after modification has become big enough we should add it to list of large ranges
|
||||||
|
b.freeRangesLarge = append(b.freeRangesLarge, newFreeRange)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func FuzzFreeRanges(f *testing.F) {
|
|||||||
|
|
||||||
total := 0
|
total := 0
|
||||||
for {
|
for {
|
||||||
freeBefore, spaceBefore := countUsableFreeRanges(mmmm)
|
freeBefore, spaceBefore := countUsableFreeRanges(t, mmmm)
|
||||||
|
|
||||||
hasAdded := false
|
hasAdded := false
|
||||||
for i := range rnd.IntN(40) {
|
for i := range rnd.IntN(40) {
|
||||||
@@ -69,7 +69,7 @@ func FuzzFreeRanges(f *testing.F) {
|
|||||||
total++
|
total++
|
||||||
}
|
}
|
||||||
|
|
||||||
freeAfter, spaceAfter := countUsableFreeRanges(mmmm)
|
freeAfter, spaceAfter := countUsableFreeRanges(t, mmmm)
|
||||||
if hasAdded && freeBefore > 0 {
|
if hasAdded && freeBefore > 0 {
|
||||||
require.Lessf(t, spaceAfter, spaceBefore, "must use some of the existing free ranges when inserting new events (before: %d, after: %d)", freeBefore, freeAfter)
|
require.Lessf(t, spaceAfter, spaceBefore, "must use some of the existing free ranges when inserting new events (before: %d, after: %d)", freeBefore, freeAfter)
|
||||||
}
|
}
|
||||||
@@ -86,9 +86,35 @@ func FuzzFreeRanges(f *testing.F) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
verifyFreeRangesInvariants(t, mmmm)
|
||||||
|
|
||||||
|
// add more events
|
||||||
|
for i := range rnd.IntN(40) {
|
||||||
|
content := "1"
|
||||||
|
if i > 0 {
|
||||||
|
content = strings.Repeat("z", rnd.IntN(1000))
|
||||||
|
}
|
||||||
|
|
||||||
|
evt := nostr.Event{
|
||||||
|
CreatedAt: nostr.Timestamp(rnd.Uint32()),
|
||||||
|
Kind: 1,
|
||||||
|
Content: content,
|
||||||
|
Tags: nostr.Tags{},
|
||||||
|
}
|
||||||
|
evt.Sign(sk)
|
||||||
|
err := il.SaveEvent(evt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
total++
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyFreeRangesInvariants(t, mmmm)
|
||||||
|
|
||||||
mmmm.lmdbEnv.View(func(txn *lmdb.Txn) error {
|
mmmm.lmdbEnv.View(func(txn *lmdb.Txn) error {
|
||||||
expectedFreeRanges, _ := mmmm.gatherFreeRanges(txn)
|
before := mmmm.freeRangesAll
|
||||||
require.Equalf(t, expectedFreeRanges, mmmm.freeRanges, "expected %s, got %s", expectedFreeRanges, mmmm.freeRanges)
|
err := mmmm.gatherFreeRanges(txn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equalf(t, mmmm.freeRangesAll, before, "expected %s, got %s", before, mmmm.freeRangesAll)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -99,12 +125,54 @@ func FuzzFreeRanges(f *testing.F) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func countUsableFreeRanges(mmmm *MultiMmapManager) (count int, space int) {
|
func countUsableFreeRanges(t *testing.T, mmmm *MultiMmapManager) (count int, space int) {
|
||||||
for _, fr := range mmmm.freeRanges {
|
for _, fr := range mmmm.freeRangesAll {
|
||||||
if fr.size >= 142 {
|
if fr.size >= LARGE_FREERANGE {
|
||||||
count++
|
count++
|
||||||
space += int(fr.size)
|
space += int(fr.size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
require.Equal(t, count, len(mmmm.freeRangesLarge))
|
||||||
|
|
||||||
return count, space
|
return count, space
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func verifyFreeRangesInvariants(t *testing.T, mmmm *MultiMmapManager) {
|
||||||
|
all := mmmm.freeRangesAll
|
||||||
|
large := mmmm.freeRangesLarge
|
||||||
|
|
||||||
|
for _, l := range large {
|
||||||
|
found := false
|
||||||
|
for _, a := range all {
|
||||||
|
if l.start == a.start && l.size == a.size {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, found, "large range %v not found in all ranges", l)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < len(all); i++ {
|
||||||
|
require.Greater(t, all[i].start, all[i-1].start, "all ranges should be sorted by start")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range all {
|
||||||
|
for j := i + 1; j < len(all); j++ {
|
||||||
|
end1 := all[i].start + uint64(all[i].size)
|
||||||
|
end2 := all[j].start + uint64(all[j].size)
|
||||||
|
require.False(t, (all[i].start >= all[j].start && all[i].start < end2) ||
|
||||||
|
(all[j].start >= all[i].start && all[j].start < end1),
|
||||||
|
"ranges %v and %v overlap", all[i], all[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mmmm.lmdbEnv.View(func(txn *lmdb.Txn) error {
|
||||||
|
before := make(positions, len(mmmm.freeRangesAll))
|
||||||
|
copy(before, mmmm.freeRangesAll)
|
||||||
|
err := mmmm.gatherFreeRanges(txn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, before, mmmm.freeRangesAll, "recomputing free ranges should yield the same result")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func (it *iterator) pull(n int, since uint32) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(it.key) != it.query.keySize || !bytes.HasPrefix(it.key, it.query.prefix) {
|
if len(it.key) != len(it.query.prefix)+4 || !bytes.HasPrefix(it.key, it.query.prefix) {
|
||||||
// we reached the end of this prefix
|
// we reached the end of this prefix
|
||||||
it.exhausted = true
|
it.exhausted = true
|
||||||
return
|
return
|
||||||
@@ -226,7 +226,7 @@ func (il *IndexingLayer) getIndexKeysForEvent(evt nostr.Event) iter.Seq[key] {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// now the p-tag+kind+date
|
// now the p-1733934977tag+kind+date
|
||||||
if dbi == il.indexTag32 && tag[0] == "p" {
|
if dbi == il.indexTag32 && tag[0] == "p" {
|
||||||
k := make([]byte, 8+2+4)
|
k := make([]byte, 8+2+4)
|
||||||
xhex.Decode(k[0:8], []byte(tag[1][0:8*2]))
|
xhex.Decode(k[0:8], []byte(tag[1][0:8*2]))
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (il *IndexingLayer) Init() error {
|
|||||||
|
|
||||||
env.SetMaxDBs(9)
|
env.SetMaxDBs(9)
|
||||||
env.SetMaxReaders(1000)
|
env.SetMaxReaders(1000)
|
||||||
env.SetMapSize(1 << 38) // ~273GB
|
env.SetMapSize(MMAP_INFINITE_SIZE)
|
||||||
|
|
||||||
// create directory if it doesn't exist and open it
|
// create directory if it doesn't exist and open it
|
||||||
if err := os.MkdirAll(path, 0755); err != nil {
|
if err := os.MkdirAll(path, 0755); err != nil {
|
||||||
|
|||||||
+46
-24
@@ -35,13 +35,15 @@ type MultiMmapManager struct {
|
|||||||
mmapfEnd uint64
|
mmapfEnd uint64
|
||||||
|
|
||||||
writeMutex sync.Mutex
|
writeMutex sync.Mutex
|
||||||
|
lockfile *os.File
|
||||||
|
|
||||||
lmdbEnv *lmdb.Env
|
lmdbEnv *lmdb.Env
|
||||||
stuff lmdb.DBI
|
stuff lmdb.DBI
|
||||||
knownLayers lmdb.DBI
|
knownLayers lmdb.DBI
|
||||||
indexId lmdb.DBI
|
indexId lmdb.DBI
|
||||||
|
|
||||||
freeRanges positions
|
freeRangesAll positions // sorted by position
|
||||||
|
freeRangesLarge []position // unsorted
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *MultiMmapManager) String() string {
|
func (b *MultiMmapManager) String() string {
|
||||||
@@ -49,33 +51,43 @@ func (b *MultiMmapManager) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MMAP_INFINITE_SIZE = 1 << 40
|
MMAP_INFINITE_SIZE = 100_000_000_000
|
||||||
maxuint16 = 65535
|
maxuint16 = 65535
|
||||||
maxuint32 = 4294967295
|
maxuint32 = 4294967295
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *MultiMmapManager) Init() error {
|
func (b *MultiMmapManager) Init() (err error) {
|
||||||
if b.Logger == nil {
|
if b.Logger == nil {
|
||||||
nopLogger := zerolog.Nop()
|
nopLogger := zerolog.Nop()
|
||||||
b.Logger = &nopLogger
|
b.Logger = &nopLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
b.releaseLock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// create directory if it doesn't exist
|
// create directory if it doesn't exist
|
||||||
dbpath := filepath.Join(b.Dir, "mmmm")
|
dbpath := filepath.Join(b.Dir, "mmmm")
|
||||||
if err := os.MkdirAll(dbpath, 0755); err != nil {
|
if err := os.MkdirAll(dbpath, 0755); err != nil {
|
||||||
return fmt.Errorf("failed to create directory %s: %w", dbpath, err)
|
return fmt.Errorf("failed to create directory %s: %w", dbpath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !b.ReadOnly {
|
// lock database directory to prevent multiple instances
|
||||||
// create lockfile to prevent multiple instances
|
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
||||||
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
lockfile, err := os.OpenFile(lockfilePath, os.O_CREATE|os.O_RDWR, 0644)
|
||||||
if _, err := os.OpenFile(lockfilePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644); err != nil {
|
if err != nil {
|
||||||
if os.IsExist(err) {
|
return fmt.Errorf("failed to open lockfile %s: %w", lockfilePath, err)
|
||||||
return fmt.Errorf("database at %s is already in use by another instance", b.Dir)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to create lockfile %s: %w", lockfilePath, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if err := syscall.Flock(int(lockfile.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
|
||||||
|
lockfile.Close()
|
||||||
|
if errors.Is(err, syscall.EWOULDBLOCK) || errors.Is(err, syscall.EAGAIN) {
|
||||||
|
return fmt.Errorf("database at %s is already in use by another instance", b.Dir)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to lock database at %s: %w", b.Dir, err)
|
||||||
|
}
|
||||||
|
b.lockfile = lockfile
|
||||||
|
|
||||||
// open a huge mmapped file
|
// open a huge mmapped file
|
||||||
b.mmapfPath = filepath.Join(b.Dir, "events")
|
b.mmapfPath = filepath.Join(b.Dir, "events")
|
||||||
@@ -83,7 +95,7 @@ func (b *MultiMmapManager) Init() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open events file at %s: %w", b.mmapfPath, err)
|
return fmt.Errorf("failed to open events file at %s: %w", b.mmapfPath, err)
|
||||||
}
|
}
|
||||||
mmapf, err := syscall.Mmap(int(file.Fd()), 0, MMAP_INFINITE_SIZE,
|
mmapf, err := syscall.Mmap(int(file.Fd()), 0, int(MMAP_INFINITE_SIZE),
|
||||||
syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
|
syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to mmap events file at %s: %w", b.mmapfPath, err)
|
return fmt.Errorf("failed to mmap events file at %s: %w", b.mmapfPath, err)
|
||||||
@@ -104,7 +116,7 @@ func (b *MultiMmapManager) Init() error {
|
|||||||
|
|
||||||
env.SetMaxDBs(3)
|
env.SetMaxDBs(3)
|
||||||
env.SetMaxReaders(1000)
|
env.SetMaxReaders(1000)
|
||||||
env.SetMapSize(1 << 38) // ~273GB
|
env.SetMapSize(MMAP_INFINITE_SIZE)
|
||||||
|
|
||||||
err = env.Open(dbpath, lmdb.NoTLS, 0644)
|
err = env.Open(dbpath, lmdb.NoTLS, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -139,18 +151,17 @@ func (b *MultiMmapManager) Init() error {
|
|||||||
|
|
||||||
if !b.ReadOnly {
|
if !b.ReadOnly {
|
||||||
// scan index table to calculate free ranges from used positions
|
// scan index table to calculate free ranges from used positions
|
||||||
b.freeRanges, err = b.gatherFreeRanges(txn)
|
if err := b.gatherFreeRanges(txn); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logOp := b.Logger.Debug()
|
logOp := b.Logger.Debug()
|
||||||
for _, pos := range b.freeRanges {
|
count := 0
|
||||||
if pos.size > 20 {
|
for _, pos := range b.freeRangesLarge {
|
||||||
logOp = logOp.Uint32(fmt.Sprintf("%d", pos.start), pos.size)
|
logOp = logOp.Uint32(fmt.Sprintf("%d", pos.start), pos.size)
|
||||||
}
|
count++
|
||||||
}
|
}
|
||||||
logOp.Msg("calculated free ranges from index scan")
|
logOp.Int("count", count).Msg("calculated free ranges from index scan")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -365,6 +376,19 @@ func (b *MultiMmapManager) getNextAvailableLayerId(txn *lmdb.Txn) (uint16, error
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *MultiMmapManager) releaseLock() {
|
||||||
|
if b.lockfile == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = syscall.Flock(int(b.lockfile.Fd()), syscall.LOCK_UN)
|
||||||
|
_ = b.lockfile.Close()
|
||||||
|
b.lockfile = nil
|
||||||
|
|
||||||
|
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
||||||
|
_ = os.Remove(lockfilePath)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *MultiMmapManager) Close() {
|
func (b *MultiMmapManager) Close() {
|
||||||
b.lmdbEnv.Close()
|
b.lmdbEnv.Close()
|
||||||
for _, il := range b.layers {
|
for _, il := range b.layers {
|
||||||
@@ -373,7 +397,5 @@ func (b *MultiMmapManager) Close() {
|
|||||||
|
|
||||||
syscall.Munmap(b.mmapf)
|
syscall.Munmap(b.mmapf)
|
||||||
|
|
||||||
// remove lockfile
|
b.releaseLock()
|
||||||
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
|
||||||
os.Remove(lockfilePath)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,35 @@
|
|||||||
package mmm
|
package mmm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type positions []position
|
type positions []position
|
||||||
|
|
||||||
|
func (poss positions) find(start uint64) (idx int) {
|
||||||
|
idx, _ = slices.BinarySearchFunc(poss, start, func(item position, target uint64) int {
|
||||||
|
return cmp.Compare(item.start, target)
|
||||||
|
})
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (poss positions) del(start uint64) positions {
|
||||||
|
idx := poss.find(start)
|
||||||
|
return slices.Delete(poss, idx, idx+1)
|
||||||
|
}
|
||||||
|
|
||||||
func (poss positions) String() string {
|
func (poss positions) String() string {
|
||||||
str := strings.Builder{}
|
str := strings.Builder{}
|
||||||
str.Grow(10 + 20*len(poss))
|
str.Grow(10 + 20*len(poss))
|
||||||
str.WriteString("positions:[")
|
str.WriteString("positions:[")
|
||||||
for _, pos := range poss {
|
for _, pos := range poss {
|
||||||
str.WriteByte(' ')
|
|
||||||
str.WriteString(pos.String())
|
str.WriteString(pos.String())
|
||||||
}
|
}
|
||||||
str.WriteString(" ]")
|
str.WriteString("]")
|
||||||
return str.String()
|
return str.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +42,10 @@ func (pos position) String() string {
|
|||||||
return fmt.Sprintf("<%d|%d|%d>", pos.start, pos.size, pos.start+uint64(pos.size))
|
return fmt.Sprintf("<%d|%d|%d>", pos.start, pos.size, pos.start+uint64(pos.size))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pos position) isLarge() bool {
|
||||||
|
return pos.size >= LARGE_FREERANGE
|
||||||
|
}
|
||||||
|
|
||||||
func positionFromBytes(posb []byte) position {
|
func positionFromBytes(posb []byte) position {
|
||||||
return position{
|
return position{
|
||||||
size: binary.BigEndian.Uint32(posb[0:4]),
|
size: binary.BigEndian.Uint32(posb[0:4]),
|
||||||
|
|||||||
+11
-3
@@ -7,6 +7,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
"fiatjaf.com/nostr/eventstore/codec/betterbinary"
|
"fiatjaf.com/nostr/eventstore/codec/betterbinary"
|
||||||
@@ -14,6 +15,12 @@ import (
|
|||||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var tempResultsPool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return make([]nostr.Event, 0, 64)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// GetByID returns the event -- if found in this mmm -- and all the IndexingLayers it belongs to.
|
// GetByID returns the event -- if found in this mmm -- and all the IndexingLayers it belongs to.
|
||||||
func (b *MultiMmapManager) GetByID(id nostr.ID) (*nostr.Event, IndexingLayers) {
|
func (b *MultiMmapManager) GetByID(id nostr.ID) (*nostr.Event, IndexingLayers) {
|
||||||
var event *nostr.Event
|
var event *nostr.Event
|
||||||
@@ -140,7 +147,8 @@ func (il *IndexingLayer) query(txn *lmdb.Txn, filter nostr.Filter, limit int, yi
|
|||||||
|
|
||||||
numberOfIteratorsToPullOnEachRound := max(1, int(math.Ceil(float64(len(iterators))/float64(12))))
|
numberOfIteratorsToPullOnEachRound := max(1, int(math.Ceil(float64(len(iterators))/float64(12))))
|
||||||
totalEventsEmitted := 0
|
totalEventsEmitted := 0
|
||||||
tempResults := make([]nostr.Event, 0, batchSizePerQuery*2)
|
tempResults := tempResultsPool.Get().([]nostr.Event)
|
||||||
|
defer tempResultsPool.Put(tempResults[:0])
|
||||||
|
|
||||||
for len(iterators) > 0 {
|
for len(iterators) > 0 {
|
||||||
// reset stuff
|
// reset stuff
|
||||||
@@ -180,8 +188,8 @@ func (il *IndexingLayer) query(txn *lmdb.Txn, filter nostr.Filter, limit int, yi
|
|||||||
// decode the entire thing
|
// decode the entire thing
|
||||||
event := nostr.Event{}
|
event := nostr.Event{}
|
||||||
if err := betterbinary.Unmarshal(bin, &event); err != nil {
|
if err := betterbinary.Unmarshal(bin, &event); err != nil {
|
||||||
log.Printf("lmdb: value read error (id %x) on query prefix %x sp %x dbi %v: %s\n",
|
log.Printf("mmm: value read error (id %s) on query prefix %x sp %x dbi %v: %s\n",
|
||||||
betterbinary.GetID(bin), iterators[i].query.prefix, iterators[i].query.startingPoint, iterators[i].query.dbi, err)
|
betterbinary.GetID(bin).Hex(), iterators[i].query.prefix, iterators[i].query.startingPoint, iterators[i].query.dbi, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ type query struct {
|
|||||||
i int
|
i int
|
||||||
dbi lmdb.DBI
|
dbi lmdb.DBI
|
||||||
prefix []byte
|
prefix []byte
|
||||||
keySize int
|
|
||||||
timestampSize int
|
|
||||||
startingPoint []byte
|
startingPoint []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,10 +39,10 @@ func (il *IndexingLayer) prepareQueries(filter nostr.Filter) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i, q := range queries {
|
for i, q := range queries {
|
||||||
sp := make([]byte, len(q.prefix))
|
sp := make([]byte, len(q.prefix)+4)
|
||||||
sp = sp[0:len(q.prefix)]
|
copy(sp[0:len(q.prefix)], q.prefix)
|
||||||
copy(sp, q.prefix)
|
binary.BigEndian.PutUint32(sp[len(q.prefix):], uint32(until))
|
||||||
queries[i].startingPoint = binary.BigEndian.AppendUint32(sp, uint32(until))
|
queries[i].startingPoint = sp
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -65,39 +63,27 @@ func (il *IndexingLayer) prepareQueries(filter nostr.Filter) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only "p" tag has a goodness of 2, so
|
// only "p" tag has a goodness of 2, so
|
||||||
if goodness == 2 {
|
if goodness == 2 && filter.Kinds != nil {
|
||||||
// this means we got a "p" tag, so we will use the ptag-kind index
|
// this means we got a "p" tag, so we will use the ptag-kind index
|
||||||
i := 0
|
i := 0
|
||||||
if filter.Kinds != nil {
|
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
||||||
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
for _, value := range tagValues {
|
||||||
for _, value := range tagValues {
|
if len(value) != 64 {
|
||||||
if len(value) != 64 {
|
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, kind := range filter.Kinds {
|
|
||||||
k := make([]byte, 8+2)
|
|
||||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
|
||||||
queries[i] = query{i: i, dbi: il.indexPTagKind, prefix: k[0 : 8+2], keySize: 8 + 2 + 4, timestampSize: 4}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// even if there are no kinds, in that case we will just return any kind and not care
|
|
||||||
queries = make([]query, len(tagValues))
|
|
||||||
for i, value := range tagValues {
|
|
||||||
if len(value) != 64 {
|
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
k := make([]byte, 8)
|
for _, kind := range filter.Kinds {
|
||||||
|
k := make([]byte, 8+2)
|
||||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
||||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||||
}
|
}
|
||||||
queries[i] = query{i: i, dbi: il.indexPTagKind, prefix: k[0:8], keySize: 8 + 2 + 4, timestampSize: 4}
|
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
||||||
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: il.indexPTagKind,
|
||||||
|
prefix: k[0 : 8+2],
|
||||||
|
}
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -108,7 +94,11 @@ func (il *IndexingLayer) prepareQueries(filter nostr.Filter) (
|
|||||||
dbi, k, offset := il.getTagIndexPrefix(tagKey, value)
|
dbi, k, offset := il.getTagIndexPrefix(tagKey, value)
|
||||||
// remove the last parts part to get just the prefix we want here
|
// remove the last parts part to get just the prefix we want here
|
||||||
prefix := k[0:offset]
|
prefix := k[0:offset]
|
||||||
queries[i] = query{i: i, dbi: dbi, prefix: prefix, keySize: len(prefix) + 4, timestampSize: 4}
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: dbi,
|
||||||
|
prefix: prefix,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add an extra kind filter if available (only do this on plain tag index, not on ptag-kind index)
|
// add an extra kind filter if available (only do this on plain tag index, not on ptag-kind index)
|
||||||
@@ -143,9 +133,11 @@ pubkeyMatching:
|
|||||||
// will use pubkey index
|
// will use pubkey index
|
||||||
queries = make([]query, len(filter.Authors))
|
queries = make([]query, len(filter.Authors))
|
||||||
for i, pk := range filter.Authors {
|
for i, pk := range filter.Authors {
|
||||||
prefix := make([]byte, 8)
|
queries[i] = query{
|
||||||
copy(prefix[0:8], pk[0:8])
|
i: i,
|
||||||
queries[i] = query{i: i, dbi: il.indexPubkey, prefix: prefix[0:8], keySize: 8 + 4, timestampSize: 4}
|
dbi: il.indexPubkey,
|
||||||
|
prefix: pk[0:8],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// will use pubkeyKind index
|
// will use pubkeyKind index
|
||||||
@@ -156,7 +148,11 @@ pubkeyMatching:
|
|||||||
prefix := make([]byte, 8+2)
|
prefix := make([]byte, 8+2)
|
||||||
copy(prefix[0:8], pk[0:8])
|
copy(prefix[0:8], pk[0:8])
|
||||||
binary.BigEndian.PutUint16(prefix[8:8+2], uint16(kind))
|
binary.BigEndian.PutUint16(prefix[8:8+2], uint16(kind))
|
||||||
queries[i] = query{i: i, dbi: il.indexPubkeyKind, prefix: prefix[0 : 8+2], keySize: 10 + 4, timestampSize: 4}
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: il.indexPubkeyKind,
|
||||||
|
prefix: prefix[0 : 8+2],
|
||||||
|
}
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -173,7 +169,11 @@ pubkeyMatching:
|
|||||||
for i, kind := range filter.Kinds {
|
for i, kind := range filter.Kinds {
|
||||||
prefix := make([]byte, 2)
|
prefix := make([]byte, 2)
|
||||||
binary.BigEndian.PutUint16(prefix[0:2], uint16(kind))
|
binary.BigEndian.PutUint16(prefix[0:2], uint16(kind))
|
||||||
queries[i] = query{i: i, dbi: il.indexKind, prefix: prefix[0:2], keySize: 2 + 4, timestampSize: 4}
|
queries[i] = query{
|
||||||
|
i: i,
|
||||||
|
dbi: il.indexKind,
|
||||||
|
prefix: prefix[0:2],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// potentially with an extra useless tag filtering
|
// potentially with an extra useless tag filtering
|
||||||
@@ -184,6 +184,10 @@ pubkeyMatching:
|
|||||||
// if we got here our query will have nothing to filter with
|
// if we got here our query will have nothing to filter with
|
||||||
queries = make([]query, 1)
|
queries = make([]query, 1)
|
||||||
prefix := make([]byte, 0)
|
prefix := make([]byte, 0)
|
||||||
queries[0] = query{i: 0, dbi: il.indexCreatedAt, prefix: prefix, keySize: 0 + 4, timestampSize: 4}
|
queries[0] = query{
|
||||||
|
i: 0,
|
||||||
|
dbi: il.indexCreatedAt,
|
||||||
|
prefix: prefix,
|
||||||
|
}
|
||||||
return queries, nil, nil, "", nil, since, nil
|
return queries, nil, nil, "", nil, since, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+18
-16
@@ -9,9 +9,9 @@ import (
|
|||||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||||
if il.mmmm.ReadOnly {
|
if il.mmmm.ReadOnly {
|
||||||
return ReadOnly
|
return nil, ReadOnly
|
||||||
}
|
}
|
||||||
|
|
||||||
il.mmmm.writeMutex.Lock()
|
il.mmmm.writeMutex.Lock()
|
||||||
@@ -29,7 +29,7 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
|||||||
// prepare transactions
|
// prepare transactions
|
||||||
mmmtxn, err := il.mmmm.lmdbEnv.BeginTxn(nil, 0)
|
mmmtxn, err := il.mmmm.lmdbEnv.BeginTxn(nil, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
// defer abort but only if we haven't committed (we'll set it to nil after committing)
|
// defer abort but only if we haven't committed (we'll set it to nil after committing)
|
||||||
@@ -41,7 +41,7 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
|||||||
|
|
||||||
iltxn, err := il.lmdbEnv.BeginTxn(nil, 0)
|
iltxn, err := il.lmdbEnv.BeginTxn(nil, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
// defer abort but only if we haven't committed (we'll set it to nil after committing)
|
// defer abort but only if we haven't committed (we'll set it to nil after committing)
|
||||||
@@ -54,33 +54,35 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
|||||||
// check if we already have this id
|
// check if we already have this id
|
||||||
_, existsErr := mmmtxn.Get(il.mmmm.indexId, evt.ID[0:8])
|
_, existsErr := mmmtxn.Get(il.mmmm.indexId, evt.ID[0:8])
|
||||||
if existsErr == nil {
|
if existsErr == nil {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if !lmdb.IsNotFound(existsErr) {
|
if !lmdb.IsNotFound(existsErr) {
|
||||||
return fmt.Errorf("error checking existence: %w", existsErr)
|
return nil, fmt.Errorf("error checking existence: %w", existsErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// now we fetch the past events, whatever they are, delete them and then save the new
|
// now we fetch the past events, whatever they are, delete them and then save the new
|
||||||
|
var qerr error
|
||||||
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
||||||
err = il.query(iltxn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
qerr = il.query(iltxn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if qerr != nil {
|
||||||
return fmt.Errorf("failed to query past events with %s: %w", filter, err)
|
return nil, fmt.Errorf("failed to query past events with %s: %w", filter, qerr)
|
||||||
}
|
}
|
||||||
|
|
||||||
var acquiredFreeRangeFromDelete *position
|
var acquiredFreeRangeFromDelete *position
|
||||||
shouldStore := true
|
shouldStore := true
|
||||||
for previous := range results {
|
for previous := range results {
|
||||||
if nostr.IsOlder(previous, evt) {
|
if nostr.IsOlder(previous, evt) {
|
||||||
if pos, shouldPurge, err := il.delete(mmmtxn, iltxn, previous.ID); err != nil {
|
if pos, shouldPurge, derr := il.delete(mmmtxn, iltxn, previous.ID); derr != nil {
|
||||||
return fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, err)
|
return nil, fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, derr)
|
||||||
} else if shouldPurge {
|
} else if shouldPurge {
|
||||||
// purge
|
// purge
|
||||||
if err := mmmtxn.Del(il.mmmm.indexId, previous.ID[0:8], nil); err != nil {
|
if err := mmmtxn.Del(il.mmmm.indexId, previous.ID[0:8], nil); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
acquiredFreeRangeFromDelete = &pos
|
acquiredFreeRangeFromDelete = &pos
|
||||||
}
|
}
|
||||||
|
deleted = append(deleted, previous)
|
||||||
} else {
|
} else {
|
||||||
// there is a newer event already stored, so we won't store this
|
// there is a newer event already stored, so we won't store this
|
||||||
shouldStore = false
|
shouldStore = false
|
||||||
@@ -90,17 +92,17 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
|||||||
if shouldStore {
|
if shouldStore {
|
||||||
_, err := il.mmmm.storeOn(mmmtxn, iltxn, il, evt)
|
_, err := il.mmmm.storeOn(mmmtxn, iltxn, il, evt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit in this order to minimize problematic inconsistencies
|
// commit in this order to minimize problematic inconsistencies
|
||||||
if err := mmmtxn.Commit(); err != nil {
|
if err := mmmtxn.Commit(); err != nil {
|
||||||
return fmt.Errorf("can't commit mmmtxn: %w", err)
|
return nil, fmt.Errorf("can't commit mmmtxn: %w", err)
|
||||||
}
|
}
|
||||||
mmmtxn = nil
|
mmmtxn = nil
|
||||||
if err := iltxn.Commit(); err != nil {
|
if err := iltxn.Commit(); err != nil {
|
||||||
return fmt.Errorf("can't commit iltxn: %w", err)
|
return nil, fmt.Errorf("can't commit iltxn: %w", err)
|
||||||
}
|
}
|
||||||
iltxn = nil
|
iltxn = nil
|
||||||
|
|
||||||
@@ -110,5 +112,5 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
|||||||
il.mmmm.mergeNewFreeRange(*acquiredFreeRangeFromDelete)
|
il.mmmm.mergeNewFreeRange(*acquiredFreeRangeFromDelete)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return deleted, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+22
-7
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -104,25 +103,41 @@ func (b *MultiMmapManager) storeOn(
|
|||||||
return false, fmt.Errorf("event too large to store, max %d, got %d", 1<<16, pos.size)
|
return false, fmt.Errorf("event too large to store, max %d, got %d", 1<<16, pos.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find a suitable place for this to be stored in
|
// find a suitable place for this to be stored in (search only large free ranges)
|
||||||
appendToMmap := true
|
appendToMmap := true
|
||||||
for f, fr := range b.freeRanges {
|
for f, fr := range b.freeRangesLarge {
|
||||||
if fr.size >= pos.size {
|
if fr.size >= pos.size {
|
||||||
// found the smallest possible place that can fit this event
|
// found a place that can fit this event
|
||||||
appendToMmap = false
|
appendToMmap = false
|
||||||
pos.start = fr.start
|
pos.start = fr.start
|
||||||
|
|
||||||
// modify the free ranges we're keeping track of
|
// modify the free ranges we're keeping track of
|
||||||
// (in case of conflict we lose this free range but it's ok, it will be recovered on the next startup)
|
// (in case of conflict we lose this free range but it's ok, it will be recovered on the next startup)
|
||||||
if pos.size == fr.size {
|
if pos.size == fr.size {
|
||||||
// if we've used it entirely just delete it
|
// if we've used it entirely just delete it (swap-delete since it's unsorted)
|
||||||
b.freeRanges = slices.Delete(b.freeRanges, f, f+1)
|
b.freeRangesLarge[f] = b.freeRangesLarge[len(b.freeRangesLarge)-1]
|
||||||
|
b.freeRangesLarge = b.freeRangesLarge[0 : len(b.freeRangesLarge)-1]
|
||||||
|
|
||||||
|
// also delete it from b.freeRangesAll
|
||||||
|
b.freeRangesAll = b.freeRangesAll.del(fr.start)
|
||||||
} else {
|
} else {
|
||||||
// otherwise modify it in place
|
// otherwise modify it in place
|
||||||
b.freeRanges[f] = position{
|
newFreeRange := position{
|
||||||
start: fr.start + uint64(pos.size),
|
start: fr.start + uint64(pos.size),
|
||||||
size: fr.size - pos.size,
|
size: fr.size - pos.size,
|
||||||
}
|
}
|
||||||
|
// only keep it in freeRangesLarge if it's still large enough
|
||||||
|
if newFreeRange.size >= LARGE_FREERANGE {
|
||||||
|
b.freeRangesLarge[f] = newFreeRange
|
||||||
|
} else {
|
||||||
|
// remove it from freeRangesLarge if it's no longer large enough
|
||||||
|
b.freeRangesLarge[f] = b.freeRangesLarge[len(b.freeRangesLarge)-1]
|
||||||
|
b.freeRangesLarge = b.freeRangesLarge[0 : len(b.freeRangesLarge)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// also modify it in b.freeRangesAll
|
||||||
|
idx := b.freeRangesAll.find(fr.start)
|
||||||
|
b.freeRangesAll[idx] = newFreeRange
|
||||||
}
|
}
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ func (b NullStore) SaveEvent(evt nostr.Event) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b NullStore) ReplaceEvent(evt nostr.Event) error {
|
func (b NullStore) ReplaceEvent(evt nostr.Event) ([]nostr.Event, error) {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b NullStore) CountEvents(filter nostr.Filter) (uint32, error) {
|
func (b NullStore) CountEvents(filter nostr.Filter) (uint32, error) {
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func (b *SliceStore) delete(id nostr.ID) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *SliceStore) ReplaceEvent(evt nostr.Event) error {
|
func (b *SliceStore) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
@@ -135,8 +135,9 @@ func (b *SliceStore) ReplaceEvent(evt nostr.Event) error {
|
|||||||
for previous := range b.QueryEvents(filter, 1) {
|
for previous := range b.QueryEvents(filter, 1) {
|
||||||
if nostr.IsOlder(previous, evt) {
|
if nostr.IsOlder(previous, evt) {
|
||||||
if err := b.delete(previous.ID); err != nil {
|
if err := b.delete(previous.ID); err != nil {
|
||||||
return fmt.Errorf("failed to delete event for replacing: %w", err)
|
return nil, fmt.Errorf("failed to delete event for replacing: %w", err)
|
||||||
}
|
}
|
||||||
|
deleted = append(deleted, previous)
|
||||||
} else {
|
} else {
|
||||||
shouldStore = false
|
shouldStore = false
|
||||||
}
|
}
|
||||||
@@ -144,11 +145,11 @@ func (b *SliceStore) ReplaceEvent(evt nostr.Event) error {
|
|||||||
|
|
||||||
if shouldStore {
|
if shouldStore {
|
||||||
if err := b.save(evt); err != nil && err != eventstore.ErrDupEvent {
|
if err := b.save(evt); err != nil && err != eventstore.ErrDupEvent {
|
||||||
return fmt.Errorf("failed to save: %w", err)
|
return nil, fmt.Errorf("failed to save: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return deleted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func eventTimestampComparator(e nostr.Event, t nostr.Timestamp) int {
|
func eventTimestampComparator(e nostr.Event, t nostr.Timestamp) int {
|
||||||
|
|||||||
+1
-1
@@ -26,7 +26,7 @@ type Store interface {
|
|||||||
|
|
||||||
// ReplaceEvent atomically replaces a replaceable or addressable event.
|
// ReplaceEvent atomically replaces a replaceable or addressable event.
|
||||||
// Conceptually it is like a Query->Delete->Save, but streamlined.
|
// Conceptually it is like a Query->Delete->Save, but streamlined.
|
||||||
ReplaceEvent(nostr.Event) error
|
ReplaceEvent(nostr.Event) (deleted []nostr.Event, err error)
|
||||||
|
|
||||||
// CountEvents counts all events that match a given filter
|
// CountEvents counts all events that match a given filter
|
||||||
CountEvents(nostr.Filter) (uint32, error)
|
CountEvents(nostr.Filter) (uint32, error)
|
||||||
|
|||||||
@@ -128,6 +128,24 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
|||||||
require.Len(t, results, 1)
|
require.Len(t, results, 1)
|
||||||
require.Equal(t, events[5].ID, results[0].ID, "author + kind query error")
|
require.Equal(t, events[5].ID, results[0].ID, "author + kind query error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test 5: until
|
||||||
|
{
|
||||||
|
results := slices.Collect(db.QueryEvents(nostr.Filter{Until: 102}, 1000))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 3)
|
||||||
|
|
||||||
|
resultsWithTag := slices.Collect(db.QueryEvents(nostr.Filter{
|
||||||
|
Until: 102,
|
||||||
|
Tags: nostr.TagMap{
|
||||||
|
"e": []string{
|
||||||
|
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, 1000))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resultsWithTag, 1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// from another-basic-test.patch
|
// from another-basic-test.patch
|
||||||
@@ -223,7 +241,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
|||||||
}
|
}
|
||||||
originalProfile.Sign(sk3)
|
originalProfile.Sign(sk3)
|
||||||
|
|
||||||
err = db.ReplaceEvent(originalProfile)
|
_, err = db.ReplaceEvent(originalProfile)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// verify
|
// verify
|
||||||
@@ -244,7 +262,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
|||||||
newProfile.Sign(sk3)
|
newProfile.Sign(sk3)
|
||||||
|
|
||||||
// replace with newer event
|
// replace with newer event
|
||||||
err = db.ReplaceEvent(newProfile)
|
_, err = db.ReplaceEvent(newProfile)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// verify only the newer event exists
|
// verify only the newer event exists
|
||||||
@@ -264,7 +282,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
|||||||
}
|
}
|
||||||
olderProfile.Sign(sk3)
|
olderProfile.Sign(sk3)
|
||||||
|
|
||||||
err = db.ReplaceEvent(olderProfile)
|
_, err = db.ReplaceEvent(olderProfile)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// verify the newer event is still there
|
// verify the newer event is still there
|
||||||
@@ -284,7 +302,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
|||||||
}
|
}
|
||||||
articleV1.Sign(sk3)
|
articleV1.Sign(sk3)
|
||||||
|
|
||||||
err = db.ReplaceEvent(articleV1)
|
_, err = db.ReplaceEvent(articleV1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// verify article was saved
|
// verify article was saved
|
||||||
@@ -305,7 +323,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
|||||||
}
|
}
|
||||||
articleV2.Sign(sk3)
|
articleV2.Sign(sk3)
|
||||||
|
|
||||||
err = db.ReplaceEvent(articleV2)
|
_, err = db.ReplaceEvent(articleV2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// verify only the newer version exists
|
// verify only the newer version exists
|
||||||
@@ -327,7 +345,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
|||||||
}
|
}
|
||||||
differentArticle.Sign(sk3)
|
differentArticle.Sign(sk3)
|
||||||
|
|
||||||
err = db.ReplaceEvent(differentArticle)
|
_, err = db.ReplaceEvent(differentArticle)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// verify both articles exist (different d tags)
|
// verify both articles exist (different d tags)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ var tests = []struct {
|
|||||||
{"manyauthors", manyAuthorsTest},
|
{"manyauthors", manyAuthorsTest},
|
||||||
{"unbalanced", unbalancedTest},
|
{"unbalanced", unbalancedTest},
|
||||||
{"count", countTest},
|
{"count", countTest},
|
||||||
|
{"pfilter-until", pTagUntilMismatchTest},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSliceStore(t *testing.T) {
|
func TestSliceStore(t *testing.T) {
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"fiatjaf.com/nostr"
|
||||||
|
"fiatjaf.com/nostr/eventstore"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pTagUntilMismatchTest(t *testing.T, db eventstore.Store) {
|
||||||
|
err := db.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
targetP := "460c25e682fda7832b52d1f22d3d22b3176d972f60dcdc3212ed8c92ef85065c"
|
||||||
|
author := nostr.MustPubKeyFromHex("7fa56f5d6962ab1e3cd424e758c3002b8665f7b0d8dcee9fe9e288d7751ac194")
|
||||||
|
|
||||||
|
events := []nostr.Event{
|
||||||
|
{
|
||||||
|
Kind: 9802,
|
||||||
|
ID: nostr.MustIDFromHex("2c997233fa580b1a831f989d8fa320c409f8412d9c75b819c9a29df102d7f901"),
|
||||||
|
PubKey: author,
|
||||||
|
CreatedAt: 1773835689,
|
||||||
|
Tags: nostr.Tags{
|
||||||
|
{"e", "bcaa6599e69cff48ed6ab4b0b315d4f33a869a4ba8fa808287700faebe17195f", "wss://nos.lol/", "source"},
|
||||||
|
{"p", targetP, "", "author"},
|
||||||
|
},
|
||||||
|
Content: "With so few people donating in the zap the devs button, the incentives are quite low to produce cool new things",
|
||||||
|
Sig: sigFromHex(t, "b49206476c4d2a5f44590331541c83910fd826c0f4cdab99ceffd5bcf3aca94935e3db9d7820e7db3a0f1165c43a28dd3173a81fd08bf8348629ea4efde02537"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: 9802,
|
||||||
|
ID: nostr.MustIDFromHex("31c1eddb3a5201ef1bbce91b9fb3b7d8fe3e3eb25a66bedadcbc93c84d072c7d"),
|
||||||
|
PubKey: author,
|
||||||
|
CreatedAt: 1773154080,
|
||||||
|
Tags: nostr.Tags{
|
||||||
|
{"p", "5ea4648045bb1ff222655ddd36e6dceddc43590c26090c486bef38ef450da5bd", "", "mention"},
|
||||||
|
{"p", "c8fb0d3aa788b9ace4f6cb92dd97d3f292db25b5c9f92462ef6c64926129fbaf", "", "mention"},
|
||||||
|
{"p", "2f29aa33c2a3b45c2ef32212879248b2f4a49a002bd0de0fa16c94e138ac6f13", "", "mention"},
|
||||||
|
{"p", targetP, "", "mention"},
|
||||||
|
{"comment", "normie"},
|
||||||
|
{"e", "5911eeba39a6886fe8abea82bb50612d27d1273d63904c9b64cde070c7088d48", "wss://relay.primal.net/", "source"},
|
||||||
|
{"p", "3f770d65d3a764a9c5cb503ae123e62ec7598ad035d836e2a810f3877a745b24", "", "author"},
|
||||||
|
},
|
||||||
|
Content: "grimoire is cool, but it's too nerdy for me",
|
||||||
|
Sig: sigFromHex(t, "0ee01515d54293d52fa1247a395e64c8499df96eee80c30204cb7c8fc5b5977023e6ac00cc240f137ce7e594818545340bf74bce6d3de86539f1a5d26fe33f24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: 9802,
|
||||||
|
ID: nostr.MustIDFromHex("baeb90e2075c9d8a9b41286dbf1c52e5ef8ad6c030118839ce24d065e72df9b7"),
|
||||||
|
PubKey: author,
|
||||||
|
CreatedAt: 1773154058,
|
||||||
|
Tags: nostr.Tags{
|
||||||
|
{"p", "c8fb0d3aa788b9ace4f6cb92dd97d3f292db25b5c9f92462ef6c64926129fbaf", "", "mention"},
|
||||||
|
{"p", "2f29aa33c2a3b45c2ef32212879248b2f4a49a002bd0de0fa16c94e138ac6f13", "", "mention"},
|
||||||
|
{"p", targetP, "", "mention"},
|
||||||
|
{"p", "3f770d65d3a764a9c5cb503ae123e62ec7598ad035d836e2a810f3877a745b24", "", "mention"},
|
||||||
|
{"comment", "no lies detected"},
|
||||||
|
{"e", "81171c564cedbc5f07e5b7a9d06842d1f43a81cd79c8755921190382de55c514", "wss://nos.lol/", "source"},
|
||||||
|
{"p", "5ea4648045bb1ff222655ddd36e6dceddc43590c26090c486bef38ef450da5bd", "", "author"},
|
||||||
|
},
|
||||||
|
Content: "i never used grimoire once in my life, but that is not the point, it is still the best client",
|
||||||
|
Sig: sigFromHex(t, "3e4b855c4e3a4d2b3a078d593e728f1bbdb07af91ba5831b7866e2d16df90ce58a5f9f1db5733603911b20b334bc7fc5ae1482c7870b9f7acde4a8ccc080a79d"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, evt := range events {
|
||||||
|
err = db.SaveEvent(evt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := slices.Collect(db.QueryEvents(nostr.Filter{
|
||||||
|
Until: 1733934976,
|
||||||
|
Limit: 3,
|
||||||
|
Tags: nostr.TagMap{"p": []string{targetP}},
|
||||||
|
}, 1000))
|
||||||
|
require.Len(t, results, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sigFromHex(t *testing.T, sigStr string) [64]byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
raw, err := hex.DecodeString(sigStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, raw, 64)
|
||||||
|
|
||||||
|
var sig [64]byte
|
||||||
|
copy(sig[:], raw)
|
||||||
|
return sig
|
||||||
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package wrappers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
|
||||||
|
"fiatjaf.com/nostr"
|
||||||
|
"fiatjaf.com/nostr/eventstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ nostr.Publisher = DynamicPublisher{}
|
||||||
|
|
||||||
|
type DynamicPublisher struct {
|
||||||
|
GetStore func() eventstore.Store
|
||||||
|
MaxLimit int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w DynamicPublisher) QueryEvents(filter nostr.Filter) iter.Seq[nostr.Event] {
|
||||||
|
return w.GetStore().QueryEvents(filter, w.MaxLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w DynamicPublisher) Publish(ctx context.Context, evt nostr.Event) error {
|
||||||
|
if evt.Kind.IsEphemeral() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if evt.Kind.IsRegular() {
|
||||||
|
if err := w.GetStore().SaveEvent(evt); err != nil && err != eventstore.ErrDupEvent {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := w.GetStore().ReplaceEvent(evt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -39,5 +39,6 @@ func (w StorePublisher) Publish(ctx context.Context, evt nostr.Event) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// others are replaced
|
// others are replaced
|
||||||
return w.Store.ReplaceEvent(evt)
|
_, err := w.Store.ReplaceEvent(evt)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ func (ef Filter) Matches(event Event) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:inline
|
||||||
func (ef Filter) MatchesIgnoringTimestampConstraints(event Event) bool {
|
func (ef Filter) MatchesIgnoringTimestampConstraints(event Event) bool {
|
||||||
if ef.IDs != nil && !slices.Contains(ef.IDs, event.ID) {
|
if ef.IDs != nil && !slices.Contains(ef.IDs, event.ID) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -40,8 +40,10 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
fiatjaf.com/lib v0.3.6
|
||||||
github.com/dgraph-io/ristretto/v2 v2.3.0
|
github.com/dgraph-io/ristretto/v2 v2.3.0
|
||||||
github.com/go-git/go-git/v5 v5.16.3
|
github.com/go-git/go-git/v5 v5.16.3
|
||||||
|
github.com/pemistahl/lingua-go v1.4.0
|
||||||
github.com/sivukhin/godjot v1.0.6
|
github.com/sivukhin/godjot v1.0.6
|
||||||
github.com/templexxx/cpu v0.0.1
|
github.com/templexxx/cpu v0.0.1
|
||||||
github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b
|
github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b
|
||||||
@@ -63,6 +65,7 @@ require (
|
|||||||
github.com/blevesearch/scorch_segment_api/v2 v2.2.16 // indirect
|
github.com/blevesearch/scorch_segment_api/v2 v2.2.16 // indirect
|
||||||
github.com/blevesearch/segment v0.9.1 // indirect
|
github.com/blevesearch/segment v0.9.1 // indirect
|
||||||
github.com/blevesearch/snowballstem v0.9.0 // indirect
|
github.com/blevesearch/snowballstem v0.9.0 // indirect
|
||||||
|
github.com/blevesearch/stempel v0.2.0 // indirect
|
||||||
github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect
|
github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect
|
||||||
github.com/blevesearch/vellum v1.0.11 // indirect
|
github.com/blevesearch/vellum v1.0.11 // indirect
|
||||||
github.com/blevesearch/zapx/v11 v11.3.10 // indirect
|
github.com/blevesearch/zapx/v11 v11.3.10 // indirect
|
||||||
@@ -93,6 +96,7 @@ require (
|
|||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect
|
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect
|
||||||
|
github.com/shopspring/decimal v1.3.1 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
fiatjaf.com/lib v0.3.6 h1:GRZNSxHI2EWdjSKVuzaT+c0aifLDtS16SzkeJaHyJfY=
|
||||||
|
fiatjaf.com/lib v0.3.6/go.mod h1:UlHaZvPHj25PtKLh9GjZkUHRmQ2xZ8Jkoa4VRaLeeQ8=
|
||||||
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc=
|
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc=
|
||||||
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw=
|
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw=
|
||||||
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc=
|
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc=
|
||||||
@@ -40,6 +42,8 @@ github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+j
|
|||||||
github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw=
|
github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw=
|
||||||
github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s=
|
github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s=
|
||||||
github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs=
|
github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs=
|
||||||
|
github.com/blevesearch/stempel v0.2.0 h1:CYzVPaScODMvgE9o+kf6D4RJ/VRomyi9uHF+PtB+Afc=
|
||||||
|
github.com/blevesearch/stempel v0.2.0/go.mod h1:wjeTHqQv+nQdbPuJ/YcvOjTInA2EIc6Ks1FoSUzSLvc=
|
||||||
github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A=
|
github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A=
|
||||||
github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ=
|
github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ=
|
||||||
github.com/blevesearch/vellum v1.0.11 h1:SJI97toEFTtA9WsDZxkyGTaBWFdWl1n2LEDCXLCq/AU=
|
github.com/blevesearch/vellum v1.0.11 h1:SJI97toEFTtA9WsDZxkyGTaBWFdWl1n2LEDCXLCq/AU=
|
||||||
@@ -190,6 +194,8 @@ github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5
|
|||||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||||
|
github.com/pemistahl/lingua-go v1.4.0 h1:ifYhthrlW7iO4icdubwlduYnmwU37V1sbNrwhKBR4rM=
|
||||||
|
github.com/pemistahl/lingua-go v1.4.0/go.mod h1:ECuM1Hp/3hvyh7k8aWSqNCPlTxLemFZsRjocUf3KgME=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
@@ -207,6 +213,8 @@ github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
|||||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc=
|
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc=
|
||||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
|
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
|
||||||
|
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||||
|
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||||
github.com/sivukhin/godjot v1.0.6 h1:yoRD+hlcDbSxP9Gd/KRVlEFXgtGyZyt0CHwhY6Gk3EQ=
|
github.com/sivukhin/godjot v1.0.6 h1:yoRD+hlcDbSxP9Gd/KRVlEFXgtGyZyt0CHwhY6Gk3EQ=
|
||||||
github.com/sivukhin/godjot v1.0.6/go.mod h1:wA6KdR4Z+XpwdwyViPDLWYYxT72pKjNc6XGA9I025gM=
|
github.com/sivukhin/godjot v1.0.6/go.mod h1:wA6KdR4Z+XpwdwyViPDLWYYxT72pKjNc6XGA9I025gM=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
|||||||
+1
-1
@@ -46,7 +46,7 @@ func (rl *Relay) handleNormal(ctx context.Context, evt nostr.Event) (skipBroadca
|
|||||||
} else {
|
} else {
|
||||||
// otherwise it's a replaceable
|
// otherwise it's a replaceable
|
||||||
if nil != rl.ReplaceEvent {
|
if nil != rl.ReplaceEvent {
|
||||||
if err := rl.ReplaceEvent(ctx, evt); err != nil {
|
if _, err := rl.ReplaceEvent(ctx, evt); err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
case eventstore.ErrDupEvent:
|
case eventstore.ErrDupEvent:
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|||||||
@@ -78,6 +78,9 @@ func (rl *Relay) handleDeleteRequest(ctx context.Context, evt nostr.Event) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
haveDeletedSomething = true
|
haveDeletedSomething = true
|
||||||
|
if rl.OnEventDeleted != nil {
|
||||||
|
rl.OnEventDeleted(ctx, target)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
---
|
|
||||||
outline: deep
|
|
||||||
---
|
|
||||||
|
|
||||||
# Request Routing
|
|
||||||
|
|
||||||
If you have one (or more) set of policies that have to be executed in sequence (for example, first you check for the presence of a tag, then later in the next policies you use that tag without checking) and they only apply to some class of events, but you still want your relay to deal with other classes of events that can lead to cumbersome sets of rules, always having to check if an event meets the requirements and so on. There is where routing can help you.
|
|
||||||
|
|
||||||
```go
|
|
||||||
sk := os.Getenv("RELAY_SECRET_KEY")
|
|
||||||
|
|
||||||
// a relay for NIP-29 groups
|
|
||||||
groupsStore := boltdb.BoltBackend{}
|
|
||||||
groupsStore.Init()
|
|
||||||
groupsRelay, _ := khatru29.Init(relay29.Options{Domain: "example.com", DB: groupsStore, SecretKey: sk})
|
|
||||||
// ...
|
|
||||||
|
|
||||||
// a relay for everything else
|
|
||||||
publicStore := slicestore.SliceStore{}
|
|
||||||
publicStore.Init()
|
|
||||||
publicRelay := khatru.NewRelay()
|
|
||||||
publicRelay.UseEventStore(publicStore, 1000)
|
|
||||||
// ...
|
|
||||||
|
|
||||||
// a higher-level relay that just routes between the two above
|
|
||||||
router := khatru.NewRouter()
|
|
||||||
|
|
||||||
// route requests and events to the groups relay
|
|
||||||
router.Route().
|
|
||||||
Req(func (filter nostr.Filter) bool {
|
|
||||||
_, hasHTag := filter.Tags["h"]
|
|
||||||
if hasHTag {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return slices.Contains(filter.Kinds, func (k int) bool { return k == 39000 || k == 39001 || k == 39002 })
|
|
||||||
}).
|
|
||||||
Event(func (event *nostr.Event) bool {
|
|
||||||
switch {
|
|
||||||
case event.Kind <= 9021 && event.Kind >= 9000:
|
|
||||||
return true
|
|
||||||
case event.Kind <= 39010 && event.Kind >= 39000:
|
|
||||||
return true
|
|
||||||
case event.Kind <= 12 && event.Kind >= 9:
|
|
||||||
return true
|
|
||||||
case event.Tags.Find("h") != nil:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}).
|
|
||||||
Relay(groupsRelay)
|
|
||||||
|
|
||||||
// route requests and events to the other
|
|
||||||
router.Route().
|
|
||||||
Req(func (filter nostr.Filter) bool { return true }).
|
|
||||||
Event(func (event *nostr.Event) bool { return true }).
|
|
||||||
Relay(publicRelay)
|
|
||||||
```
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
|
||||||
"fiatjaf.com/nostr/eventstore/lmdb"
|
|
||||||
"fiatjaf.com/nostr/eventstore/slicestore"
|
|
||||||
"fiatjaf.com/nostr/khatru"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
db1 := &slicestore.SliceStore{}
|
|
||||||
db1.Init()
|
|
||||||
r1 := khatru.NewRelay()
|
|
||||||
r1.UseEventstore(db1, 400)
|
|
||||||
|
|
||||||
db2 := &lmdb.LMDBBackend{Path: "/tmp/t"}
|
|
||||||
db2.Init()
|
|
||||||
r2 := khatru.NewRelay()
|
|
||||||
r2.UseEventstore(db2, 400)
|
|
||||||
|
|
||||||
db3 := &slicestore.SliceStore{}
|
|
||||||
db3.Init()
|
|
||||||
r3 := khatru.NewRelay()
|
|
||||||
r3.UseEventstore(db3, 400)
|
|
||||||
|
|
||||||
router := khatru.NewRouter()
|
|
||||||
|
|
||||||
router.Route().
|
|
||||||
Req(func(filter nostr.Filter) bool {
|
|
||||||
return slices.Contains(filter.Kinds, 30023)
|
|
||||||
}).
|
|
||||||
Event(func(event *nostr.Event) bool {
|
|
||||||
return event.Kind == 30023
|
|
||||||
}).
|
|
||||||
Relay(r1)
|
|
||||||
|
|
||||||
router.Route().
|
|
||||||
Req(func(filter nostr.Filter) bool {
|
|
||||||
return slices.Contains(filter.Kinds, 1) && slices.Contains(filter.Tags["t"], "spam")
|
|
||||||
}).
|
|
||||||
Event(func(event *nostr.Event) bool {
|
|
||||||
return event.Kind == 1 && event.Tags.FindWithValue("t", "spam") != nil
|
|
||||||
}).
|
|
||||||
Relay(r2)
|
|
||||||
|
|
||||||
router.Route().
|
|
||||||
Req(func(filter nostr.Filter) bool {
|
|
||||||
return slices.Contains(filter.Kinds, 1)
|
|
||||||
}).
|
|
||||||
Event(func(event *nostr.Event) bool {
|
|
||||||
return event.Kind == 1
|
|
||||||
}).
|
|
||||||
Relay(r3)
|
|
||||||
|
|
||||||
fmt.Println("running on :3334")
|
|
||||||
http.ListenAndServe(":3334", router)
|
|
||||||
}
|
|
||||||
+17
-5
@@ -39,9 +39,15 @@ type expirationManager struct {
|
|||||||
events expiringEventHeap
|
events expiringEventHeap
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// a function to query the relay database, generally the same as relay.queryStored
|
||||||
queryStored func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event]
|
queryStored func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event]
|
||||||
|
|
||||||
|
// a function to delete an event from the relay database, generally the same as relay.DeleteEvent
|
||||||
deleteEvent func(ctx context.Context, id nostr.ID) error
|
deleteEvent func(ctx context.Context, id nostr.ID) error
|
||||||
|
|
||||||
|
// a function to call after an event has been deleted, generally the same as relay.OnEventDeleted
|
||||||
|
deleteCallback func(ctx context.Context, id nostr.Event)
|
||||||
|
|
||||||
interval time.Duration
|
interval time.Duration
|
||||||
initialScanDone bool
|
initialScanDone bool
|
||||||
kill chan struct{} // used for manually killing this
|
kill chan struct{} // used for manually killing this
|
||||||
@@ -109,7 +115,11 @@ func (em *expirationManager) checkExpiredEvents(ctx context.Context) {
|
|||||||
heap.Pop(&em.events)
|
heap.Pop(&em.events)
|
||||||
|
|
||||||
ctx := context.WithValue(ctx, internalCallKey, struct{}{})
|
ctx := context.WithValue(ctx, internalCallKey, struct{}{})
|
||||||
em.deleteEvent(ctx, next.id)
|
if nil == em.deleteEvent(ctx, next.id) && em.deleteCallback != nil {
|
||||||
|
for evt := range em.queryStored(ctx, nostr.Filter{IDs: []nostr.ID{next.id}}) {
|
||||||
|
em.deleteCallback(ctx, evt)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,12 +152,14 @@ func (em *expirationManager) removeEvent(id nostr.ID) {
|
|||||||
func (rl *Relay) StartExpirationManager(
|
func (rl *Relay) StartExpirationManager(
|
||||||
queryStored func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event],
|
queryStored func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event],
|
||||||
deleteEvent func(ctx context.Context, id nostr.ID) error,
|
deleteEvent func(ctx context.Context, id nostr.ID) error,
|
||||||
|
onDeleteCallback func(ctx context.Context, evt nostr.Event),
|
||||||
) {
|
) {
|
||||||
rl.expirationManager = &expirationManager{
|
rl.expirationManager = &expirationManager{
|
||||||
events: make(expiringEventHeap, 0),
|
events: make(expiringEventHeap, 0),
|
||||||
|
|
||||||
queryStored: queryStored,
|
queryStored: queryStored,
|
||||||
deleteEvent: deleteEvent,
|
deleteEvent: deleteEvent,
|
||||||
|
deleteCallback: onDeleteCallback,
|
||||||
|
|
||||||
interval: time.Hour,
|
interval: time.Hour,
|
||||||
kill: make(chan struct{}),
|
kill: make(chan struct{}),
|
||||||
@@ -155,14 +167,14 @@ func (rl *Relay) StartExpirationManager(
|
|||||||
}
|
}
|
||||||
|
|
||||||
go rl.expirationManager.start(rl.ctx)
|
go rl.expirationManager.start(rl.ctx)
|
||||||
rl.Info.AddSupportedNIP(40)
|
rl.Info.AddSupportedNIP("40")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *Relay) DisableExpirationManager() {
|
func (rl *Relay) DisableExpirationManager() {
|
||||||
rl.expirationManager.stop()
|
rl.expirationManager.stop()
|
||||||
rl.expirationManager = nil
|
rl.expirationManager = nil
|
||||||
|
|
||||||
idx := slices.Index(rl.Info.SupportedNIPs, 40)
|
idx := slices.Index(rl.Info.SupportedNIPs, "40")
|
||||||
if idx != -1 {
|
if idx != -1 {
|
||||||
rl.Info.SupportedNIPs[idx] = rl.Info.SupportedNIPs[len(rl.Info.SupportedNIPs)-1]
|
rl.Info.SupportedNIPs[idx] = rl.Info.SupportedNIPs[len(rl.Info.SupportedNIPs)-1]
|
||||||
rl.Info.SupportedNIPs = rl.Info.SupportedNIPs[0 : len(rl.Info.SupportedNIPs)-1]
|
rl.Info.SupportedNIPs = rl.Info.SupportedNIPs[0 : len(rl.Info.SupportedNIPs)-1]
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
package khatru
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fasthttp/websocket"
|
|
||||||
"github.com/rs/cors"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (rl *Relay) Router() *http.ServeMux {
|
|
||||||
return rl.serveMux
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rl *Relay) SetRouter(mux *http.ServeMux) {
|
|
||||||
rl.serveMux = mux
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start creates an http server and starts listening on given host and port.
|
|
||||||
func (rl *Relay) Start(host string, port int, started ...chan bool) error {
|
|
||||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
|
||||||
ln, err := net.Listen("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rl.Addr = ln.Addr().String()
|
|
||||||
rl.httpServer = &http.Server{
|
|
||||||
Handler: cors.Default().Handler(rl),
|
|
||||||
Addr: addr,
|
|
||||||
WriteTimeout: 2 * time.Second,
|
|
||||||
ReadTimeout: 2 * time.Second,
|
|
||||||
IdleTimeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// notify caller that we're starting
|
|
||||||
for _, started := range started {
|
|
||||||
close(started)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rl.httpServer.Serve(ln); err == http.ErrServerClosed {
|
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown sends a websocket close control message to all connected clients.
|
|
||||||
func (rl *Relay) Shutdown(ctx context.Context) {
|
|
||||||
rl.httpServer.Shutdown(ctx)
|
|
||||||
rl.clientsMutex.Lock()
|
|
||||||
defer rl.clientsMutex.Unlock()
|
|
||||||
for ws := range rl.clients {
|
|
||||||
ws.conn.WriteControl(websocket.CloseMessage, nil, time.Now().Add(time.Second))
|
|
||||||
ws.cancel()
|
|
||||||
ws.conn.Close()
|
|
||||||
}
|
|
||||||
clear(rl.clients)
|
|
||||||
rl.listeners = rl.listeners[:0]
|
|
||||||
}
|
|
||||||
@@ -31,7 +31,7 @@ func New(rl *khatru.Relay, repositoryDir string) *GraspServer {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rl.Info.AddSupportedNIP(34)
|
rl.Info.AddSupportedNIP("34")
|
||||||
rl.Info.SupportedGrasps = append(rl.Info.SupportedGrasps, "GRASP-01")
|
rl.Info.SupportedGrasps = append(rl.Info.SupportedGrasps, "GRASP-01")
|
||||||
|
|
||||||
base := rl.Router()
|
base := rl.Router()
|
||||||
|
|||||||
+28
-40
@@ -108,17 +108,20 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
killOnce := sync.Once{}
|
||||||
kill := func() {
|
kill := func() {
|
||||||
if nil != rl.OnDisconnect {
|
killOnce.Do(func() {
|
||||||
rl.OnDisconnect(ctx)
|
if nil != rl.OnDisconnect {
|
||||||
}
|
rl.OnDisconnect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
cancel()
|
cancel()
|
||||||
ws.cancel()
|
ws.cancel()
|
||||||
ws.conn.Close()
|
ws.conn.Close()
|
||||||
|
|
||||||
rl.removeClientAndListeners(ws)
|
rl.removeClientAndListeners(ws)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -214,35 +217,30 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
srl := rl
|
|
||||||
if rl.getSubRelayFromEvent != nil {
|
|
||||||
srl = rl.getSubRelayFromEvent(&env.Event)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ok bool
|
var ok bool
|
||||||
var writeErr error
|
var writeErr error
|
||||||
var skipBroadcast bool
|
var skipBroadcast bool
|
||||||
|
|
||||||
if env.Event.Kind == nostr.KindDeletion {
|
if env.Event.Kind == nostr.KindDeletion {
|
||||||
// store the delete event first
|
// store the delete event first
|
||||||
skipBroadcast, writeErr = srl.handleNormal(ctx, env.Event)
|
skipBroadcast, writeErr = rl.handleNormal(ctx, env.Event)
|
||||||
if writeErr == nil {
|
if writeErr == nil {
|
||||||
// this always returns "blocked: " whenever it returns an error
|
// this always returns "blocked: " whenever it returns an error
|
||||||
writeErr = srl.handleDeleteRequest(ctx, env.Event)
|
writeErr = rl.handleDeleteRequest(ctx, env.Event)
|
||||||
}
|
}
|
||||||
} else if env.Event.Kind.IsEphemeral() {
|
} else if env.Event.Kind.IsEphemeral() {
|
||||||
// this will also always return a prefixed reason
|
// this will also always return a prefixed reason
|
||||||
writeErr = srl.handleEphemeral(ctx, env.Event)
|
writeErr = rl.handleEphemeral(ctx, env.Event)
|
||||||
} else {
|
} else {
|
||||||
// this will also always return a prefixed reason
|
// this will also always return a prefixed reason
|
||||||
skipBroadcast, writeErr = srl.handleNormal(ctx, env.Event)
|
skipBroadcast, writeErr = rl.handleNormal(ctx, env.Event)
|
||||||
}
|
}
|
||||||
|
|
||||||
var reason string
|
var reason string
|
||||||
if writeErr == nil {
|
if writeErr == nil {
|
||||||
ok = true
|
ok = true
|
||||||
if !skipBroadcast {
|
if !skipBroadcast {
|
||||||
n := srl.notifyListeners(env.Event, false)
|
n := rl.notifyListeners(env.Event, false)
|
||||||
|
|
||||||
// the number of notified listeners matters in ephemeral events
|
// the number of notified listeners matters in ephemeral events
|
||||||
if env.Event.Kind.IsEphemeral() {
|
if env.Event.Kind.IsEphemeral() {
|
||||||
@@ -275,15 +273,10 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
var total uint32
|
var total uint32
|
||||||
var hll *hyperloglog.HyperLogLog
|
var hll *hyperloglog.HyperLogLog
|
||||||
|
|
||||||
srl := rl
|
|
||||||
if rl.getSubRelayFromFilter != nil {
|
|
||||||
srl = rl.getSubRelayFromFilter(env.Filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
if offset := nip45.HyperLogLogEventPubkeyOffsetForFilter(env.Filter); offset != -1 {
|
if offset := nip45.HyperLogLogEventPubkeyOffsetForFilter(env.Filter); offset != -1 {
|
||||||
total, hll = srl.handleCountRequestWithHLL(ctx, ws, env.Filter, offset)
|
total, hll = rl.handleCountRequestWithHLL(ctx, ws, env.Filter, offset)
|
||||||
} else {
|
} else {
|
||||||
total = srl.handleCountRequest(ctx, ws, env.Filter)
|
total = rl.handleCountRequest(ctx, ws, env.Filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := nostr.CountEnvelope{
|
resp := nostr.CountEnvelope{
|
||||||
@@ -308,11 +301,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// handle each filter separately -- dispatching events as they're loaded from databases
|
// handle each filter separately -- dispatching events as they're loaded from databases
|
||||||
for _, filter := range env.Filters {
|
for _, filter := range env.Filters {
|
||||||
srl := rl
|
err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
|
||||||
if rl.getSubRelayFromFilter != nil {
|
|
||||||
srl = rl.getSubRelayFromFilter(filter)
|
|
||||||
}
|
|
||||||
err := srl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// fail everything if any filter is rejected
|
// fail everything if any filter is rejected
|
||||||
reason := err.Error()
|
reason := err.Error()
|
||||||
@@ -322,8 +311,11 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason})
|
ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason})
|
||||||
cancelReqCtx(errors.New("filter rejected"))
|
cancelReqCtx(errors.New("filter rejected"))
|
||||||
return
|
return
|
||||||
} else {
|
} else if filter.IDs == nil {
|
||||||
rl.addListener(ws, env.SubscriptionID, srl, filter, cancelReqCtx)
|
// a query that is just a bunch of "ids": [...] will not add listeners.
|
||||||
|
// is this a bug? maybe, but I don't think anyone is listening for an ID
|
||||||
|
// that hasn't been published yet anywhere -- if yes we can change later
|
||||||
|
rl.addListener(ws, env.SubscriptionID, filter, cancelReqCtx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,15 +352,11 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate: " + err.Error()})
|
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate: " + err.Error()})
|
||||||
}
|
}
|
||||||
case *nip77.OpenEnvelope:
|
case *nip77.OpenEnvelope:
|
||||||
srl := rl
|
if !rl.Negentropy {
|
||||||
if rl.getSubRelayFromFilter != nil {
|
// ignore
|
||||||
srl = rl.getSubRelayFromFilter(env.Filter)
|
return
|
||||||
if !srl.Negentropy {
|
|
||||||
// ignore
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
vec, err := srl.startNegentropySession(ctx, env.Filter)
|
vec, err := rl.startNegentropySession(ctx, env.Filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// fail everything if any filter is rejected
|
// fail everything if any filter is rejected
|
||||||
reason := err.Error()
|
reason := err.Error()
|
||||||
|
|||||||
+225
-77
@@ -3,18 +3,19 @@ package khatru
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"slices"
|
"iter"
|
||||||
|
|
||||||
|
"fiatjaf.com/lib/set"
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrSubscriptionClosedByClient = errors.New("subscription closed by client")
|
var ErrSubscriptionClosedByClient = errors.New("subscription closed by client")
|
||||||
|
|
||||||
type listenerSpec struct {
|
type listenerSpec struct {
|
||||||
id string // kept here so we can easily match against it removeListenerId
|
ssid int // internal numeric id for a listener
|
||||||
cancel context.CancelCauseFunc
|
sid string // client-provided subscription id
|
||||||
index int
|
cancel context.CancelCauseFunc
|
||||||
subrelay *Relay // this is important when we're dealing with routing, otherwise it will be always the same
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type listener struct {
|
type listener struct {
|
||||||
@@ -23,10 +24,199 @@ type listener struct {
|
|||||||
ws *WebSocket
|
ws *WebSocket
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type subscription struct {
|
||||||
|
id string
|
||||||
|
filter nostr.Filter
|
||||||
|
ws *WebSocket
|
||||||
|
}
|
||||||
|
|
||||||
|
type dispatcher struct {
|
||||||
|
serial int
|
||||||
|
subscriptions *xsync.MapOf[int, subscription]
|
||||||
|
byAuthor map[nostr.PubKey]set.Set[int]
|
||||||
|
byKind map[nostr.Kind]set.Set[int]
|
||||||
|
fallback set.Set[int]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDispatcher() dispatcher {
|
||||||
|
return dispatcher{
|
||||||
|
subscriptions: xsync.NewMapOf[int, subscription](),
|
||||||
|
byAuthor: make(map[nostr.PubKey]set.Set[int]),
|
||||||
|
byKind: make(map[nostr.Kind]set.Set[int]),
|
||||||
|
fallback: set.NewSliceSet[int](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dispatcher) addSubscription(sub subscription) int {
|
||||||
|
d.serial++
|
||||||
|
ssid := d.serial
|
||||||
|
|
||||||
|
d.subscriptions.Store(ssid, sub)
|
||||||
|
|
||||||
|
indexed := false
|
||||||
|
if sub.filter.Authors != nil {
|
||||||
|
indexed = true
|
||||||
|
for _, author := range sub.filter.Authors {
|
||||||
|
s, ok := d.byAuthor[author]
|
||||||
|
if !ok {
|
||||||
|
s = set.NewSliceSet[int]()
|
||||||
|
d.byAuthor[author] = s
|
||||||
|
}
|
||||||
|
s.Add(ssid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sub.filter.Kinds != nil {
|
||||||
|
indexed = true
|
||||||
|
for _, kind := range sub.filter.Kinds {
|
||||||
|
s, ok := d.byKind[kind]
|
||||||
|
if !ok {
|
||||||
|
s = set.NewSliceSet[int]()
|
||||||
|
d.byKind[kind] = s
|
||||||
|
}
|
||||||
|
s.Add(ssid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !indexed {
|
||||||
|
d.fallback.Add(ssid)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dispatcher) removeSubscription(ssid int) {
|
||||||
|
sub, ok := d.subscriptions.LoadAndDelete(ssid)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
indexed := false
|
||||||
|
if sub.filter.Authors != nil {
|
||||||
|
indexed = true
|
||||||
|
for _, author := range sub.filter.Authors {
|
||||||
|
s, ok := d.byAuthor[author]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.Remove(ssid)
|
||||||
|
if s.Len() == 0 {
|
||||||
|
delete(d.byAuthor, author)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sub.filter.Kinds != nil {
|
||||||
|
indexed = true
|
||||||
|
for _, kind := range sub.filter.Kinds {
|
||||||
|
s, ok := d.byKind[kind]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.Remove(ssid)
|
||||||
|
if s.Len() == 0 {
|
||||||
|
delete(d.byKind, kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !indexed {
|
||||||
|
d.fallback.Remove(ssid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dispatcher) candidates(event nostr.Event) iter.Seq[subscription] {
|
||||||
|
return func(yield func(subscription) bool) {
|
||||||
|
authorSubs, hasAuthorSubs := d.byAuthor[event.PubKey]
|
||||||
|
kindSubs, hasKindSubs := d.byKind[event.Kind]
|
||||||
|
|
||||||
|
if hasAuthorSubs && hasKindSubs {
|
||||||
|
for _, ssid := range authorSubs.Slice() {
|
||||||
|
sub, _ := d.subscriptions.Load(ssid)
|
||||||
|
|
||||||
|
if kindSubs.Has(ssid) {
|
||||||
|
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||||
|
if !yield(sub) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// matched author but not tags, so this event doesn't qualify for any filter
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if hasAuthorSubs {
|
||||||
|
for _, ssid := range authorSubs.Slice() {
|
||||||
|
sub, _ := d.subscriptions.Load(ssid)
|
||||||
|
|
||||||
|
if sub.filter.Kinds != nil {
|
||||||
|
// if there are any kinds in the filter we already know this doesn't qualify
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||||
|
if !yield(sub) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if hasKindSubs {
|
||||||
|
for _, ssid := range kindSubs.Slice() {
|
||||||
|
sub, _ := d.subscriptions.Load(ssid)
|
||||||
|
|
||||||
|
if sub.filter.Authors != nil {
|
||||||
|
// if there are any authors in the filter we already know this doesn't qualify
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||||
|
if !yield(sub) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ssid := range d.fallback.Slice() {
|
||||||
|
sub, _ := d.subscriptions.Load(ssid)
|
||||||
|
|
||||||
|
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||||
|
if !yield(sub) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:inline
|
||||||
|
func filterMatchesTimestampConstraintsAndTags(filter nostr.Filter, event nostr.Event) bool {
|
||||||
|
if filter.Since != 0 && event.CreatedAt < filter.Since {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Until != 0 && event.CreatedAt > filter.Until {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for f, v := range filter.Tags {
|
||||||
|
if !event.Tags.ContainsAny(f, v) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:inline
|
||||||
|
func tagKeyValueKey(tagKey, tagValue string) string {
|
||||||
|
return tagKey + "\x00" + tagValue
|
||||||
|
}
|
||||||
|
|
||||||
func (rl *Relay) GetListeningFilters() []nostr.Filter {
|
func (rl *Relay) GetListeningFilters() []nostr.Filter {
|
||||||
respfilters := make([]nostr.Filter, len(rl.listeners))
|
respfilters := make([]nostr.Filter, 0, rl.dispatcher.subscriptions.Size())
|
||||||
for i, l := range rl.listeners {
|
for _, sub := range rl.dispatcher.subscriptions.Range {
|
||||||
respfilters[i] = l.filter
|
respfilters = append(respfilters, sub.filter)
|
||||||
}
|
}
|
||||||
return respfilters
|
return respfilters
|
||||||
}
|
}
|
||||||
@@ -36,26 +226,27 @@ func (rl *Relay) GetListeningFilters() []nostr.Filter {
|
|||||||
func (rl *Relay) addListener(
|
func (rl *Relay) addListener(
|
||||||
ws *WebSocket,
|
ws *WebSocket,
|
||||||
id string,
|
id string,
|
||||||
subrelay *Relay,
|
|
||||||
filter nostr.Filter,
|
filter nostr.Filter,
|
||||||
cancel context.CancelCauseFunc,
|
cancel context.CancelCauseFunc,
|
||||||
) {
|
) {
|
||||||
rl.clientsMutex.Lock()
|
select {
|
||||||
defer rl.clientsMutex.Unlock()
|
case <-rl.clientsMutex.C():
|
||||||
|
defer rl.clientsMutex.Unlock()
|
||||||
|
case <-ws.Context.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if specs, ok := rl.clients[ws]; ok /* this will always be true unless client has disconnected very rapidly */ {
|
if specs, ok := rl.clients[ws]; ok /* this will always be true unless client has disconnected very rapidly */ {
|
||||||
idx := len(subrelay.listeners)
|
ssid := rl.dispatcher.addSubscription(subscription{
|
||||||
rl.clients[ws] = append(specs, listenerSpec{
|
|
||||||
id: id,
|
|
||||||
cancel: cancel,
|
|
||||||
subrelay: subrelay,
|
|
||||||
index: idx,
|
|
||||||
})
|
|
||||||
subrelay.listeners = append(subrelay.listeners, listener{
|
|
||||||
ws: ws,
|
ws: ws,
|
||||||
id: id,
|
id: id,
|
||||||
filter: filter,
|
filter: filter,
|
||||||
})
|
})
|
||||||
|
rl.clients[ws] = append(specs, listenerSpec{
|
||||||
|
ssid: ssid,
|
||||||
|
cancel: cancel,
|
||||||
|
sid: id,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,35 +257,16 @@ func (rl *Relay) removeListenerId(ws *WebSocket, id string) {
|
|||||||
defer rl.clientsMutex.Unlock()
|
defer rl.clientsMutex.Unlock()
|
||||||
|
|
||||||
if specs, ok := rl.clients[ws]; ok {
|
if specs, ok := rl.clients[ws]; ok {
|
||||||
// swap delete specs that match this id
|
kept := specs[:0]
|
||||||
for s := len(specs) - 1; s >= 0; s-- {
|
for _, spec := range specs {
|
||||||
spec := specs[s]
|
if spec.sid == id {
|
||||||
if spec.id == id {
|
|
||||||
spec.cancel(ErrSubscriptionClosedByClient)
|
spec.cancel(ErrSubscriptionClosedByClient)
|
||||||
specs[s] = specs[len(specs)-1]
|
rl.dispatcher.removeSubscription(spec.ssid)
|
||||||
specs = specs[0 : len(specs)-1]
|
continue
|
||||||
rl.clients[ws] = specs
|
|
||||||
|
|
||||||
// swap delete listeners one at a time, as they may be each in a different subrelay
|
|
||||||
srl := spec.subrelay // == rl in normal cases, but different when this came from a route
|
|
||||||
|
|
||||||
if spec.index != len(srl.listeners)-1 {
|
|
||||||
movedFromIndex := len(srl.listeners) - 1
|
|
||||||
moved := srl.listeners[movedFromIndex] // this wasn't removed, but will be moved
|
|
||||||
srl.listeners[spec.index] = moved
|
|
||||||
|
|
||||||
// now we must update the the listener we just moved
|
|
||||||
// so its .index reflects its new position on srl.listeners
|
|
||||||
movedSpecs := rl.clients[moved.ws]
|
|
||||||
idx := slices.IndexFunc(movedSpecs, func(ls listenerSpec) bool {
|
|
||||||
return ls.index == movedFromIndex && ls.subrelay == srl
|
|
||||||
})
|
|
||||||
movedSpecs[idx].index = spec.index
|
|
||||||
rl.clients[moved.ws] = movedSpecs
|
|
||||||
}
|
|
||||||
srl.listeners = srl.listeners[0 : len(srl.listeners)-1] // finally reduce the slice length
|
|
||||||
}
|
}
|
||||||
|
kept = append(kept, spec)
|
||||||
}
|
}
|
||||||
|
rl.clients[ws] = kept
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,31 +274,9 @@ func (rl *Relay) removeClientAndListeners(ws *WebSocket) {
|
|||||||
rl.clientsMutex.Lock()
|
rl.clientsMutex.Lock()
|
||||||
defer rl.clientsMutex.Unlock()
|
defer rl.clientsMutex.Unlock()
|
||||||
if specs, ok := rl.clients[ws]; ok {
|
if specs, ok := rl.clients[ws]; ok {
|
||||||
// swap delete listeners and delete client (all specs will be deleted)
|
for _, spec := range specs {
|
||||||
for s, spec := range specs {
|
|
||||||
// no need to cancel contexts since they inherit from the main connection context
|
// no need to cancel contexts since they inherit from the main connection context
|
||||||
// just delete the listeners (swap-delete)
|
rl.dispatcher.removeSubscription(spec.ssid)
|
||||||
srl := spec.subrelay
|
|
||||||
|
|
||||||
if spec.index != len(srl.listeners)-1 {
|
|
||||||
movedFromIndex := len(srl.listeners) - 1
|
|
||||||
moved := srl.listeners[movedFromIndex] // this wasn't removed, but will be moved
|
|
||||||
srl.listeners[spec.index] = moved
|
|
||||||
|
|
||||||
// temporarily update the spec of the listener being removed to have index == -1
|
|
||||||
// (since it was removed) so it doesn't match in the search below
|
|
||||||
rl.clients[ws][s].index = -1
|
|
||||||
|
|
||||||
// now we must update the the listener we just moved
|
|
||||||
// so its .index reflects its new position on srl.listeners
|
|
||||||
movedSpecs := rl.clients[moved.ws]
|
|
||||||
idx := slices.IndexFunc(movedSpecs, func(ls listenerSpec) bool {
|
|
||||||
return ls.index == movedFromIndex && ls.subrelay == srl
|
|
||||||
})
|
|
||||||
movedSpecs[idx].index = spec.index
|
|
||||||
rl.clients[moved.ws] = movedSpecs
|
|
||||||
}
|
|
||||||
srl.listeners = srl.listeners[0 : len(srl.listeners)-1] // finally reduce the slice length
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(rl.clients, ws)
|
delete(rl.clients, ws)
|
||||||
@@ -136,16 +286,14 @@ func (rl *Relay) removeClientAndListeners(ws *WebSocket) {
|
|||||||
func (rl *Relay) notifyListeners(event nostr.Event, skipPrevent bool) int {
|
func (rl *Relay) notifyListeners(event nostr.Event, skipPrevent bool) int {
|
||||||
count := 0
|
count := 0
|
||||||
listenersloop:
|
listenersloop:
|
||||||
for _, listener := range rl.listeners {
|
for sub := range rl.dispatcher.candidates(event) {
|
||||||
if listener.filter.Matches(event) {
|
if !skipPrevent && nil != rl.PreventBroadcast {
|
||||||
if !skipPrevent && nil != rl.PreventBroadcast {
|
if rl.PreventBroadcast(sub.ws, sub.filter, event) {
|
||||||
if rl.PreventBroadcast(listener.ws, listener.filter, event) {
|
continue listenersloop
|
||||||
continue listenersloop
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
listener.ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &listener.id, Event: event})
|
|
||||||
count++
|
|
||||||
}
|
}
|
||||||
|
sub.ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &sub.id, Event: event})
|
||||||
|
count++
|
||||||
}
|
}
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func FuzzRandomListenerClientRemoving(f *testing.F) {
|
|||||||
l := 0
|
l := 0
|
||||||
|
|
||||||
for i := 0; i < totalWebsockets; i++ {
|
for i := 0; i < totalWebsockets; i++ {
|
||||||
ws := &WebSocket{}
|
ws := &WebSocket{Context: rl.ctx}
|
||||||
websockets = append(websockets, ws)
|
websockets = append(websockets, ws)
|
||||||
rl.clients[ws] = nil
|
rl.clients[ws] = nil
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ func FuzzRandomListenerClientRemoving(f *testing.F) {
|
|||||||
|
|
||||||
if s%addListenerFreq == 0 {
|
if s%addListenerFreq == 0 {
|
||||||
l++
|
l++
|
||||||
rl.addListener(ws, w+":"+idFromSeqLower(j), rl, f, cancel)
|
rl.addListener(ws, w+":"+idFromSeqLower(j), f, cancel)
|
||||||
}
|
}
|
||||||
|
|
||||||
s++
|
s++
|
||||||
@@ -46,14 +46,22 @@ func FuzzRandomListenerClientRemoving(f *testing.F) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.clients, totalWebsockets)
|
require.Len(t, rl.clients, totalWebsockets)
|
||||||
require.Len(t, rl.listeners, l)
|
ssidCount := 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, l, ssidCount)
|
||||||
|
|
||||||
for ws := range rl.clients {
|
for ws := range rl.clients {
|
||||||
rl.removeClientAndListeners(ws)
|
rl.removeClientAndListeners(ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.clients, 0)
|
require.Len(t, rl.clients, 0)
|
||||||
require.Len(t, rl.listeners, 0)
|
ssidCount = 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, 0, ssidCount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,7 +92,7 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
|||||||
extra := 0
|
extra := 0
|
||||||
|
|
||||||
for i := 0; i < totalWebsockets; i++ {
|
for i := 0; i < totalWebsockets; i++ {
|
||||||
ws := &WebSocket{}
|
ws := &WebSocket{Context: rl.ctx}
|
||||||
websockets = append(websockets, ws)
|
websockets = append(websockets, ws)
|
||||||
rl.clients[ws] = nil
|
rl.clients[ws] = nil
|
||||||
}
|
}
|
||||||
@@ -97,11 +105,11 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
|||||||
|
|
||||||
if s%addListenerFreq == 0 {
|
if s%addListenerFreq == 0 {
|
||||||
id := w + ":" + idFromSeqLower(j)
|
id := w + ":" + idFromSeqLower(j)
|
||||||
rl.addListener(ws, id, rl, f, cancel)
|
rl.addListener(ws, id, f, cancel)
|
||||||
subs = append(subs, wsid{ws, id})
|
subs = append(subs, wsid{ws, id})
|
||||||
|
|
||||||
if s%addExtraListenerFreq == 0 {
|
if s%addExtraListenerFreq == 0 {
|
||||||
rl.addListener(ws, id, rl, f, cancel)
|
rl.addListener(ws, id, f, cancel)
|
||||||
extra++
|
extra++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -111,7 +119,11 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.clients, totalWebsockets)
|
require.Len(t, rl.clients, totalWebsockets)
|
||||||
require.Len(t, rl.listeners, len(subs)+extra)
|
ssidCount := 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, len(subs)+extra, ssidCount)
|
||||||
|
|
||||||
rand.Shuffle(len(subs), func(i, j int) {
|
rand.Shuffle(len(subs), func(i, j int) {
|
||||||
subs[i], subs[j] = subs[j], subs[i]
|
subs[i], subs[j] = subs[j], subs[i]
|
||||||
@@ -120,7 +132,11 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
|||||||
rl.removeListenerId(wsidToRemove.ws, wsidToRemove.id)
|
rl.removeListenerId(wsidToRemove.ws, wsidToRemove.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.listeners, 0)
|
ssidCount = 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, 0, ssidCount)
|
||||||
require.Len(t, rl.clients, totalWebsockets)
|
require.Len(t, rl.clients, totalWebsockets)
|
||||||
for _, specs := range rl.clients {
|
for _, specs := range rl.clients {
|
||||||
require.Len(t, specs, 0)
|
require.Len(t, specs, 0)
|
||||||
@@ -129,23 +145,17 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func FuzzRouterListenersPabloCrash(f *testing.F) {
|
func FuzzRouterListenersPabloCrash(f *testing.F) {
|
||||||
f.Add(uint(3), uint(6), uint(2), uint(20))
|
f.Add(uint(6), uint(2), uint(20))
|
||||||
f.Fuzz(func(t *testing.T, totalRelays uint, totalConns uint, subFreq uint, subIterations uint) {
|
f.Fuzz(func(t *testing.T, totalConns uint, subFreq uint, subIterations uint) {
|
||||||
totalRelays++
|
|
||||||
totalConns++
|
totalConns++
|
||||||
subFreq++
|
subFreq++
|
||||||
subIterations++
|
subIterations++
|
||||||
|
|
||||||
rl := NewRelay()
|
rl := NewRelay()
|
||||||
|
|
||||||
relays := make([]*Relay, int(totalRelays))
|
|
||||||
for i := 0; i < int(totalRelays); i++ {
|
|
||||||
relays[i] = NewRelay()
|
|
||||||
}
|
|
||||||
|
|
||||||
conns := make([]*WebSocket, int(totalConns))
|
conns := make([]*WebSocket, int(totalConns))
|
||||||
for i := 0; i < int(totalConns); i++ {
|
for i := 0; i < int(totalConns); i++ {
|
||||||
ws := &WebSocket{}
|
ws := &WebSocket{Context: rl.ctx}
|
||||||
conns[i] = ws
|
conns[i] = ws
|
||||||
rl.clients[ws] = make([]listenerSpec, 0, subIterations)
|
rl.clients[ws] = make([]listenerSpec, 0, subIterations)
|
||||||
}
|
}
|
||||||
@@ -159,18 +169,16 @@ func FuzzRouterListenersPabloCrash(f *testing.F) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := 0
|
s := 0
|
||||||
subs := make([]wsid, 0, subIterations*totalConns*totalRelays)
|
subs := make([]wsid, 0, subIterations*totalConns)
|
||||||
for i, conn := range conns {
|
for i, conn := range conns {
|
||||||
w := idFromSeqUpper(i)
|
w := idFromSeqUpper(i)
|
||||||
for j := 0; j < int(subIterations); j++ {
|
for j := 0; j < int(subIterations); j++ {
|
||||||
id := w + ":" + idFromSeqLower(j)
|
id := w + ":" + idFromSeqLower(j)
|
||||||
for _, rlt := range relays {
|
if s%int(subFreq) == 0 {
|
||||||
if s%int(subFreq) == 0 {
|
rl.addListener(conn, id, f, cancel)
|
||||||
rl.addListener(conn, id, rlt, f, cancel)
|
subs = append(subs, wsid{conn, id})
|
||||||
subs = append(subs, wsid{conn, id})
|
|
||||||
}
|
|
||||||
s++
|
|
||||||
}
|
}
|
||||||
|
s++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,8 +189,5 @@ func FuzzRouterListenersPabloCrash(f *testing.F) {
|
|||||||
for _, wsid := range subs {
|
for _, wsid := range subs {
|
||||||
require.Len(t, rl.clients[wsid.ws], 0)
|
require.Len(t, rl.clients[wsid.ws], 0)
|
||||||
}
|
}
|
||||||
for _, rlt := range relays {
|
|
||||||
require.Len(t, rlt.listeners, 0)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
+117
-225
@@ -26,8 +26,8 @@ func idFromSeq(seq int, min, max int) string {
|
|||||||
func TestListenerSetupAndRemoveOnce(t *testing.T) {
|
func TestListenerSetupAndRemoveOnce(t *testing.T) {
|
||||||
rl := NewRelay()
|
rl := NewRelay()
|
||||||
|
|
||||||
ws1 := &WebSocket{}
|
ws1 := &WebSocket{Context: rl.ctx}
|
||||||
ws2 := &WebSocket{}
|
ws2 := &WebSocket{Context: rl.ctx}
|
||||||
|
|
||||||
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||||
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
||||||
@@ -39,28 +39,21 @@ func TestListenerSetupAndRemoveOnce(t *testing.T) {
|
|||||||
var cancel func(cause error) = nil
|
var cancel func(cause error) = nil
|
||||||
|
|
||||||
t.Run("adding listeners", func(t *testing.T) {
|
t.Run("adding listeners", func(t *testing.T) {
|
||||||
rl.addListener(ws1, "1a", rl, f1, cancel)
|
rl.addListener(ws1, "1a", f1, cancel)
|
||||||
rl.addListener(ws1, "1b", rl, f2, cancel)
|
rl.addListener(ws1, "1b", f2, cancel)
|
||||||
rl.addListener(ws2, "2a", rl, f3, cancel)
|
rl.addListener(ws2, "2a", f3, cancel)
|
||||||
rl.addListener(ws1, "1c", rl, f3, cancel)
|
rl.addListener(ws1, "1c", f3, cancel)
|
||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"1a", cancel, 0, rl},
|
{1, "1a", cancel},
|
||||||
{"1b", cancel, 1, rl},
|
{2, "1b", cancel},
|
||||||
{"1c", cancel, 3, rl},
|
{4, "1c", cancel},
|
||||||
},
|
},
|
||||||
ws2: {
|
ws2: {
|
||||||
{"2a", cancel, 2, rl},
|
{3, "2a", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"1a", f1, ws1},
|
|
||||||
{"1b", f2, ws1},
|
|
||||||
{"2a", f3, ws2},
|
|
||||||
{"1c", f3, ws1},
|
|
||||||
}, rl.listeners)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("removing a client", func(t *testing.T) {
|
t.Run("removing a client", func(t *testing.T) {
|
||||||
@@ -68,23 +61,19 @@ func TestListenerSetupAndRemoveOnce(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws2: {
|
ws2: {
|
||||||
{"2a", cancel, 0, rl},
|
{3, "2a", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"2a", f3, ws2},
|
|
||||||
}, rl.listeners)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListenerMoreConvolutedCase(t *testing.T) {
|
func TestListenerMoreConvolutedCase(t *testing.T) {
|
||||||
rl := NewRelay()
|
rl := NewRelay()
|
||||||
|
|
||||||
ws1 := &WebSocket{}
|
ws1 := &WebSocket{Context: rl.ctx}
|
||||||
ws2 := &WebSocket{}
|
ws2 := &WebSocket{Context: rl.ctx}
|
||||||
ws3 := &WebSocket{}
|
ws3 := &WebSocket{Context: rl.ctx}
|
||||||
ws4 := &WebSocket{}
|
ws4 := &WebSocket{Context: rl.ctx}
|
||||||
|
|
||||||
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||||
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
||||||
@@ -98,35 +87,27 @@ func TestListenerMoreConvolutedCase(t *testing.T) {
|
|||||||
var cancel func(cause error) = nil
|
var cancel func(cause error) = nil
|
||||||
|
|
||||||
t.Run("adding listeners", func(t *testing.T) {
|
t.Run("adding listeners", func(t *testing.T) {
|
||||||
rl.addListener(ws1, "c", rl, f1, cancel)
|
rl.addListener(ws1, "c", f1, cancel)
|
||||||
rl.addListener(ws2, "b", rl, f2, cancel)
|
rl.addListener(ws2, "b", f2, cancel)
|
||||||
rl.addListener(ws3, "a", rl, f3, cancel)
|
rl.addListener(ws3, "a", f3, cancel)
|
||||||
rl.addListener(ws4, "d", rl, f3, cancel)
|
rl.addListener(ws4, "d", f3, cancel)
|
||||||
rl.addListener(ws2, "b", rl, f1, cancel)
|
rl.addListener(ws2, "b", f1, cancel)
|
||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 0, rl},
|
{1, "c", cancel},
|
||||||
},
|
},
|
||||||
ws2: {
|
ws2: {
|
||||||
{"b", cancel, 1, rl},
|
{2, "b", cancel},
|
||||||
{"b", cancel, 4, rl},
|
{5, "b", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"a", cancel, 2, rl},
|
{3, "a", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"d", cancel, 3, rl},
|
{4, "d", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"b", f2, ws2},
|
|
||||||
{"a", f3, ws3},
|
|
||||||
{"d", f3, ws4},
|
|
||||||
{"b", f1, ws2},
|
|
||||||
}, rl.listeners)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("removing a client", func(t *testing.T) {
|
t.Run("removing a client", func(t *testing.T) {
|
||||||
@@ -134,85 +115,62 @@ func TestListenerMoreConvolutedCase(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 0, rl},
|
{1, "c", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"a", cancel, 2, rl},
|
{3, "a", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"d", cancel, 1, rl},
|
{4, "d", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"d", f3, ws4},
|
|
||||||
{"a", f3, ws3},
|
|
||||||
}, rl.listeners)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reorganize the first case differently and then remove again", func(t *testing.T) {
|
t.Run("reorganize the first case differently and then remove again", func(t *testing.T) {
|
||||||
rl.clients = map[*WebSocket][]listenerSpec{
|
rl.clients = map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 1, rl},
|
{2, "c", cancel},
|
||||||
},
|
},
|
||||||
ws2: {
|
ws2: {
|
||||||
{"b", cancel, 2, rl},
|
{3, "b", cancel},
|
||||||
{"b", cancel, 4, rl},
|
{5, "b", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"a", cancel, 0, rl},
|
{1, "a", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"d", cancel, 3, rl},
|
{4, "d", cancel},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rl.listeners = []listener{
|
|
||||||
{"a", f3, ws3},
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"b", f2, ws2},
|
|
||||||
{"d", f3, ws4},
|
|
||||||
{"b", f1, ws2},
|
|
||||||
}
|
|
||||||
|
|
||||||
rl.removeClientAndListeners(ws2)
|
rl.removeClientAndListeners(ws2)
|
||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 1, rl},
|
{2, "c", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"a", cancel, 0, rl},
|
{1, "a", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"d", cancel, 2, rl},
|
{4, "d", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"a", f3, ws3},
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"d", f3, ws4},
|
|
||||||
}, rl.listeners)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
||||||
rl := NewRelay()
|
rl := NewRelay()
|
||||||
|
|
||||||
ws1 := &WebSocket{}
|
ws1 := &WebSocket{Context: rl.ctx}
|
||||||
ws2 := &WebSocket{}
|
ws2 := &WebSocket{Context: rl.ctx}
|
||||||
ws3 := &WebSocket{}
|
ws3 := &WebSocket{Context: rl.ctx}
|
||||||
ws4 := &WebSocket{}
|
ws4 := &WebSocket{Context: rl.ctx}
|
||||||
|
|
||||||
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||||
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
||||||
f3 := nostr.Filter{Kinds: []nostr.Kind{3}}
|
f3 := nostr.Filter{Kinds: []nostr.Kind{3}}
|
||||||
|
|
||||||
rlx := NewRelay()
|
|
||||||
rly := NewRelay()
|
|
||||||
rlz := NewRelay()
|
|
||||||
|
|
||||||
rl.clients[ws1] = nil
|
rl.clients[ws1] = nil
|
||||||
rl.clients[ws2] = nil
|
rl.clients[ws2] = nil
|
||||||
rl.clients[ws3] = nil
|
rl.clients[ws3] = nil
|
||||||
@@ -221,56 +179,37 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
|||||||
var cancel func(cause error) = nil
|
var cancel func(cause error) = nil
|
||||||
|
|
||||||
t.Run("adding listeners", func(t *testing.T) {
|
t.Run("adding listeners", func(t *testing.T) {
|
||||||
rl.addListener(ws1, "c", rlx, f1, cancel)
|
rl.addListener(ws1, "c", f1, cancel)
|
||||||
rl.addListener(ws2, "b", rly, f2, cancel)
|
rl.addListener(ws2, "b", f2, cancel)
|
||||||
rl.addListener(ws3, "a", rlz, f3, cancel)
|
rl.addListener(ws3, "a", f3, cancel)
|
||||||
rl.addListener(ws4, "d", rlx, f3, cancel)
|
rl.addListener(ws4, "d", f3, cancel)
|
||||||
rl.addListener(ws4, "e", rlx, f3, cancel)
|
rl.addListener(ws4, "e", f3, cancel)
|
||||||
rl.addListener(ws3, "a", rlx, f3, cancel)
|
rl.addListener(ws3, "a", f3, cancel)
|
||||||
rl.addListener(ws4, "e", rly, f3, cancel)
|
rl.addListener(ws4, "e", f3, cancel)
|
||||||
rl.addListener(ws3, "f", rly, f3, cancel)
|
rl.addListener(ws3, "f", f3, cancel)
|
||||||
rl.addListener(ws1, "g", rlz, f1, cancel)
|
rl.addListener(ws1, "g", f1, cancel)
|
||||||
rl.addListener(ws2, "g", rlz, f2, cancel)
|
rl.addListener(ws2, "g", f2, cancel)
|
||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 0, rlx},
|
{1, "c", cancel},
|
||||||
{"g", cancel, 1, rlz},
|
{9, "g", cancel},
|
||||||
},
|
},
|
||||||
ws2: {
|
ws2: {
|
||||||
{"b", cancel, 0, rly},
|
{2, "b", cancel},
|
||||||
{"g", cancel, 2, rlz},
|
{10, "g", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"a", cancel, 0, rlz},
|
{3, "a", cancel},
|
||||||
{"a", cancel, 3, rlx},
|
{6, "a", cancel},
|
||||||
{"f", cancel, 2, rly},
|
{8, "f", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"d", cancel, 1, rlx},
|
{4, "d", cancel},
|
||||||
{"e", cancel, 2, rlx},
|
{5, "e", cancel},
|
||||||
{"e", cancel, 1, rly},
|
{7, "e", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"d", f3, ws4},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
{"a", f3, ws3},
|
|
||||||
}, rlx.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"b", f2, ws2},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
{"f", f3, ws3},
|
|
||||||
}, rly.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"a", f3, ws3},
|
|
||||||
{"g", f1, ws1},
|
|
||||||
{"g", f2, ws2},
|
|
||||||
}, rlz.listeners)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("removing a subscription id", func(t *testing.T) {
|
t.Run("removing a subscription id", func(t *testing.T) {
|
||||||
@@ -280,41 +219,23 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 0, rlx},
|
{1, "c", cancel},
|
||||||
{"g", cancel, 1, rlz},
|
{9, "g", cancel},
|
||||||
},
|
},
|
||||||
ws2: {
|
ws2: {
|
||||||
{"b", cancel, 0, rly},
|
{2, "b", cancel},
|
||||||
{"g", cancel, 2, rlz},
|
{10, "g", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"a", cancel, 0, rlz},
|
{3, "a", cancel},
|
||||||
{"a", cancel, 1, rlx},
|
{6, "a", cancel},
|
||||||
{"f", cancel, 2, rly},
|
{8, "f", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"e", cancel, 1, rly},
|
{5, "e", cancel},
|
||||||
{"e", cancel, 2, rlx},
|
{7, "e", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"a", f3, ws3},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
}, rlx.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"b", f2, ws2},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
{"f", f3, ws3},
|
|
||||||
}, rly.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"a", f3, ws3},
|
|
||||||
{"g", f1, ws1},
|
|
||||||
{"g", f2, ws2},
|
|
||||||
}, rlz.listeners)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("removing another subscription id", func(t *testing.T) {
|
t.Run("removing another subscription id", func(t *testing.T) {
|
||||||
@@ -325,37 +246,21 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 0, rlx},
|
{1, "c", cancel},
|
||||||
{"g", cancel, 1, rlz},
|
{9, "g", cancel},
|
||||||
},
|
},
|
||||||
ws2: {
|
ws2: {
|
||||||
{"b", cancel, 0, rly},
|
{2, "b", cancel},
|
||||||
{"g", cancel, 0, rlz},
|
{10, "g", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"f", cancel, 2, rly},
|
{8, "f", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"e", cancel, 1, rly},
|
{5, "e", cancel},
|
||||||
{"e", cancel, 1, rlx},
|
{7, "e", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
}, rlx.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"b", f2, ws2},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
{"f", f3, ws3},
|
|
||||||
}, rly.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"g", f2, ws2},
|
|
||||||
{"g", f1, ws1},
|
|
||||||
}, rlz.listeners)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("removing a connection", func(t *testing.T) {
|
t.Run("removing a connection", func(t *testing.T) {
|
||||||
@@ -363,31 +268,17 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 0, rlx},
|
{1, "c", cancel},
|
||||||
{"g", cancel, 0, rlz},
|
{9, "g", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"f", cancel, 0, rly},
|
{8, "f", cancel},
|
||||||
},
|
},
|
||||||
ws4: {
|
ws4: {
|
||||||
{"e", cancel, 1, rly},
|
{5, "e", cancel},
|
||||||
{"e", cancel, 1, rlx},
|
{7, "e", cancel},
|
||||||
},
|
},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"c", f1, ws1},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
}, rlx.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"f", f3, ws3},
|
|
||||||
{"e", f3, ws4},
|
|
||||||
}, rly.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"g", f1, ws1},
|
|
||||||
}, rlz.listeners)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("removing another subscription id", func(t *testing.T) {
|
t.Run("removing another subscription id", func(t *testing.T) {
|
||||||
@@ -398,26 +289,14 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||||
ws1: {
|
ws1: {
|
||||||
{"c", cancel, 0, rlx},
|
{1, "c", cancel},
|
||||||
{"g", cancel, 0, rlz},
|
{9, "g", cancel},
|
||||||
},
|
},
|
||||||
ws3: {
|
ws3: {
|
||||||
{"f", cancel, 0, rly},
|
{8, "f", cancel},
|
||||||
},
|
},
|
||||||
ws4: {},
|
ws4: {},
|
||||||
}, rl.clients)
|
}, rl.clients)
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"c", f1, ws1},
|
|
||||||
}, rlx.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"f", f3, ws3},
|
|
||||||
}, rly.listeners)
|
|
||||||
|
|
||||||
require.Equal(t, []listener{
|
|
||||||
{"g", f1, ws1},
|
|
||||||
}, rlz.listeners)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,7 +311,7 @@ func TestRandomListenerClientRemoving(t *testing.T) {
|
|||||||
l := 0
|
l := 0
|
||||||
|
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
ws := &WebSocket{}
|
ws := &WebSocket{Context: rl.ctx}
|
||||||
websockets = append(websockets, ws)
|
websockets = append(websockets, ws)
|
||||||
rl.clients[ws] = nil
|
rl.clients[ws] = nil
|
||||||
}
|
}
|
||||||
@@ -444,20 +323,28 @@ func TestRandomListenerClientRemoving(t *testing.T) {
|
|||||||
|
|
||||||
if rand.Intn(2) < 1 {
|
if rand.Intn(2) < 1 {
|
||||||
l++
|
l++
|
||||||
rl.addListener(ws, w+":"+idFromSeqLower(j), rl, f, cancel)
|
rl.addListener(ws, w+":"+idFromSeqLower(j), f, cancel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.clients, 20)
|
require.Len(t, rl.clients, 20)
|
||||||
require.Len(t, rl.listeners, l)
|
ssidCount := 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, l, ssidCount)
|
||||||
|
|
||||||
for ws := range rl.clients {
|
for ws := range rl.clients {
|
||||||
rl.removeClientAndListeners(ws)
|
rl.removeClientAndListeners(ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.clients, 0)
|
require.Len(t, rl.clients, 0)
|
||||||
require.Len(t, rl.listeners, 0)
|
ssidCount = 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, 0, ssidCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRandomListenerIdRemoving(t *testing.T) {
|
func TestRandomListenerIdRemoving(t *testing.T) {
|
||||||
@@ -477,7 +364,7 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
|||||||
extra := 0
|
extra := 0
|
||||||
|
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
ws := &WebSocket{}
|
ws := &WebSocket{Context: rl.ctx}
|
||||||
websockets = append(websockets, ws)
|
websockets = append(websockets, ws)
|
||||||
rl.clients[ws] = nil
|
rl.clients[ws] = nil
|
||||||
}
|
}
|
||||||
@@ -489,11 +376,11 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
|||||||
|
|
||||||
if rand.Intn(2) < 1 {
|
if rand.Intn(2) < 1 {
|
||||||
id := w + ":" + idFromSeqLower(j)
|
id := w + ":" + idFromSeqLower(j)
|
||||||
rl.addListener(ws, id, rl, f, cancel)
|
rl.addListener(ws, id, f, cancel)
|
||||||
subs = append(subs, wsid{ws, id})
|
subs = append(subs, wsid{ws, id})
|
||||||
|
|
||||||
if rand.Intn(5) < 1 {
|
if rand.Intn(5) < 1 {
|
||||||
rl.addListener(ws, id, rl, f, cancel)
|
rl.addListener(ws, id, f, cancel)
|
||||||
extra++
|
extra++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -501,7 +388,11 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.clients, 20)
|
require.Len(t, rl.clients, 20)
|
||||||
require.Len(t, rl.listeners, len(subs)+extra)
|
ssidCount := 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, len(subs)+extra, ssidCount)
|
||||||
|
|
||||||
rand.Shuffle(len(subs), func(i, j int) {
|
rand.Shuffle(len(subs), func(i, j int) {
|
||||||
subs[i], subs[j] = subs[j], subs[i]
|
subs[i], subs[j] = subs[j], subs[i]
|
||||||
@@ -510,7 +401,11 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
|||||||
rl.removeListenerId(wsidToRemove.ws, wsidToRemove.id)
|
rl.removeListenerId(wsidToRemove.ws, wsidToRemove.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Len(t, rl.listeners, 0)
|
ssidCount = 0
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
ssidCount += len(specs)
|
||||||
|
}
|
||||||
|
require.Equal(t, 0, ssidCount)
|
||||||
require.Len(t, rl.clients, 20)
|
require.Len(t, rl.clients, 20)
|
||||||
for _, specs := range rl.clients {
|
for _, specs := range rl.clients {
|
||||||
require.Len(t, specs, 0)
|
require.Len(t, specs, 0)
|
||||||
@@ -520,12 +415,9 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
|||||||
func TestRouterListenersPabloCrash(t *testing.T) {
|
func TestRouterListenersPabloCrash(t *testing.T) {
|
||||||
rl := NewRelay()
|
rl := NewRelay()
|
||||||
|
|
||||||
rla := NewRelay()
|
ws1 := &WebSocket{Context: rl.ctx}
|
||||||
rlb := NewRelay()
|
ws2 := &WebSocket{Context: rl.ctx}
|
||||||
|
ws3 := &WebSocket{Context: rl.ctx}
|
||||||
ws1 := &WebSocket{}
|
|
||||||
ws2 := &WebSocket{}
|
|
||||||
ws3 := &WebSocket{}
|
|
||||||
|
|
||||||
rl.clients[ws1] = nil
|
rl.clients[ws1] = nil
|
||||||
rl.clients[ws2] = nil
|
rl.clients[ws2] = nil
|
||||||
@@ -534,11 +426,11 @@ func TestRouterListenersPabloCrash(t *testing.T) {
|
|||||||
f := nostr.Filter{Kinds: []nostr.Kind{1}}
|
f := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||||
cancel := func(cause error) {}
|
cancel := func(cause error) {}
|
||||||
|
|
||||||
rl.addListener(ws1, ":1", rla, f, cancel)
|
rl.addListener(ws1, ":1", f, cancel)
|
||||||
rl.addListener(ws2, ":1", rlb, f, cancel)
|
rl.addListener(ws2, ":1", f, cancel)
|
||||||
rl.addListener(ws3, "a", rlb, f, cancel)
|
rl.addListener(ws3, "a", f, cancel)
|
||||||
rl.addListener(ws3, "b", rla, f, cancel)
|
rl.addListener(ws3, "b", f, cancel)
|
||||||
rl.addListener(ws3, "c", rlb, f, cancel)
|
rl.addListener(ws3, "c", f, cancel)
|
||||||
|
|
||||||
rl.removeClientAndListeners(ws1)
|
rl.removeClientAndListeners(ws1)
|
||||||
rl.removeClientAndListeners(ws3)
|
rl.removeClientAndListeners(ws3)
|
||||||
|
|||||||
+3
-3
@@ -12,13 +12,13 @@ func (rl *Relay) HandleNIP11(w http.ResponseWriter, r *http.Request) {
|
|||||||
info := *rl.Info
|
info := *rl.Info
|
||||||
|
|
||||||
if nil != rl.DeleteEvent {
|
if nil != rl.DeleteEvent {
|
||||||
info.AddSupportedNIP(9)
|
info.AddSupportedNIP("9")
|
||||||
}
|
}
|
||||||
if nil != rl.Count {
|
if nil != rl.Count {
|
||||||
info.AddSupportedNIP(45)
|
info.AddSupportedNIP("45")
|
||||||
}
|
}
|
||||||
if rl.Negentropy {
|
if rl.Negentropy {
|
||||||
info.AddSupportedNIP(77)
|
info.AddSupportedNIP("77")
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolve relative icon and banner URLs against base URL
|
// resolve relative icon and banner URLs against base URL
|
||||||
|
|||||||
@@ -21,8 +21,10 @@ type RelayManagementAPI struct {
|
|||||||
|
|
||||||
BanPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
BanPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||||
ListBannedPubKeys func(ctx context.Context) ([]nip86.PubKeyReason, error)
|
ListBannedPubKeys func(ctx context.Context) ([]nip86.PubKeyReason, error)
|
||||||
|
UnbanPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||||
AllowPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
AllowPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||||
ListAllowedPubKeys func(ctx context.Context) ([]nip86.PubKeyReason, error)
|
ListAllowedPubKeys func(ctx context.Context) ([]nip86.PubKeyReason, error)
|
||||||
|
UnallowPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||||
ListEventsNeedingModeration func(ctx context.Context) ([]nip86.IDReason, error)
|
ListEventsNeedingModeration func(ctx context.Context) ([]nip86.IDReason, error)
|
||||||
AllowEvent func(ctx context.Context, id nostr.ID, reason string) error
|
AllowEvent func(ctx context.Context, id nostr.ID, reason string) error
|
||||||
BanEvent func(ctx context.Context, id nostr.ID, reason string) error
|
BanEvent func(ctx context.Context, id nostr.ID, reason string) error
|
||||||
@@ -168,6 +170,14 @@ func (rl *Relay) HandleNIP86(w http.ResponseWriter, r *http.Request) {
|
|||||||
} else {
|
} else {
|
||||||
resp.Result = result
|
resp.Result = result
|
||||||
}
|
}
|
||||||
|
case nip86.UnbanPubKey:
|
||||||
|
if rl.ManagementAPI.UnbanPubKey == nil {
|
||||||
|
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||||
|
} else if err := rl.ManagementAPI.UnbanPubKey(ctx, thing.PubKey, thing.Reason); err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
} else {
|
||||||
|
resp.Result = true
|
||||||
|
}
|
||||||
case nip86.AllowPubKey:
|
case nip86.AllowPubKey:
|
||||||
if rl.ManagementAPI.AllowPubKey == nil {
|
if rl.ManagementAPI.AllowPubKey == nil {
|
||||||
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||||
@@ -184,6 +194,14 @@ func (rl *Relay) HandleNIP86(w http.ResponseWriter, r *http.Request) {
|
|||||||
} else {
|
} else {
|
||||||
resp.Result = result
|
resp.Result = result
|
||||||
}
|
}
|
||||||
|
case nip86.UnallowPubKey:
|
||||||
|
if rl.ManagementAPI.UnallowPubKey == nil {
|
||||||
|
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||||
|
} else if err := rl.ManagementAPI.UnallowPubKey(ctx, thing.PubKey, thing.Reason); err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
} else {
|
||||||
|
resp.Result = true
|
||||||
|
}
|
||||||
case nip86.BanEvent:
|
case nip86.BanEvent:
|
||||||
if rl.ManagementAPI.BanEvent == nil {
|
if rl.ManagementAPI.BanEvent == nil {
|
||||||
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package policies
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"iter"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -110,6 +111,9 @@ func RejectEventsWithBase64Media(ctx context.Context, evt nostr.Event) (bool, st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OnlyAllowNIP70ProtectedEvents(ctx context.Context, event nostr.Event) (reject bool, msg string) {
|
func OnlyAllowNIP70ProtectedEvents(ctx context.Context, event nostr.Event) (reject bool, msg string) {
|
||||||
|
if event.Kind == 5 {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
if nip70.IsProtected(event) {
|
if nip70.IsProtected(event) {
|
||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
@@ -120,6 +124,9 @@ var nostrReferencesPrefix = regexp.MustCompile(`\b(nevent1|npub1|nprofile1|note1
|
|||||||
|
|
||||||
func RejectUnprefixedNostrReferences(ctx context.Context, event nostr.Event) (bool, string) {
|
func RejectUnprefixedNostrReferences(ctx context.Context, event nostr.Event) (bool, string) {
|
||||||
content := sdk.GetMainContent(event)
|
content := sdk.GetMainContent(event)
|
||||||
|
if content == "" {
|
||||||
|
content = event.Content
|
||||||
|
}
|
||||||
|
|
||||||
// only do it for stuff that wasn't parsed as blocks already
|
// only do it for stuff that wasn't parsed as blocks already
|
||||||
// (since those are already good references or URLs)
|
// (since those are already good references or URLs)
|
||||||
@@ -144,3 +151,55 @@ func RejectUnprefixedNostrReferences(ctx context.Context, event nostr.Event) (bo
|
|||||||
|
|
||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PreventNormalDuplicates prevents normal events that refer to the same thing from being saved.
|
||||||
|
// For kinds 6, 7, 16, 1018 it checks "e" tags.
|
||||||
|
// For kind 1163 it checks "p" tags.
|
||||||
|
// For kinds 1163, 6, 16, 7516, 7517 it checks "a" tags.
|
||||||
|
func PreventNormalDuplicates(query func(nostr.Filter, int) iter.Seq[nostr.Event]) func(ctx context.Context, event nostr.Event) (bool, string) {
|
||||||
|
exists := func(event nostr.Event, tagName string) bool {
|
||||||
|
hasAll := true
|
||||||
|
for t := range event.Tags.FindAll(tagName) {
|
||||||
|
hasThis := false
|
||||||
|
for range query(nostr.Filter{
|
||||||
|
Authors: []nostr.PubKey{event.PubKey},
|
||||||
|
Kinds: []nostr.Kind{event.Kind},
|
||||||
|
Tags: nostr.TagMap{tagName: []string{t[1]}},
|
||||||
|
}, 1) {
|
||||||
|
hasThis = true
|
||||||
|
}
|
||||||
|
if !hasThis {
|
||||||
|
hasAll = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hasAll
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(ctx context.Context, event nostr.Event) (bool, string) {
|
||||||
|
reject := false
|
||||||
|
|
||||||
|
switch event.Kind {
|
||||||
|
case 6:
|
||||||
|
reject = exists(event, "e") && exists(event, "a")
|
||||||
|
case 7:
|
||||||
|
reject = exists(event, "e") && exists(event, "a")
|
||||||
|
case 16:
|
||||||
|
reject = exists(event, "e") && exists(event, "a")
|
||||||
|
case 1018:
|
||||||
|
reject = exists(event, "e")
|
||||||
|
case 1163:
|
||||||
|
reject = exists(event, "p") && exists(event, "a")
|
||||||
|
case 7516:
|
||||||
|
reject = exists(event, "a")
|
||||||
|
case 7517:
|
||||||
|
reject = exists(event, "a")
|
||||||
|
}
|
||||||
|
|
||||||
|
if reject {
|
||||||
|
return true, "an event similar to this already exists"
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+40
-14
@@ -8,9 +8,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"fiatjaf.com/lib/channelmutex"
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
"fiatjaf.com/nostr/eventstore"
|
"fiatjaf.com/nostr/eventstore"
|
||||||
"fiatjaf.com/nostr/nip11"
|
"fiatjaf.com/nostr/nip11"
|
||||||
@@ -30,7 +30,7 @@ func NewRelay() *Relay {
|
|||||||
Info: &nip11.RelayInformationDocument{
|
Info: &nip11.RelayInformationDocument{
|
||||||
Software: "https://pkg.go.dev/fiatjaf.com/nostr/khatru",
|
Software: "https://pkg.go.dev/fiatjaf.com/nostr/khatru",
|
||||||
Version: "n/a",
|
Version: "n/a",
|
||||||
SupportedNIPs: []any{1, 11, 42, 70, 86},
|
SupportedNIPs: []string{"1", "11", "42", "70", "86"},
|
||||||
},
|
},
|
||||||
|
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
@@ -39,8 +39,10 @@ func NewRelay() *Relay {
|
|||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
},
|
},
|
||||||
|
|
||||||
clients: make(map[*WebSocket][]listenerSpec, 100),
|
clients: make(map[*WebSocket][]listenerSpec, 100),
|
||||||
listeners: make([]listener, 0, 100),
|
clientsMutex: channelmutex.New(),
|
||||||
|
|
||||||
|
dispatcher: newDispatcher(),
|
||||||
|
|
||||||
serveMux: &http.ServeMux{},
|
serveMux: &http.ServeMux{},
|
||||||
|
|
||||||
@@ -66,9 +68,10 @@ type Relay struct {
|
|||||||
// hooks that will be called at various times
|
// hooks that will be called at various times
|
||||||
OnEvent func(ctx context.Context, event nostr.Event) (reject bool, msg string)
|
OnEvent func(ctx context.Context, event nostr.Event) (reject bool, msg string)
|
||||||
StoreEvent func(ctx context.Context, event nostr.Event) error
|
StoreEvent func(ctx context.Context, event nostr.Event) error
|
||||||
ReplaceEvent func(ctx context.Context, event nostr.Event) error
|
ReplaceEvent func(ctx context.Context, event nostr.Event) ([]nostr.Event, error)
|
||||||
DeleteEvent func(ctx context.Context, id nostr.ID) error
|
DeleteEvent func(ctx context.Context, id nostr.ID) error
|
||||||
OnEventSaved func(ctx context.Context, event nostr.Event)
|
OnEventSaved func(ctx context.Context, event nostr.Event)
|
||||||
|
OnEventDeleted func(ctx context.Context, deleted nostr.Event)
|
||||||
OnEphemeralEvent func(ctx context.Context, event nostr.Event)
|
OnEphemeralEvent func(ctx context.Context, event nostr.Event)
|
||||||
OnRequest func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
|
OnRequest func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
|
||||||
OnCount func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
|
OnCount func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
|
||||||
@@ -84,11 +87,6 @@ type Relay struct {
|
|||||||
// this can be ignored unless you know what you're doing
|
// this can be ignored unless you know what you're doing
|
||||||
ChallengePrefix string
|
ChallengePrefix string
|
||||||
|
|
||||||
// these are used when this relays acts as a router
|
|
||||||
routes []Route
|
|
||||||
getSubRelayFromEvent func(*nostr.Event) *Relay // used for handling EVENTs
|
|
||||||
getSubRelayFromFilter func(nostr.Filter) *Relay // used for handling REQs
|
|
||||||
|
|
||||||
// setting up handlers here will enable these methods
|
// setting up handlers here will enable these methods
|
||||||
ManagementAPI RelayManagementAPI
|
ManagementAPI RelayManagementAPI
|
||||||
|
|
||||||
@@ -105,8 +103,8 @@ type Relay struct {
|
|||||||
// keep a connection reference to all connected clients for Server.Shutdown
|
// keep a connection reference to all connected clients for Server.Shutdown
|
||||||
// also used for keeping track of who is listening to what
|
// also used for keeping track of who is listening to what
|
||||||
clients map[*WebSocket][]listenerSpec
|
clients map[*WebSocket][]listenerSpec
|
||||||
listeners []listener
|
dispatcher dispatcher
|
||||||
clientsMutex sync.Mutex
|
clientsMutex *channelmutex.Mutex
|
||||||
|
|
||||||
// set this to true to support negentropy
|
// set this to true to support negentropy
|
||||||
Negentropy bool
|
Negentropy bool
|
||||||
@@ -147,7 +145,7 @@ func (rl *Relay) UseEventstore(store eventstore.Store, maxQueryLimit int) {
|
|||||||
rl.StoreEvent = func(ctx context.Context, event nostr.Event) error {
|
rl.StoreEvent = func(ctx context.Context, event nostr.Event) error {
|
||||||
return store.SaveEvent(event)
|
return store.SaveEvent(event)
|
||||||
}
|
}
|
||||||
rl.ReplaceEvent = func(ctx context.Context, event nostr.Event) error {
|
rl.ReplaceEvent = func(ctx context.Context, event nostr.Event) ([]nostr.Event, error) {
|
||||||
return store.ReplaceEvent(event)
|
return store.ReplaceEvent(event)
|
||||||
}
|
}
|
||||||
rl.DeleteEvent = func(ctx context.Context, id nostr.ID) error {
|
rl.DeleteEvent = func(ctx context.Context, id nostr.ID) error {
|
||||||
@@ -155,7 +153,15 @@ func (rl *Relay) UseEventstore(store eventstore.Store, maxQueryLimit int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only when using the eventstore we automatically set up the expiration manager
|
// only when using the eventstore we automatically set up the expiration manager
|
||||||
rl.StartExpirationManager(rl.QueryStored, rl.DeleteEvent)
|
rl.StartExpirationManager(func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event] {
|
||||||
|
return rl.QueryStored(ctx, filter)
|
||||||
|
}, func(ctx context.Context, id nostr.ID) error {
|
||||||
|
return rl.DeleteEvent(ctx, id)
|
||||||
|
}, func(ctx context.Context, evt nostr.Event) {
|
||||||
|
if rl.OnEventDeleted != nil {
|
||||||
|
rl.OnEventDeleted(ctx, evt)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *Relay) getBaseURL(r *http.Request) string {
|
func (rl *Relay) getBaseURL(r *http.Request) string {
|
||||||
@@ -184,3 +190,23 @@ func (rl *Relay) getBaseURL(r *http.Request) string {
|
|||||||
|
|
||||||
return proto + "://" + host + r.URL.Path
|
return proto + "://" + host + r.URL.Path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stats returns the current number of connected clients and open listeners.
|
||||||
|
func (rl *Relay) Stats() (clients, listeners int) {
|
||||||
|
rl.clientsMutex.Lock()
|
||||||
|
defer rl.clientsMutex.Unlock()
|
||||||
|
|
||||||
|
for _, specs := range rl.clients {
|
||||||
|
listeners += len(specs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(rl.clients), listeners
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *Relay) Router() *http.ServeMux {
|
||||||
|
return rl.serveMux
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *Relay) SetRouter(mux *http.ServeMux) {
|
||||||
|
rl.serveMux = mux
|
||||||
|
}
|
||||||
|
|||||||
+35
-77
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
"fiatjaf.com/nostr/eventstore/slicestore"
|
"fiatjaf.com/nostr/eventstore/slicestore"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBasicRelayFunctionality(t *testing.T) {
|
func TestBasicRelayFunctionality(t *testing.T) {
|
||||||
@@ -46,15 +47,11 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
// connect two test clients
|
// connect two test clients
|
||||||
url := "ws" + server.URL[4:]
|
url := "ws" + server.URL[4:]
|
||||||
client1, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
client1, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to connect client1")
|
||||||
t.Fatalf("failed to connect client1: %v", err)
|
|
||||||
}
|
|
||||||
defer client1.Close()
|
defer client1.Close()
|
||||||
|
|
||||||
client2, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
client2, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to connect client2")
|
||||||
t.Fatalf("failed to connect client2: %v", err)
|
|
||||||
}
|
|
||||||
defer client2.Close()
|
defer client2.Close()
|
||||||
|
|
||||||
// test 1: store and query events
|
// test 1: store and query events
|
||||||
@@ -64,18 +61,14 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
|
|
||||||
evt1 := createEvent(sk1, 1, "hello world", nil)
|
evt1 := createEvent(sk1, 1, "hello world", nil)
|
||||||
err := client1.Publish(ctx, evt1)
|
err := client1.Publish(ctx, evt1)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish event")
|
||||||
t.Fatalf("failed to publish event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query the event back
|
// Query the event back
|
||||||
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
||||||
Authors: []nostr.PubKey{pk1},
|
Authors: []nostr.PubKey{pk1},
|
||||||
Kinds: []nostr.Kind{1},
|
Kinds: []nostr.Kind{1},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to subscribe")
|
||||||
t.Fatalf("failed to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
defer sub.Unsub()
|
defer sub.Unsub()
|
||||||
|
|
||||||
// Wait for event
|
// Wait for event
|
||||||
@@ -85,7 +78,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
t.Errorf("got wrong event: %v", env.ID)
|
t.Errorf("got wrong event: %v", env.ID)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout waiting for event")
|
require.FailNow(t, "timeout waiting for event")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -99,17 +92,13 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
Authors: []nostr.PubKey{pk2},
|
Authors: []nostr.PubKey{pk2},
|
||||||
Kinds: []nostr.Kind{1},
|
Kinds: []nostr.Kind{1},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to subscribe")
|
||||||
t.Fatalf("failed to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
defer sub.Unsub()
|
defer sub.Unsub()
|
||||||
|
|
||||||
// Publish event from client2
|
// Publish event from client2
|
||||||
evt2 := createEvent(sk2, 1, "testing live events", nil)
|
evt2 := createEvent(sk2, 1, "testing live events", nil)
|
||||||
err = client2.Publish(ctx, evt2)
|
err = client2.Publish(ctx, evt2)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish event")
|
||||||
t.Fatalf("failed to publish event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for event on subscription
|
// Wait for event on subscription
|
||||||
select {
|
select {
|
||||||
@@ -118,7 +107,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
t.Errorf("got wrong event: %v", env.ID)
|
t.Errorf("got wrong event: %v", env.ID)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout waiting for live event")
|
require.FailNow(t, "timeout waiting for live event")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -130,24 +119,18 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
// Create an event to be deleted
|
// Create an event to be deleted
|
||||||
evt3 := createEvent(sk1, 1, "delete me", nil)
|
evt3 := createEvent(sk1, 1, "delete me", nil)
|
||||||
err = client1.Publish(ctx, evt3)
|
err = client1.Publish(ctx, evt3)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish event")
|
||||||
t.Fatalf("failed to publish event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create deletion event
|
// Create deletion event
|
||||||
delEvent := createEvent(sk1, 5, "deleting", nostr.Tags{{"e", evt3.ID.Hex()}})
|
delEvent := createEvent(sk1, 5, "deleting", nostr.Tags{{"e", evt3.ID.Hex()}})
|
||||||
err = client1.Publish(ctx, delEvent)
|
err = client1.Publish(ctx, delEvent)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish deletion event")
|
||||||
t.Fatalf("failed to publish deletion event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to query the deleted event
|
// Try to query the deleted event
|
||||||
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
||||||
IDs: []nostr.ID{evt3.ID},
|
IDs: []nostr.ID{evt3.ID},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to subscribe")
|
||||||
t.Fatalf("failed to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
defer sub.Unsub()
|
defer sub.Unsub()
|
||||||
|
|
||||||
// Should get EOSE without receiving the deleted event
|
// Should get EOSE without receiving the deleted event
|
||||||
@@ -162,7 +145,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
}
|
}
|
||||||
goto checkDeleteStored
|
goto checkDeleteStored
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout waiting for EOSE")
|
require.FailNow(t, "timeout waiting for EOSE")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,9 +154,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
subDelete, err := client2.Subscribe(ctx, nostr.Filter{
|
subDelete, err := client2.Subscribe(ctx, nostr.Filter{
|
||||||
IDs: []nostr.ID{delEvent.ID},
|
IDs: []nostr.ID{delEvent.ID},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to subscribe to delete event")
|
||||||
t.Fatalf("failed to subscribe to delete event: %v", err)
|
|
||||||
}
|
|
||||||
defer subDelete.Unsub()
|
defer subDelete.Unsub()
|
||||||
|
|
||||||
gotDeleteEvent := false
|
gotDeleteEvent := false
|
||||||
@@ -189,7 +170,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout waiting for EOSE on delete event")
|
require.FailNow(t, "timeout waiting for EOSE on delete event")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -204,36 +185,28 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
evt1.CreatedAt = 1000 // Set specific timestamp for testing
|
evt1.CreatedAt = 1000 // Set specific timestamp for testing
|
||||||
evt1.Sign(sk1)
|
evt1.Sign(sk1)
|
||||||
err = client1.Publish(ctx, evt1)
|
err = client1.Publish(ctx, evt1)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish initial event")
|
||||||
t.Fatalf("failed to publish initial event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create newer event that should replace the first
|
// create newer event that should replace the first
|
||||||
evt2 := createEvent(sk1, 0, `{"name":"newer"}`, nil)
|
evt2 := createEvent(sk1, 0, `{"name":"newer"}`, nil)
|
||||||
evt2.CreatedAt = 2004 // Newer timestamp
|
evt2.CreatedAt = 2004 // Newer timestamp
|
||||||
evt2.Sign(sk1)
|
evt2.Sign(sk1)
|
||||||
err = client1.Publish(ctx, evt2)
|
err = client1.Publish(ctx, evt2)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish newer event")
|
||||||
t.Fatalf("failed to publish newer event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create older event that should not replace the current one
|
// create older event that should not replace the current one
|
||||||
evt3 := createEvent(sk1, 0, `{"name":"older"}`, nil)
|
evt3 := createEvent(sk1, 0, `{"name":"older"}`, nil)
|
||||||
evt3.CreatedAt = 1500 // Older than evt2
|
evt3.CreatedAt = 1500 // Older than evt2
|
||||||
evt3.Sign(sk1)
|
evt3.Sign(sk1)
|
||||||
err = client1.Publish(ctx, evt3)
|
err = client1.Publish(ctx, evt3)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish older event")
|
||||||
t.Fatalf("failed to publish older event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// query to verify only the newest event exists
|
// query to verify only the newest event exists
|
||||||
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
||||||
Authors: []nostr.PubKey{pk1},
|
Authors: []nostr.PubKey{pk1},
|
||||||
Kinds: []nostr.Kind{0},
|
Kinds: []nostr.Kind{0},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to subscribe")
|
||||||
t.Fatalf("failed to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
defer sub.Unsub()
|
defer sub.Unsub()
|
||||||
|
|
||||||
// should only get one event back (the newest one)
|
// should only get one event back (the newest one)
|
||||||
@@ -251,7 +224,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout waiting for events")
|
require.FailNow(t, "timeout waiting for events")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -281,26 +254,20 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
// connect test client
|
// connect test client
|
||||||
url := "ws" + server.URL[4:]
|
url := "ws" + server.URL[4:]
|
||||||
client, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
client, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to connect client")
|
||||||
t.Fatalf("failed to connect client: %v", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
// create event that expires in 2 seconds
|
// create event that expires in 2 seconds
|
||||||
expiration := strconv.FormatInt(int64(nostr.Now()+2), 10)
|
expiration := strconv.FormatInt(int64(nostr.Now()+2), 10)
|
||||||
evt := createEvent(sk1, 1, "i will expire soon", nostr.Tags{{"expiration", expiration}})
|
evt := createEvent(sk1, 1, "i will expire soon", nostr.Tags{{"expiration", expiration}})
|
||||||
err = client.Publish(ctx, evt)
|
err = client.Publish(ctx, evt)
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to publish event")
|
||||||
t.Fatalf("failed to publish event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify event exists initially
|
// verify event exists initially
|
||||||
sub, err := client.Subscribe(ctx, nostr.Filter{
|
sub, err := client.Subscribe(ctx, nostr.Filter{
|
||||||
IDs: []nostr.ID{evt.ID},
|
IDs: []nostr.ID{evt.ID},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to subscribe")
|
||||||
t.Fatalf("failed to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// should get the event
|
// should get the event
|
||||||
select {
|
select {
|
||||||
@@ -309,7 +276,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
t.Error("got wrong event")
|
t.Error("got wrong event")
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout waiting for event")
|
require.FailNow(t, "timeout waiting for event")
|
||||||
}
|
}
|
||||||
sub.Unsub()
|
sub.Unsub()
|
||||||
|
|
||||||
@@ -320,9 +287,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
sub, err = client.Subscribe(ctx, nostr.Filter{
|
sub, err = client.Subscribe(ctx, nostr.Filter{
|
||||||
IDs: []nostr.ID{evt.ID},
|
IDs: []nostr.ID{evt.ID},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to subscribe")
|
||||||
t.Fatalf("failed to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
defer sub.Unsub()
|
defer sub.Unsub()
|
||||||
|
|
||||||
// should get EOSE without receiving the expired event
|
// should get EOSE without receiving the expired event
|
||||||
@@ -337,7 +302,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout waiting for EOSE")
|
require.FailNow(t, "timeout waiting for EOSE")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -350,33 +315,26 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
|||||||
// create an event from client1
|
// create an event from client1
|
||||||
evt4 := createEvent(sk1, 1, "try to delete me", nil)
|
evt4 := createEvent(sk1, 1, "try to delete me", nil)
|
||||||
err = client1.Publish(ctx, evt4)
|
err = client1.Publish(ctx, evt4)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatalf("failed to publish event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to delete it with client2
|
// try to delete it with client2
|
||||||
delEvent := createEvent(sk2, 5, "trying to delete", nostr.Tags{{"e", evt4.ID.Hex()}})
|
delEvent := createEvent(sk2, 5, "trying to delete", nostr.Tags{{"e", evt4.ID.Hex()}})
|
||||||
err = client2.Publish(ctx, delEvent)
|
err = client2.Publish(ctx, delEvent)
|
||||||
if err == nil {
|
require.Error(t, err)
|
||||||
t.Fatalf("should have failed to publish deletion event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify event still exists
|
// verify event still exists
|
||||||
sub, err := client1.Subscribe(ctx, nostr.Filter{
|
sub, err := client1.Subscribe(ctx, nostr.Filter{
|
||||||
IDs: []nostr.ID{evt4.ID},
|
IDs: []nostr.ID{evt4.ID},
|
||||||
}, nostr.SubscriptionOptions{})
|
}, nostr.SubscriptionOptions{})
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatalf("failed to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
defer sub.Unsub()
|
defer sub.Unsub()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case env := <-sub.Events:
|
case env, more := <-sub.Events:
|
||||||
if env.ID != evt4.ID {
|
require.True(t, more, "should get an event, got nothing")
|
||||||
t.Error("got wrong event")
|
require.Equal(t, env.ID, evt4.ID, "got wrong event")
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("event should still exist")
|
require.FailNow(t, "event should still exist")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr
|
|||||||
// run the function to query events
|
// run the function to query events
|
||||||
if nil != rl.QueryStored {
|
if nil != rl.QueryStored {
|
||||||
for event := range rl.QueryStored(ctx, filter) {
|
for event := range rl.QueryStored(ctx, filter) {
|
||||||
ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &id, Event: event})
|
if nil != ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &id, Event: event}) {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
package khatru
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fiatjaf.com/nostr"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Router struct{ *Relay }
|
|
||||||
|
|
||||||
type Route struct {
|
|
||||||
eventMatcher func(*nostr.Event) bool
|
|
||||||
filterMatcher func(nostr.Filter) bool
|
|
||||||
relay *Relay
|
|
||||||
}
|
|
||||||
|
|
||||||
type routeBuilder struct {
|
|
||||||
router *Router
|
|
||||||
eventMatcher func(*nostr.Event) bool
|
|
||||||
filterMatcher func(nostr.Filter) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRouter() *Router {
|
|
||||||
rr := &Router{Relay: NewRelay()}
|
|
||||||
rr.routes = make([]Route, 0, 3)
|
|
||||||
rr.getSubRelayFromFilter = func(f nostr.Filter) *Relay {
|
|
||||||
for _, route := range rr.routes {
|
|
||||||
if route.filterMatcher == nil || route.filterMatcher(f) {
|
|
||||||
return route.relay
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return rr.Relay
|
|
||||||
}
|
|
||||||
rr.getSubRelayFromEvent = func(e *nostr.Event) *Relay {
|
|
||||||
for _, route := range rr.routes {
|
|
||||||
if route.eventMatcher == nil || route.eventMatcher(e) {
|
|
||||||
return route.relay
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return rr.Relay
|
|
||||||
}
|
|
||||||
return rr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rr *Router) Route() routeBuilder {
|
|
||||||
return routeBuilder{
|
|
||||||
router: rr,
|
|
||||||
filterMatcher: func(f nostr.Filter) bool { return false },
|
|
||||||
eventMatcher: func(e *nostr.Event) bool { return false },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rb routeBuilder) Req(fn func(nostr.Filter) bool) routeBuilder {
|
|
||||||
rb.filterMatcher = fn
|
|
||||||
return rb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rb routeBuilder) AnyReq() routeBuilder {
|
|
||||||
rb.filterMatcher = nil
|
|
||||||
return rb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rb routeBuilder) Event(fn func(*nostr.Event) bool) routeBuilder {
|
|
||||||
rb.eventMatcher = fn
|
|
||||||
return rb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rb routeBuilder) AnyEvent() routeBuilder {
|
|
||||||
rb.eventMatcher = nil
|
|
||||||
return rb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rb routeBuilder) Relay(relay *Relay) {
|
|
||||||
rb.router.routes = append(rb.router.routes, Route{
|
|
||||||
filterMatcher: rb.filterMatcher,
|
|
||||||
eventMatcher: rb.eventMatcher,
|
|
||||||
relay: relay,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -2,6 +2,7 @@ package khatru
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -31,6 +32,9 @@ type WebSocket struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebSocket) WriteJSON(any any) error {
|
func (ws *WebSocket) WriteJSON(any any) error {
|
||||||
|
if ws == nil {
|
||||||
|
return fmt.Errorf("connection doesn't exist")
|
||||||
|
}
|
||||||
ws.mutex.Lock()
|
ws.mutex.Lock()
|
||||||
err := ws.conn.WriteJSON(any)
|
err := ws.conn.WriteJSON(any)
|
||||||
ws.mutex.Unlock()
|
ws.mutex.Unlock()
|
||||||
|
|||||||
@@ -246,6 +246,8 @@ func (kind Kind) Name() string {
|
|||||||
return "SimpleGroupMembers"
|
return "SimpleGroupMembers"
|
||||||
case KindSimpleGroupRoles:
|
case KindSimpleGroupRoles:
|
||||||
return "SimpleGroupRoles"
|
return "SimpleGroupRoles"
|
||||||
|
case KindSimpleGroupLiveKitParticipants:
|
||||||
|
return "SimpleGroupLiveKitParticipants"
|
||||||
case KindWikiArticle:
|
case KindWikiArticle:
|
||||||
return "WikiArticle"
|
return "WikiArticle"
|
||||||
case KindRedirects:
|
case KindRedirects:
|
||||||
@@ -277,138 +279,139 @@ func (kind Kind) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
KindProfileMetadata Kind = 0
|
KindProfileMetadata Kind = 0
|
||||||
KindTextNote Kind = 1
|
KindTextNote Kind = 1
|
||||||
KindRecommendServer Kind = 2
|
KindRecommendServer Kind = 2
|
||||||
KindFollowList Kind = 3
|
KindFollowList Kind = 3
|
||||||
KindEncryptedDirectMessage Kind = 4
|
KindEncryptedDirectMessage Kind = 4
|
||||||
KindDeletion Kind = 5
|
KindDeletion Kind = 5
|
||||||
KindRepost Kind = 6
|
KindRepost Kind = 6
|
||||||
KindReaction Kind = 7
|
KindReaction Kind = 7
|
||||||
KindBadgeAward Kind = 8
|
KindBadgeAward Kind = 8
|
||||||
KindSimpleGroupChatMessage Kind = 9
|
KindSimpleGroupChatMessage Kind = 9
|
||||||
KindSimpleGroupThreadedReply Kind = 10
|
KindSimpleGroupThreadedReply Kind = 10
|
||||||
KindSimpleGroupThread Kind = 11
|
KindSimpleGroupThread Kind = 11
|
||||||
KindSimpleGroupReply Kind = 12
|
KindSimpleGroupReply Kind = 12
|
||||||
KindSeal Kind = 13
|
KindSeal Kind = 13
|
||||||
KindDirectMessage Kind = 14
|
KindDirectMessage Kind = 14
|
||||||
KindGenericRepost Kind = 16
|
KindGenericRepost Kind = 16
|
||||||
KindReactionToWebsite Kind = 17
|
KindReactionToWebsite Kind = 17
|
||||||
KindChannelCreation Kind = 40
|
KindChannelCreation Kind = 40
|
||||||
KindChannelMetadata Kind = 41
|
KindChannelMetadata Kind = 41
|
||||||
KindChannelMessage Kind = 42
|
KindChannelMessage Kind = 42
|
||||||
KindChannelHideMessage Kind = 43
|
KindChannelHideMessage Kind = 43
|
||||||
KindChannelMuteUser Kind = 44
|
KindChannelMuteUser Kind = 44
|
||||||
KindChess Kind = 64
|
KindChess Kind = 64
|
||||||
KindMergeRequests Kind = 818
|
KindMergeRequests Kind = 818
|
||||||
KindComment Kind = 1111
|
KindComment Kind = 1111
|
||||||
KindBid Kind = 1021
|
KindBid Kind = 1021
|
||||||
KindBidConfirmation Kind = 1022
|
KindBidConfirmation Kind = 1022
|
||||||
KindOpenTimestamps Kind = 1040
|
KindOpenTimestamps Kind = 1040
|
||||||
KindGiftWrap Kind = 1059
|
KindGiftWrap Kind = 1059
|
||||||
KindFileMetadata Kind = 1063
|
KindFileMetadata Kind = 1063
|
||||||
KindLiveChatMessage Kind = 1311
|
KindLiveChatMessage Kind = 1311
|
||||||
KindPatch Kind = 1617
|
KindPatch Kind = 1617
|
||||||
KindIssue Kind = 1621
|
KindIssue Kind = 1621
|
||||||
KindReply Kind = 1622
|
KindReply Kind = 1622
|
||||||
KindStatusOpen Kind = 1630
|
KindStatusOpen Kind = 1630
|
||||||
KindStatusApplied Kind = 1631
|
KindStatusApplied Kind = 1631
|
||||||
KindStatusClosed Kind = 1632
|
KindStatusClosed Kind = 1632
|
||||||
KindStatusDraft Kind = 1633
|
KindStatusDraft Kind = 1633
|
||||||
KindProblemTracker Kind = 1971
|
KindProblemTracker Kind = 1971
|
||||||
KindReporting Kind = 1984
|
KindReporting Kind = 1984
|
||||||
KindLabel Kind = 1985
|
KindLabel Kind = 1985
|
||||||
KindRelayReviews Kind = 1986
|
KindRelayReviews Kind = 1986
|
||||||
KindAIEmbeddings Kind = 1987
|
KindAIEmbeddings Kind = 1987
|
||||||
KindTorrent Kind = 2003
|
KindTorrent Kind = 2003
|
||||||
KindTorrentComment Kind = 2004
|
KindTorrentComment Kind = 2004
|
||||||
KindCoinjoinPool Kind = 2022
|
KindCoinjoinPool Kind = 2022
|
||||||
KindCommunityPostApproval Kind = 4550
|
KindCommunityPostApproval Kind = 4550
|
||||||
KindJobFeedback Kind = 7000
|
KindJobFeedback Kind = 7000
|
||||||
KindSimpleGroupPutUser Kind = 9000
|
KindSimpleGroupPutUser Kind = 9000
|
||||||
KindSimpleGroupRemoveUser Kind = 9001
|
KindSimpleGroupRemoveUser Kind = 9001
|
||||||
KindSimpleGroupEditMetadata Kind = 9002
|
KindSimpleGroupEditMetadata Kind = 9002
|
||||||
KindSimpleGroupDeleteEvent Kind = 9005
|
KindSimpleGroupDeleteEvent Kind = 9005
|
||||||
KindSimpleGroupCreateGroup Kind = 9007
|
KindSimpleGroupCreateGroup Kind = 9007
|
||||||
KindSimpleGroupDeleteGroup Kind = 9008
|
KindSimpleGroupDeleteGroup Kind = 9008
|
||||||
KindSimpleGroupCreateInvite Kind = 9009
|
KindSimpleGroupCreateInvite Kind = 9009
|
||||||
KindSimpleGroupJoinRequest Kind = 9021
|
KindSimpleGroupJoinRequest Kind = 9021
|
||||||
KindSimpleGroupLeaveRequest Kind = 9022
|
KindSimpleGroupLeaveRequest Kind = 9022
|
||||||
KindZapGoal Kind = 9041
|
KindZapGoal Kind = 9041
|
||||||
KindNutZap Kind = 9321
|
KindNutZap Kind = 9321
|
||||||
KindTidalLogin Kind = 9467
|
KindTidalLogin Kind = 9467
|
||||||
KindZapRequest Kind = 9734
|
KindZapRequest Kind = 9734
|
||||||
KindZap Kind = 9735
|
KindZap Kind = 9735
|
||||||
KindHighlights Kind = 9802
|
KindHighlights Kind = 9802
|
||||||
KindMuteList Kind = 10000
|
KindMuteList Kind = 10000
|
||||||
KindPinList Kind = 10001
|
KindPinList Kind = 10001
|
||||||
KindRelayListMetadata Kind = 10002
|
KindRelayListMetadata Kind = 10002
|
||||||
KindBookmarkList Kind = 10003
|
KindBookmarkList Kind = 10003
|
||||||
KindCommunityList Kind = 10004
|
KindCommunityList Kind = 10004
|
||||||
KindPublicChatList Kind = 10005
|
KindPublicChatList Kind = 10005
|
||||||
KindBlockedRelayList Kind = 10006
|
KindBlockedRelayList Kind = 10006
|
||||||
KindSearchRelayList Kind = 10007
|
KindSearchRelayList Kind = 10007
|
||||||
KindSimpleGroupList Kind = 10009
|
KindSimpleGroupList Kind = 10009
|
||||||
KindInterestList Kind = 10015
|
KindInterestList Kind = 10015
|
||||||
KindNutZapInfo Kind = 10019
|
KindNutZapInfo Kind = 10019
|
||||||
KindEmojiList Kind = 10030
|
KindEmojiList Kind = 10030
|
||||||
KindDMRelayList Kind = 10050
|
KindDMRelayList Kind = 10050
|
||||||
KindUserServerList Kind = 10063
|
KindUserServerList Kind = 10063
|
||||||
KindFileStorageServerList Kind = 10096
|
KindFileStorageServerList Kind = 10096
|
||||||
KindGoodWikiAuthorList Kind = 10101
|
KindGoodWikiAuthorList Kind = 10101
|
||||||
KindGoodWikiRelayList Kind = 10102
|
KindGoodWikiRelayList Kind = 10102
|
||||||
KindNWCWalletInfo Kind = 13194
|
KindNWCWalletInfo Kind = 13194
|
||||||
KindLightningPubRPC Kind = 21000
|
KindLightningPubRPC Kind = 21000
|
||||||
KindClientAuthentication Kind = 22242
|
KindClientAuthentication Kind = 22242
|
||||||
KindNWCWalletRequest Kind = 23194
|
KindNWCWalletRequest Kind = 23194
|
||||||
KindNWCWalletResponse Kind = 23195
|
KindNWCWalletResponse Kind = 23195
|
||||||
KindNostrConnect Kind = 24133
|
KindNostrConnect Kind = 24133
|
||||||
KindBlobs Kind = 24242
|
KindBlobs Kind = 24242
|
||||||
KindHTTPAuth Kind = 27235
|
KindHTTPAuth Kind = 27235
|
||||||
KindCategorizedPeopleList Kind = 30000
|
KindCategorizedPeopleList Kind = 30000
|
||||||
KindCategorizedBookmarksList Kind = 30001
|
KindCategorizedBookmarksList Kind = 30001
|
||||||
KindRelaySets Kind = 30002
|
KindRelaySets Kind = 30002
|
||||||
KindBookmarkSets Kind = 30003
|
KindBookmarkSets Kind = 30003
|
||||||
KindCuratedSets Kind = 30004
|
KindCuratedSets Kind = 30004
|
||||||
KindCuratedVideoSets Kind = 30005
|
KindCuratedVideoSets Kind = 30005
|
||||||
KindMuteSets Kind = 30007
|
KindMuteSets Kind = 30007
|
||||||
KindProfileBadges Kind = 30008
|
KindProfileBadges Kind = 30008
|
||||||
KindBadgeDefinition Kind = 30009
|
KindBadgeDefinition Kind = 30009
|
||||||
KindInterestSets Kind = 30015
|
KindInterestSets Kind = 30015
|
||||||
KindStallDefinition Kind = 30017
|
KindStallDefinition Kind = 30017
|
||||||
KindProductDefinition Kind = 30018
|
KindProductDefinition Kind = 30018
|
||||||
KindMarketplaceUI Kind = 30019
|
KindMarketplaceUI Kind = 30019
|
||||||
KindProductSoldAsAuction Kind = 30020
|
KindProductSoldAsAuction Kind = 30020
|
||||||
KindArticle Kind = 30023
|
KindArticle Kind = 30023
|
||||||
KindDraftArticle Kind = 30024
|
KindDraftArticle Kind = 30024
|
||||||
KindEmojiSets Kind = 30030
|
KindEmojiSets Kind = 30030
|
||||||
KindModularArticleHeader Kind = 30040
|
KindModularArticleHeader Kind = 30040
|
||||||
KindModularArticleContent Kind = 30041
|
KindModularArticleContent Kind = 30041
|
||||||
KindReleaseArtifactSets Kind = 30063
|
KindReleaseArtifactSets Kind = 30063
|
||||||
KindApplicationSpecificData Kind = 30078
|
KindApplicationSpecificData Kind = 30078
|
||||||
KindLiveEvent Kind = 30311
|
KindLiveEvent Kind = 30311
|
||||||
KindUserStatuses Kind = 30315
|
KindUserStatuses Kind = 30315
|
||||||
KindClassifiedListing Kind = 30402
|
KindClassifiedListing Kind = 30402
|
||||||
KindDraftClassifiedListing Kind = 30403
|
KindDraftClassifiedListing Kind = 30403
|
||||||
KindRepositoryAnnouncement Kind = 30617
|
KindRepositoryAnnouncement Kind = 30617
|
||||||
KindRepositoryState Kind = 30618
|
KindRepositoryState Kind = 30618
|
||||||
KindSimpleGroupMetadata Kind = 39000
|
KindSimpleGroupMetadata Kind = 39000
|
||||||
KindSimpleGroupAdmins Kind = 39001
|
KindSimpleGroupAdmins Kind = 39001
|
||||||
KindSimpleGroupMembers Kind = 39002
|
KindSimpleGroupMembers Kind = 39002
|
||||||
KindSimpleGroupRoles Kind = 39003
|
KindSimpleGroupRoles Kind = 39003
|
||||||
KindWikiArticle Kind = 30818
|
KindSimpleGroupLiveKitParticipants Kind = 39004
|
||||||
KindRedirects Kind = 30819
|
KindWikiArticle Kind = 30818
|
||||||
KindFeed Kind = 31890
|
KindRedirects Kind = 30819
|
||||||
KindDateCalendarEvent Kind = 31922
|
KindFeed Kind = 31890
|
||||||
KindTimeCalendarEvent Kind = 31923
|
KindDateCalendarEvent Kind = 31922
|
||||||
KindCalendar Kind = 31924
|
KindTimeCalendarEvent Kind = 31923
|
||||||
KindCalendarEventRSVP Kind = 31925
|
KindCalendar Kind = 31924
|
||||||
KindHandlerRecommendation Kind = 31989
|
KindCalendarEventRSVP Kind = 31925
|
||||||
KindHandlerInformation Kind = 31990
|
KindHandlerRecommendation Kind = 31989
|
||||||
KindVideoEvent Kind = 34235
|
KindHandlerInformation Kind = 31990
|
||||||
KindShortVideoEvent Kind = 34236
|
KindVideoEvent Kind = 34235
|
||||||
KindVideoViewEvent Kind = 34237
|
KindShortVideoEvent Kind = 34236
|
||||||
KindCommunityDefinition Kind = 34550
|
KindVideoViewEvent Kind = 34237
|
||||||
|
KindCommunityDefinition Kind = 34550
|
||||||
)
|
)
|
||||||
|
|
||||||
func (kind Kind) IsRegular() bool {
|
func (kind Kind) IsRegular() bool {
|
||||||
|
|||||||
+18
-18
@@ -9,30 +9,30 @@ import (
|
|||||||
|
|
||||||
func TestAddSupportedNIP(t *testing.T) {
|
func TestAddSupportedNIP(t *testing.T) {
|
||||||
info := RelayInformationDocument{}
|
info := RelayInformationDocument{}
|
||||||
info.AddSupportedNIP(12)
|
info.AddSupportedNIP("12")
|
||||||
info.AddSupportedNIP(12)
|
info.AddSupportedNIP("12")
|
||||||
info.AddSupportedNIP(13)
|
info.AddSupportedNIP("13")
|
||||||
info.AddSupportedNIP(1)
|
info.AddSupportedNIP("1")
|
||||||
info.AddSupportedNIP(12)
|
info.AddSupportedNIP("12")
|
||||||
info.AddSupportedNIP(44)
|
info.AddSupportedNIP("44")
|
||||||
info.AddSupportedNIP(2)
|
info.AddSupportedNIP("2")
|
||||||
info.AddSupportedNIP(13)
|
info.AddSupportedNIP("13")
|
||||||
info.AddSupportedNIP(2)
|
info.AddSupportedNIP("2")
|
||||||
info.AddSupportedNIP(13)
|
info.AddSupportedNIP("13")
|
||||||
info.AddSupportedNIP(0)
|
info.AddSupportedNIP("0")
|
||||||
info.AddSupportedNIP(17)
|
info.AddSupportedNIP("17")
|
||||||
info.AddSupportedNIP(19)
|
info.AddSupportedNIP("19")
|
||||||
info.AddSupportedNIP(1)
|
info.AddSupportedNIP("1")
|
||||||
info.AddSupportedNIP(18)
|
info.AddSupportedNIP("18")
|
||||||
|
|
||||||
assert.Contains(t, info.SupportedNIPs, 0, 1, 2, 12, 13, 17, 18, 19, 44)
|
assert.Contains(t, info.SupportedNIPs, "0", "1", "2", "12", "13", "17", "18", "19", "44")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddSupportedNIPs(t *testing.T) {
|
func TestAddSupportedNIPs(t *testing.T) {
|
||||||
info := RelayInformationDocument{}
|
info := RelayInformationDocument{}
|
||||||
info.AddSupportedNIPs([]int{0, 1, 2, 12, 13, 17, 18, 19, 44})
|
info.AddSupportedNIPs([]int{"0", "1", "2", "12", "13", "17", "18", "19", "44"})
|
||||||
|
|
||||||
assert.Contains(t, info.SupportedNIPs, 0, 1, 2, 12, 13, 17, 18, 19, 44)
|
assert.Contains(t, info.SupportedNIPs, "0", "1", "2", "12", "13", "17", "18", "19", "44")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFetch(t *testing.T) {
|
func TestFetch(t *testing.T) {
|
||||||
|
|||||||
+5
-5
@@ -14,7 +14,7 @@ type RelayInformationDocument struct {
|
|||||||
PubKey *nostr.PubKey `json:"pubkey,omitempty"`
|
PubKey *nostr.PubKey `json:"pubkey,omitempty"`
|
||||||
Self *nostr.PubKey `json:"self,omitempty"`
|
Self *nostr.PubKey `json:"self,omitempty"`
|
||||||
Contact string `json:"contact,omitempty"`
|
Contact string `json:"contact,omitempty"`
|
||||||
SupportedNIPs []any `json:"supported_nips,omitempty"`
|
SupportedNIPs []string `json:"supported_nips,omitempty"`
|
||||||
Software string `json:"software,omitempty"`
|
Software string `json:"software,omitempty"`
|
||||||
Version string `json:"version,omitempty"`
|
Version string `json:"version,omitempty"`
|
||||||
|
|
||||||
@@ -33,16 +33,16 @@ type RelayInformationDocument struct {
|
|||||||
SupportedGrasps []string `json:"supported_grasps,omitempty"`
|
SupportedGrasps []string `json:"supported_grasps,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (info *RelayInformationDocument) AddSupportedNIP(number int) {
|
func (info *RelayInformationDocument) AddSupportedNIP(nip string) {
|
||||||
idx := slices.IndexFunc(info.SupportedNIPs, func(n any) bool { return n == number })
|
idx := slices.IndexFunc(info.SupportedNIPs, func(n string) bool { return n == nip })
|
||||||
if idx != -1 {
|
if idx != -1 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
info.SupportedNIPs = append(info.SupportedNIPs, number)
|
info.SupportedNIPs = append(info.SupportedNIPs, nip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (info *RelayInformationDocument) AddSupportedNIPs(numbers []int) {
|
func (info *RelayInformationDocument) AddSupportedNIPs(numbers []string) {
|
||||||
for _, n := range numbers {
|
for _, n := range numbers {
|
||||||
info.AddSupportedNIP(n)
|
info.AddSupportedNIP(n)
|
||||||
}
|
}
|
||||||
|
|||||||
+126
-35
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
@@ -38,10 +39,11 @@ func ParseGroupAddress(raw string) (GroupAddress, error) {
|
|||||||
type Group struct {
|
type Group struct {
|
||||||
Address GroupAddress
|
Address GroupAddress
|
||||||
|
|
||||||
Name string
|
Name string
|
||||||
Picture string
|
Picture string
|
||||||
About string
|
About string
|
||||||
Members map[nostr.PubKey][]*Role
|
Members map[nostr.PubKey][]*Role
|
||||||
|
LiveKitParticipants []nostr.PubKey
|
||||||
|
|
||||||
// indicates that only members can read group messages
|
// indicates that only members can read group messages
|
||||||
Private bool
|
Private bool
|
||||||
@@ -55,13 +57,20 @@ type Group struct {
|
|||||||
// indicates that relays should hide group metadata from non-members
|
// indicates that relays should hide group metadata from non-members
|
||||||
Hidden bool
|
Hidden bool
|
||||||
|
|
||||||
|
// indicates that the group supports audio/video live chat
|
||||||
|
LiveKit bool
|
||||||
|
|
||||||
|
// indicates which event kinds this group supports
|
||||||
|
SupportedKinds []nostr.Kind
|
||||||
|
|
||||||
Roles []*Role
|
Roles []*Role
|
||||||
InviteCodes []string
|
InviteCodes []string
|
||||||
|
|
||||||
LastMetadataUpdate nostr.Timestamp
|
LastMetadataUpdate nostr.Timestamp
|
||||||
LastAdminsUpdate nostr.Timestamp
|
LastAdminsUpdate nostr.Timestamp
|
||||||
LastMembersUpdate nostr.Timestamp
|
LastMembersUpdate nostr.Timestamp
|
||||||
LastRolesUpdate nostr.Timestamp
|
LastRolesUpdate nostr.Timestamp
|
||||||
|
LastLiveKitParticipantsUpdate nostr.Timestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (group Group) String() string {
|
func (group Group) String() string {
|
||||||
@@ -83,6 +92,11 @@ func (group Group) String() string {
|
|||||||
maybeClosed = " closed"
|
maybeClosed = " closed"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maybeLiveKit := ""
|
||||||
|
if group.LiveKit {
|
||||||
|
maybeLiveKit = " livekit"
|
||||||
|
}
|
||||||
|
|
||||||
members := make([]string, len(group.Members))
|
members := make([]string, len(group.Members))
|
||||||
i := 0
|
i := 0
|
||||||
for pubkey, roles := range group.Members {
|
for pubkey, roles := range group.Members {
|
||||||
@@ -101,13 +115,14 @@ func (group Group) String() string {
|
|||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf(`<Group %s name="%s"%s%s%s%s picture="%s" about="%s" members=[%v]>`,
|
return fmt.Sprintf(`<Group %s name="%s"%s%s%s%s%s picture="%s" about="%s" members=[%v]>`,
|
||||||
group.Address,
|
group.Address,
|
||||||
group.Name,
|
group.Name,
|
||||||
maybePrivate,
|
maybePrivate,
|
||||||
maybeRestricted,
|
maybeRestricted,
|
||||||
maybeHidden,
|
maybeHidden,
|
||||||
maybeClosed,
|
maybeClosed,
|
||||||
|
maybeLiveKit,
|
||||||
group.Picture,
|
group.Picture,
|
||||||
group.About,
|
group.About,
|
||||||
strings.Join(members, " "),
|
strings.Join(members, " "),
|
||||||
@@ -122,9 +137,10 @@ func NewGroup(gadstr string) (Group, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Group{
|
return Group{
|
||||||
Address: gad,
|
Address: gad,
|
||||||
Name: gad.ID,
|
Name: gad.ID,
|
||||||
Members: make(map[nostr.PubKey][]*Role),
|
Members: make(map[nostr.PubKey][]*Role),
|
||||||
|
LiveKitParticipants: make([]nostr.PubKey, 0),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,8 +150,9 @@ func NewGroupFromMetadataEvent(relayURL string, evt *nostr.Event) (Group, error)
|
|||||||
Relay: relayURL,
|
Relay: relayURL,
|
||||||
ID: evt.Tags.GetD(),
|
ID: evt.Tags.GetD(),
|
||||||
},
|
},
|
||||||
Name: evt.Tags.GetD(),
|
Name: evt.Tags.GetD(),
|
||||||
Members: make(map[nostr.PubKey][]*Role),
|
Members: make(map[nostr.PubKey][]*Role),
|
||||||
|
LiveKitParticipants: make([]nostr.PubKey, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := g.MergeInMetadataEvent(evt)
|
err := g.MergeInMetadataEvent(evt)
|
||||||
@@ -173,6 +190,18 @@ func (group Group) ToMetadataEvent() nostr.Event {
|
|||||||
if group.Closed {
|
if group.Closed {
|
||||||
evt.Tags = append(evt.Tags, nostr.Tag{"closed"})
|
evt.Tags = append(evt.Tags, nostr.Tag{"closed"})
|
||||||
}
|
}
|
||||||
|
if group.LiveKit {
|
||||||
|
evt.Tags = append(evt.Tags, nostr.Tag{"livekit"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if group.SupportedKinds != nil {
|
||||||
|
tag := make(nostr.Tag, 1, 1+len(group.SupportedKinds))
|
||||||
|
tag[0] = "supported_kinds"
|
||||||
|
for _, kind := range group.SupportedKinds {
|
||||||
|
tag = append(tag, strconv.Itoa(int(kind)))
|
||||||
|
}
|
||||||
|
evt.Tags = append(evt.Tags, tag)
|
||||||
|
}
|
||||||
|
|
||||||
return evt
|
return evt
|
||||||
}
|
}
|
||||||
@@ -236,6 +265,22 @@ func (group Group) ToRolesEvent() nostr.Event {
|
|||||||
return evt
|
return evt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (group Group) ToLiveKitParticipantsEvent() nostr.Event {
|
||||||
|
evt := nostr.Event{
|
||||||
|
Kind: nostr.KindSimpleGroupLiveKitParticipants,
|
||||||
|
CreatedAt: group.LastLiveKitParticipantsUpdate,
|
||||||
|
Tags: make(nostr.Tags, 1, 1+len(group.LiveKitParticipants)),
|
||||||
|
}
|
||||||
|
evt.Tags[0] = nostr.Tag{"d", group.Address.ID}
|
||||||
|
|
||||||
|
for _, member := range group.LiveKitParticipants {
|
||||||
|
tag := nostr.Tag{"participant", member.Hex()}
|
||||||
|
evt.Tags = append(evt.Tags, tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
return evt
|
||||||
|
}
|
||||||
|
|
||||||
func (group *Group) MergeInMetadataEvent(evt *nostr.Event) error {
|
func (group *Group) MergeInMetadataEvent(evt *nostr.Event) error {
|
||||||
if evt.Kind != nostr.KindSimpleGroupMetadata {
|
if evt.Kind != nostr.KindSimpleGroupMetadata {
|
||||||
return fmt.Errorf("expected kind %d, got %d", nostr.KindSimpleGroupMetadata, evt.Kind)
|
return fmt.Errorf("expected kind %d, got %d", nostr.KindSimpleGroupMetadata, evt.Kind)
|
||||||
@@ -247,27 +292,42 @@ func (group *Group) MergeInMetadataEvent(evt *nostr.Event) error {
|
|||||||
group.LastMetadataUpdate = evt.CreatedAt
|
group.LastMetadataUpdate = evt.CreatedAt
|
||||||
group.Name = group.Address.ID
|
group.Name = group.Address.ID
|
||||||
|
|
||||||
if tag := evt.Tags.Find("name"); tag != nil {
|
for _, tag := range evt.Tags {
|
||||||
group.Name = tag[1]
|
if len(tag) >= 1 {
|
||||||
}
|
switch tag[0] {
|
||||||
if tag := evt.Tags.Find("about"); tag != nil {
|
case "private":
|
||||||
group.About = tag[1]
|
group.Private = true
|
||||||
}
|
case "restricted":
|
||||||
if tag := evt.Tags.Find("picture"); tag != nil {
|
group.Restricted = true
|
||||||
group.Picture = tag[1]
|
case "closed":
|
||||||
}
|
group.Closed = true
|
||||||
|
case "hidden":
|
||||||
if tag := evt.Tags.Find("private"); tag != nil {
|
group.Hidden = true
|
||||||
group.Private = true
|
case "livekit":
|
||||||
}
|
group.LiveKit = true
|
||||||
if tag := evt.Tags.Find("restricted"); tag != nil {
|
case "supported_kinds":
|
||||||
group.Restricted = true
|
kinds := make([]nostr.Kind, 0, len(tag)-1)
|
||||||
}
|
for _, raw := range tag[1:] {
|
||||||
if tag := evt.Tags.Find("hidden"); tag != nil {
|
kind, err := strconv.Atoi(raw)
|
||||||
group.Hidden = true
|
if err != nil {
|
||||||
}
|
continue
|
||||||
if tag := evt.Tags.Find("closed"); tag != nil {
|
}
|
||||||
group.Closed = true
|
kinds = append(kinds, nostr.Kind(kind))
|
||||||
|
}
|
||||||
|
group.SupportedKinds = kinds
|
||||||
|
default:
|
||||||
|
if len(tag) >= 2 {
|
||||||
|
switch tag[0] {
|
||||||
|
case "name":
|
||||||
|
group.Name = tag[1]
|
||||||
|
case "about":
|
||||||
|
group.About = tag[1]
|
||||||
|
case "picture":
|
||||||
|
group.Picture = tag[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -368,3 +428,34 @@ func (group *Group) MergeInRolesEvent(evt *nostr.Event) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (group *Group) MergeInLiveKitParticipantsEvent(evt *nostr.Event) error {
|
||||||
|
if evt.Kind != nostr.KindSimpleGroupLiveKitParticipants {
|
||||||
|
return fmt.Errorf("expected kind %d, got %d", nostr.KindSimpleGroupLiveKitParticipants, evt.Kind)
|
||||||
|
}
|
||||||
|
if evt.CreatedAt < group.LastLiveKitParticipantsUpdate {
|
||||||
|
return fmt.Errorf("event is older than our last update (%d vs %d)", evt.CreatedAt, group.LastLiveKitParticipantsUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
group.LastLiveKitParticipantsUpdate = evt.CreatedAt
|
||||||
|
group.LiveKitParticipants = make([]nostr.PubKey, 0, len(evt.Tags))
|
||||||
|
for _, tag := range evt.Tags {
|
||||||
|
if len(tag) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tag[0] != "participant" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
member, err := nostr.PubKeyFromHex(tag[1])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if slices.Contains(group.LiveKitParticipants, member) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
group.LiveKitParticipants = append(group.LiveKitParticipants, member)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
+129
-48
@@ -3,6 +3,7 @@ package nip29
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
)
|
)
|
||||||
@@ -78,48 +79,101 @@ var moderationActionFactories = map[nostr.Kind]func(nostr.Event) (Action, error)
|
|||||||
nostr.KindSimpleGroupEditMetadata: func(evt nostr.Event) (Action, error) {
|
nostr.KindSimpleGroupEditMetadata: func(evt nostr.Event) (Action, error) {
|
||||||
ok := false
|
ok := false
|
||||||
edit := EditMetadata{When: evt.CreatedAt}
|
edit := EditMetadata{When: evt.CreatedAt}
|
||||||
if t := evt.Tags.Find("name"); t != nil {
|
|
||||||
edit.NameValue = &t[1]
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
if t := evt.Tags.Find("picture"); t != nil {
|
|
||||||
edit.PictureValue = &t[1]
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
if t := evt.Tags.Find("about"); t != nil {
|
|
||||||
edit.AboutValue = &t[1]
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
|
|
||||||
y := true
|
y := true
|
||||||
n := false
|
n := false
|
||||||
if evt.Tags.Has("closed") {
|
|
||||||
edit.ClosedValue = &y
|
hasName := false
|
||||||
ok = true
|
|
||||||
} else if evt.Tags.Has("open") {
|
// DEPRECATED: remove all the fields not tagged with Replace = true eventually
|
||||||
edit.ClosedValue = &n
|
// edit-metadata to become a PUT rather than a PATCH
|
||||||
ok = true
|
|
||||||
}
|
for _, tag := range evt.Tags {
|
||||||
if evt.Tags.Has("restricted") {
|
if len(tag) >= 1 {
|
||||||
edit.RestrictedValue = &y
|
switch tag[0] {
|
||||||
ok = true
|
case "name":
|
||||||
} else if evt.Tags.Has("unrestricted") {
|
if len(tag) >= 2 {
|
||||||
edit.RestrictedValue = &n
|
edit.NameValue = &tag[1]
|
||||||
ok = true
|
if ok {
|
||||||
}
|
edit.Replace = true
|
||||||
if evt.Tags.Has("hidden") {
|
}
|
||||||
edit.HiddenValue = &y
|
ok = true
|
||||||
ok = true
|
hasName = true
|
||||||
} else if evt.Tags.Has("visible") {
|
}
|
||||||
edit.HiddenValue = &n
|
case "picture":
|
||||||
ok = true
|
if len(tag) >= 2 {
|
||||||
}
|
edit.PictureValue = &tag[1]
|
||||||
if evt.Tags.Has("private") {
|
if hasName {
|
||||||
edit.PrivateValue = &y
|
edit.Replace = true
|
||||||
ok = true
|
}
|
||||||
} else if evt.Tags.Has("public") {
|
ok = true
|
||||||
edit.PrivateValue = &n
|
}
|
||||||
ok = true
|
case "about":
|
||||||
|
if len(tag) >= 2 {
|
||||||
|
edit.AboutValue = &tag[1]
|
||||||
|
if hasName {
|
||||||
|
edit.Replace = true
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
case "supported_kinds":
|
||||||
|
kinds := make([]nostr.Kind, 0, len(tag)-1)
|
||||||
|
for _, kstr := range tag[1:] {
|
||||||
|
if kind, err := strconv.ParseUint(kstr, 10, 16); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid kind: %w", err)
|
||||||
|
} else {
|
||||||
|
kinds = append(kinds, nostr.Kind(kind))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
edit.SupportedKindsValue = &kinds
|
||||||
|
edit.Replace = true
|
||||||
|
case "closed":
|
||||||
|
edit.ClosedValue = &y
|
||||||
|
if hasName {
|
||||||
|
edit.Replace = true
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
case "open":
|
||||||
|
edit.ClosedValue = &n
|
||||||
|
ok = true
|
||||||
|
case "restricted":
|
||||||
|
edit.RestrictedValue = &y
|
||||||
|
if hasName {
|
||||||
|
edit.Replace = true
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
case "unrestricted":
|
||||||
|
edit.RestrictedValue = &n
|
||||||
|
ok = true
|
||||||
|
case "hidden":
|
||||||
|
edit.HiddenValue = &y
|
||||||
|
if hasName {
|
||||||
|
edit.Replace = true
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
case "visible":
|
||||||
|
edit.HiddenValue = &n
|
||||||
|
ok = true
|
||||||
|
case "private":
|
||||||
|
edit.PrivateValue = &y
|
||||||
|
if hasName {
|
||||||
|
edit.Replace = true
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
case "public":
|
||||||
|
edit.PrivateValue = &n
|
||||||
|
ok = true
|
||||||
|
case "livekit":
|
||||||
|
edit.LiveKitValue = &y
|
||||||
|
edit.Replace = true
|
||||||
|
ok = true
|
||||||
|
case "no-livekit":
|
||||||
|
edit.LiveKitValue = &n
|
||||||
|
ok = true
|
||||||
|
case "no-text":
|
||||||
|
edit.SupportedKindsValue = nil
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
@@ -226,19 +280,36 @@ func (a RemoveUser) Apply(group *Group) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EditMetadata struct {
|
type EditMetadata struct {
|
||||||
NameValue *string
|
NameValue *string
|
||||||
PictureValue *string
|
PictureValue *string
|
||||||
AboutValue *string
|
AboutValue *string
|
||||||
RestrictedValue *bool
|
RestrictedValue *bool
|
||||||
ClosedValue *bool
|
ClosedValue *bool
|
||||||
HiddenValue *bool
|
HiddenValue *bool
|
||||||
PrivateValue *bool
|
PrivateValue *bool
|
||||||
When nostr.Timestamp
|
LiveKitValue *bool
|
||||||
|
SupportedKindsValue *[]nostr.Kind
|
||||||
|
|
||||||
|
Replace bool
|
||||||
|
When nostr.Timestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (_ EditMetadata) Name() string { return "edit-metadata" }
|
func (_ EditMetadata) Name() string { return "edit-metadata" }
|
||||||
func (a EditMetadata) Apply(group *Group) {
|
func (a EditMetadata) Apply(group *Group) {
|
||||||
group.LastMetadataUpdate = a.When
|
group.LastMetadataUpdate = a.When
|
||||||
|
|
||||||
|
if a.Replace {
|
||||||
|
group.Name = ""
|
||||||
|
group.Picture = ""
|
||||||
|
group.About = ""
|
||||||
|
group.Restricted = false
|
||||||
|
group.Closed = false
|
||||||
|
group.Hidden = false
|
||||||
|
group.Private = false
|
||||||
|
group.LiveKit = false
|
||||||
|
group.SupportedKinds = nil
|
||||||
|
}
|
||||||
|
|
||||||
if a.NameValue != nil {
|
if a.NameValue != nil {
|
||||||
group.Name = *a.NameValue
|
group.Name = *a.NameValue
|
||||||
}
|
}
|
||||||
@@ -260,6 +331,12 @@ func (a EditMetadata) Apply(group *Group) {
|
|||||||
if a.PrivateValue != nil {
|
if a.PrivateValue != nil {
|
||||||
group.Private = *a.PrivateValue
|
group.Private = *a.PrivateValue
|
||||||
}
|
}
|
||||||
|
if a.LiveKitValue != nil {
|
||||||
|
group.LiveKit = *a.LiveKitValue
|
||||||
|
}
|
||||||
|
if a.SupportedKindsValue != nil {
|
||||||
|
group.SupportedKinds = *a.SupportedKindsValue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateGroup struct {
|
type CreateGroup struct {
|
||||||
@@ -272,6 +349,7 @@ func (a CreateGroup) Apply(group *Group) {
|
|||||||
group.LastMetadataUpdate = a.When
|
group.LastMetadataUpdate = a.When
|
||||||
group.LastAdminsUpdate = a.When
|
group.LastAdminsUpdate = a.When
|
||||||
group.LastMembersUpdate = a.When
|
group.LastMembersUpdate = a.When
|
||||||
|
group.LastLiveKitParticipantsUpdate = a.When
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeleteGroup struct {
|
type DeleteGroup struct {
|
||||||
@@ -281,6 +359,7 @@ type DeleteGroup struct {
|
|||||||
func (_ DeleteGroup) Name() string { return "delete-group" }
|
func (_ DeleteGroup) Name() string { return "delete-group" }
|
||||||
func (a DeleteGroup) Apply(group *Group) {
|
func (a DeleteGroup) Apply(group *Group) {
|
||||||
group.Members = make(map[nostr.PubKey][]*Role)
|
group.Members = make(map[nostr.PubKey][]*Role)
|
||||||
|
group.LiveKitParticipants = make([]nostr.PubKey, 0)
|
||||||
group.Closed = true
|
group.Closed = true
|
||||||
group.Private = true
|
group.Private = true
|
||||||
group.Restricted = true
|
group.Restricted = true
|
||||||
@@ -288,9 +367,11 @@ func (a DeleteGroup) Apply(group *Group) {
|
|||||||
group.Name = "[deleted]"
|
group.Name = "[deleted]"
|
||||||
group.About = ""
|
group.About = ""
|
||||||
group.Picture = ""
|
group.Picture = ""
|
||||||
|
group.LiveKit = false
|
||||||
group.LastMetadataUpdate = a.When
|
group.LastMetadataUpdate = a.When
|
||||||
group.LastAdminsUpdate = a.When
|
group.LastAdminsUpdate = a.When
|
||||||
group.LastMembersUpdate = a.When
|
group.LastMembersUpdate = a.When
|
||||||
|
group.LastLiveKitParticipantsUpdate = a.When
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateInvite struct {
|
type CreateInvite struct {
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ var MetadataEventKinds = KindRange{
|
|||||||
nostr.KindSimpleGroupAdmins,
|
nostr.KindSimpleGroupAdmins,
|
||||||
nostr.KindSimpleGroupMembers,
|
nostr.KindSimpleGroupMembers,
|
||||||
nostr.KindSimpleGroupRoles,
|
nostr.KindSimpleGroupRoles,
|
||||||
|
nostr.KindSimpleGroupLiveKitParticipants,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kr KindRange) Includes(kind nostr.Kind) bool {
|
func (kr KindRange) Includes(kind nostr.Kind) bool {
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Person struct {
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
Timestamp int64
|
||||||
|
Timezone string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Commit struct {
|
||||||
|
Hash string
|
||||||
|
Tree string
|
||||||
|
Parents []string
|
||||||
|
Author Person
|
||||||
|
Committer Person
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseCommit(data []byte, hash string) (*Commit, error) {
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
headerEndIndex := strings.Index(content, "\n\n")
|
||||||
|
if headerEndIndex == -1 {
|
||||||
|
return nil, fmt.Errorf("invalid commit format for %s: no message separator found", hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
header := content[:headerEndIndex]
|
||||||
|
message := content[headerEndIndex+2:]
|
||||||
|
|
||||||
|
lines := strings.Split(header, "\n")
|
||||||
|
result := &Commit{
|
||||||
|
Hash: hash,
|
||||||
|
Parents: []string{},
|
||||||
|
Message: message,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.HasPrefix(line, "tree ") {
|
||||||
|
result.Tree = line[5:]
|
||||||
|
} else if strings.HasPrefix(line, "parent ") {
|
||||||
|
result.Parents = append(result.Parents, line[7:])
|
||||||
|
} else if strings.HasPrefix(line, "author ") {
|
||||||
|
person, err := parsePerson(line[7:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid author in commit %s: %w", hash, err)
|
||||||
|
}
|
||||||
|
result.Author = person
|
||||||
|
} else if strings.HasPrefix(line, "committer ") {
|
||||||
|
person, err := parsePerson(line[10:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid committer in commit %s: %w", hash, err)
|
||||||
|
}
|
||||||
|
result.Committer = person
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Tree == "" {
|
||||||
|
return nil, fmt.Errorf("invalid commit format for %s: missing tree", hash)
|
||||||
|
}
|
||||||
|
if result.Author.Name == "" {
|
||||||
|
return nil, fmt.Errorf("invalid commit format for %s: missing author", hash)
|
||||||
|
}
|
||||||
|
if result.Committer.Name == "" {
|
||||||
|
return nil, fmt.Errorf("invalid commit format for %s: missing committer", hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePerson(line string) (Person, error) {
|
||||||
|
emailStart := strings.Index(line, " <")
|
||||||
|
if emailStart == -1 {
|
||||||
|
return Person{}, fmt.Errorf("invalid person format: %s", line)
|
||||||
|
}
|
||||||
|
name := line[:emailStart]
|
||||||
|
|
||||||
|
emailEnd := strings.Index(line[emailStart+2:], ">")
|
||||||
|
if emailEnd == -1 {
|
||||||
|
return Person{}, fmt.Errorf("invalid person format: %s", line)
|
||||||
|
}
|
||||||
|
email := line[emailStart+2 : emailStart+2+emailEnd]
|
||||||
|
|
||||||
|
rest := strings.TrimSpace(line[emailStart+2+emailEnd+1:])
|
||||||
|
parts := strings.SplitN(rest, " ", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return Person{}, fmt.Errorf("invalid person format: %s", line)
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp, err := strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return Person{}, fmt.Errorf("invalid timestamp in person: %s", parts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return Person{
|
||||||
|
Name: name,
|
||||||
|
Email: email,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
Timezone: parts[1],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,413 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DiffLine struct {
|
||||||
|
Index int
|
||||||
|
Status string
|
||||||
|
Text string
|
||||||
|
Change string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiffFile struct {
|
||||||
|
Path string
|
||||||
|
Status string
|
||||||
|
Content []byte
|
||||||
|
Lines []DiffLine
|
||||||
|
}
|
||||||
|
|
||||||
|
type changedEntry struct {
|
||||||
|
newVersion TreeEntry
|
||||||
|
oldVersions []TreeEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCommitDiff(url string, commitOrRef string) ([]DiffFile, error) {
|
||||||
|
commit, err := GetSingleCommit(url, commitOrRef)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
added := make(map[string]TreeEntry)
|
||||||
|
deleted := make(map[string]TreeEntry)
|
||||||
|
changed := make(map[string]*changedEntry)
|
||||||
|
unchanged := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, parent := range commit.Parents {
|
||||||
|
parentCommit, err := GetSingleCommit(url, parent)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = computeTreeDiffs(url, commit.Tree, parentCommit.Tree, "", added, deleted, changed, unchanged)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var diff []DiffFile
|
||||||
|
var mu sync.Mutex
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var firstErr error
|
||||||
|
|
||||||
|
for path, entry := range changed {
|
||||||
|
p := path
|
||||||
|
e := entry
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
curr, err := GetObject(url, e.newVersion.Hash)
|
||||||
|
if err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if curr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(e.oldVersions) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oldObj, err := GetObject(url, e.oldVersions[0].Hash)
|
||||||
|
if err != nil || oldObj == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isBinary(curr.Data) || isBinary(oldObj.Data) {
|
||||||
|
mu.Lock()
|
||||||
|
diff = append(diff, DiffFile{
|
||||||
|
Path: p,
|
||||||
|
Status: "changed-binary",
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := diffTextLines(oldObj.Data, curr.Data)
|
||||||
|
mu.Lock()
|
||||||
|
diff = append(diff, DiffFile{
|
||||||
|
Path: p,
|
||||||
|
Status: "changed",
|
||||||
|
Lines: lines,
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for path, entry := range deleted {
|
||||||
|
p := path
|
||||||
|
e := entry
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
obj, err := GetObject(url, e.Hash)
|
||||||
|
if err != nil || obj == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
diff = append(diff, DiffFile{
|
||||||
|
Path: p,
|
||||||
|
Status: "deleted",
|
||||||
|
Content: obj.Data,
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for path, entry := range added {
|
||||||
|
p := path
|
||||||
|
e := entry
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
obj, err := GetObject(url, e.Hash)
|
||||||
|
if err != nil || obj == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
diff = append(diff, DiffFile{
|
||||||
|
Path: p,
|
||||||
|
Status: "added",
|
||||||
|
Content: obj.Data,
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if firstErr != nil {
|
||||||
|
return nil, firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return diff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBinary(data []byte) bool {
|
||||||
|
for _, b := range data {
|
||||||
|
if b == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func diffTextLines(oldData []byte, newData []byte) []DiffLine {
|
||||||
|
oldText := string(oldData)
|
||||||
|
newText := string(newData)
|
||||||
|
oldLines := splitLines(oldText)
|
||||||
|
newLines := splitLines(newText)
|
||||||
|
|
||||||
|
ops := lcsOperations(oldLines, newLines)
|
||||||
|
allLines := make([]DiffLine, 0, len(ops))
|
||||||
|
oldIndex := 1
|
||||||
|
newIndex := 1
|
||||||
|
|
||||||
|
for i := 0; i < len(ops); i++ {
|
||||||
|
op := ops[i]
|
||||||
|
var next *lcsOp
|
||||||
|
if i+1 < len(ops) {
|
||||||
|
next = &ops[i+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if op.typ == "del" && next != nil && next.typ == "add" {
|
||||||
|
allLines = append(allLines, DiffLine{
|
||||||
|
Status: "changed",
|
||||||
|
Index: newIndex,
|
||||||
|
Text: next.line,
|
||||||
|
})
|
||||||
|
oldIndex++
|
||||||
|
newIndex++
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if op.typ == "add" && next != nil && next.typ == "del" {
|
||||||
|
allLines = append(allLines, DiffLine{
|
||||||
|
Status: "changed",
|
||||||
|
Index: newIndex,
|
||||||
|
Text: op.line,
|
||||||
|
})
|
||||||
|
oldIndex++
|
||||||
|
newIndex++
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if op.typ == "add" {
|
||||||
|
allLines = append(allLines, DiffLine{
|
||||||
|
Status: "added",
|
||||||
|
Index: newIndex,
|
||||||
|
Text: op.line,
|
||||||
|
})
|
||||||
|
newIndex++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if op.typ == "del" {
|
||||||
|
allLines = append(allLines, DiffLine{
|
||||||
|
Status: "deleted",
|
||||||
|
Index: oldIndex,
|
||||||
|
Text: op.line,
|
||||||
|
})
|
||||||
|
oldIndex++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
oldIndex++
|
||||||
|
newIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allLines) == 0 {
|
||||||
|
return allLines
|
||||||
|
}
|
||||||
|
|
||||||
|
keep := make([]bool, len(allLines))
|
||||||
|
for i := 0; i < len(allLines); i++ {
|
||||||
|
start := i - 3
|
||||||
|
if start < 0 {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
end := i + 3
|
||||||
|
if end >= len(allLines) {
|
||||||
|
end = len(allLines) - 1
|
||||||
|
}
|
||||||
|
for j := start; j <= end; j++ {
|
||||||
|
keep[j] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]DiffLine, 0, len(allLines))
|
||||||
|
for i := 0; i < len(allLines); i++ {
|
||||||
|
if keep[i] {
|
||||||
|
result = append(result, allLines[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
type lcsOp struct {
|
||||||
|
typ string
|
||||||
|
line string
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitLines(text string) []string {
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||||
|
lines = lines[:len(lines)-1]
|
||||||
|
}
|
||||||
|
return lines
|
||||||
|
}
|
||||||
|
|
||||||
|
func lcsOperations(oldLines []string, newLines []string) []lcsOp {
|
||||||
|
n := len(oldLines)
|
||||||
|
m := len(newLines)
|
||||||
|
|
||||||
|
dp := make([][]uint32, n+1)
|
||||||
|
for i := range dp {
|
||||||
|
dp[i] = make([]uint32, m+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= n; i++ {
|
||||||
|
for j := 1; j <= m; j++ {
|
||||||
|
if oldLines[i-1] == newLines[j-1] {
|
||||||
|
dp[i][j] = dp[i-1][j-1] + 1
|
||||||
|
} else if dp[i-1][j] >= dp[i][j-1] {
|
||||||
|
dp[i][j] = dp[i-1][j]
|
||||||
|
} else {
|
||||||
|
dp[i][j] = dp[i][j-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ops := make([]lcsOp, 0, n+m)
|
||||||
|
i := n
|
||||||
|
j := m
|
||||||
|
for i > 0 || j > 0 {
|
||||||
|
if i > 0 && j > 0 && oldLines[i-1] == newLines[j-1] {
|
||||||
|
ops = append(ops, lcsOp{typ: "equal", line: oldLines[i-1]})
|
||||||
|
i--
|
||||||
|
j--
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if i > 0 && (j == 0 || dp[i-1][j] >= dp[i][j-1]) {
|
||||||
|
ops = append(ops, lcsOp{typ: "del", line: oldLines[i-1]})
|
||||||
|
i--
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if j > 0 {
|
||||||
|
ops = append(ops, lcsOp{typ: "add", line: newLines[j-1]})
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, j := 0, len(ops)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
ops[i], ops[j] = ops[j], ops[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return ops
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeTreeDiffs(
|
||||||
|
url string,
|
||||||
|
treeHash string,
|
||||||
|
parentTreeHash string,
|
||||||
|
basePath string,
|
||||||
|
added map[string]TreeEntry,
|
||||||
|
deleted map[string]TreeEntry,
|
||||||
|
changed map[string]*changedEntry,
|
||||||
|
unchanged map[string]bool,
|
||||||
|
) error {
|
||||||
|
var newTree []TreeEntry
|
||||||
|
var oldTree []TreeEntry
|
||||||
|
|
||||||
|
if treeHash != "" {
|
||||||
|
obj, err := GetObject(url, treeHash)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if obj != nil {
|
||||||
|
newTree = ParseTree(obj.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if parentTreeHash != "" {
|
||||||
|
obj, err := GetObject(url, parentTreeHash)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if obj != nil {
|
||||||
|
oldTree = ParseTree(obj.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range newTree {
|
||||||
|
var old *TreeEntry
|
||||||
|
for _, o := range oldTree {
|
||||||
|
if o.Path == entry.Path {
|
||||||
|
o := o
|
||||||
|
old = &o
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if old != nil {
|
||||||
|
delete(added, basePath+entry.Path)
|
||||||
|
|
||||||
|
if old.Hash == entry.Hash {
|
||||||
|
unchanged[basePath+entry.Path] = true
|
||||||
|
} else {
|
||||||
|
if entry.IsDir {
|
||||||
|
err := computeTreeDiffs(url, entry.Hash, old.Hash, basePath+entry.Path+"/", added, deleted, changed, unchanged)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if existing, exists := changed[basePath+entry.Path]; !exists {
|
||||||
|
changed[basePath+entry.Path] = &changedEntry{
|
||||||
|
newVersion: entry,
|
||||||
|
oldVersions: []TreeEntry{*old},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
existing.oldVersions = append(existing.oldVersions, *old)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if entry.IsDir {
|
||||||
|
err := computeTreeDiffs(url, entry.Hash, "", basePath+entry.Path+"/", added, deleted, changed, unchanged)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
added[basePath+entry.Path] = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, old := range oldTree {
|
||||||
|
if unchanged[basePath+old.Path] || changed[basePath+old.Path] != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if old.IsDir {
|
||||||
|
err := computeTreeDiffs(url, "", old.Hash, basePath+old.Path+"/", added, deleted, changed, unchanged)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
deleted[basePath+old.Path] = old
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MissingCapability struct {
|
||||||
|
URL string
|
||||||
|
Capability string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *MissingCapability) Error() string {
|
||||||
|
return fmt.Sprintf("server at %s is missing required capability %s", e.URL, e.Capability)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareRequest(url string, commitOrRef string, needFilter bool) (resolvedRef string, capabilities []string, err error) {
|
||||||
|
var info *InfoRefsUploadPackResponse
|
||||||
|
if strings.HasPrefix(commitOrRef, "refs/") {
|
||||||
|
info, err = GetInfoRefs(url)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
resolved, ok := info.Refs[commitOrRef]
|
||||||
|
if !ok {
|
||||||
|
return "", nil, fmt.Errorf("ref %s not found", commitOrRef)
|
||||||
|
}
|
||||||
|
commitOrRef = resolved
|
||||||
|
}
|
||||||
|
|
||||||
|
caps, err := GetCapabilities(url, info)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range DefaultCapabilities {
|
||||||
|
if slices.Contains(caps, c) {
|
||||||
|
capabilities = append(capabilities, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, c := range NecessaryCapabilities {
|
||||||
|
if slices.Contains(caps, c) {
|
||||||
|
capabilities = append(capabilities, c)
|
||||||
|
} else {
|
||||||
|
return "", nil, &MissingCapability{URL: url, Capability: c}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, c := range RequiredCapabilities {
|
||||||
|
if !slices.Contains(caps, c) {
|
||||||
|
return "", nil, &MissingCapability{URL: url, Capability: c}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if needFilter {
|
||||||
|
if slices.Contains(caps, "filter") {
|
||||||
|
capabilities = append(capabilities, "filter")
|
||||||
|
} else {
|
||||||
|
return "", nil, &MissingCapability{URL: url, Capability: "filter"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return commitOrRef, capabilities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetObject(url string, blobHash string) (*ParsedObject, error) {
|
||||||
|
ref, caps, err := prepareRequest(url, blobHash, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
deepen := 1
|
||||||
|
want, err := CreateWantRequest(ref, caps, &deepen, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FetchPackfile(url, want)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.Objects[blobHash], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDirectoryTreeAt(url string, commitOrRef string, nestLimit *int) (*Tree, error) {
|
||||||
|
ref, caps, err := prepareRequest(url, commitOrRef, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
want, err := CreateWantRequest(ref, caps, nestLimit, "blob:none")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FetchPackfile(url, want)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
commit := result.Objects[ref]
|
||||||
|
if commit == nil {
|
||||||
|
return nil, fmt.Errorf("commit %s not found in packfile", ref)
|
||||||
|
}
|
||||||
|
|
||||||
|
treeHash := string(commit.Data[5:45])
|
||||||
|
rootTree := result.Objects[treeHash]
|
||||||
|
if rootTree == nil {
|
||||||
|
return nil, fmt.Errorf("root tree %s not found in packfile", treeHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoadTree(rootTree, result.Objects, nestLimit), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShallowCloneRepositoryAt(url string, commitOrRef string) (*Commit, *Tree, error) {
|
||||||
|
ref, caps, err := prepareRequest(url, commitOrRef, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
deepen := 1
|
||||||
|
want, err := CreateWantRequest(ref, caps, &deepen, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FetchPackfile(url, want)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
commitObj := result.Objects[ref]
|
||||||
|
if commitObj == nil {
|
||||||
|
return nil, nil, fmt.Errorf("commit %s not found in packfile", ref)
|
||||||
|
}
|
||||||
|
|
||||||
|
treeHash := string(commitObj.Data[5:45])
|
||||||
|
rootTree := result.Objects[treeHash]
|
||||||
|
if rootTree == nil {
|
||||||
|
return nil, nil, fmt.Errorf("root tree %s not found in packfile", treeHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
commit, err := ParseCommit(commitObj.Data, commitObj.Hash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tree := LoadTree(rootTree, result.Objects, nil)
|
||||||
|
return commit, tree, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchCommitsOnly(url string, commitOrRef string, maxCommits *int) ([]*Commit, error) {
|
||||||
|
ref, caps, err := prepareRequest(url, commitOrRef, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
want, err := CreateWantRequest(ref, caps, maxCommits, "tree:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FetchPackfile(url, want)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
commitMap := make(map[string]*Commit, len(result.Objects))
|
||||||
|
for hash, obj := range result.Objects {
|
||||||
|
commit, err := ParseCommit(obj.Data, hash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
commitMap[hash] = commit
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort topologically starting from the requested ref
|
||||||
|
sorted := make([]*Commit, 0, len(commitMap))
|
||||||
|
visited := make(map[string]bool, len(commitMap))
|
||||||
|
var visit func(hash string)
|
||||||
|
visit = func(hash string) {
|
||||||
|
if visited[hash] {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c, ok := commitMap[hash]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
visited[hash] = true
|
||||||
|
sorted = append(sorted, c)
|
||||||
|
for _, parent := range c.Parents {
|
||||||
|
visit(parent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
visit(ref)
|
||||||
|
|
||||||
|
for _, c := range commitMap {
|
||||||
|
if !visited[c.Hash] {
|
||||||
|
sorted = append(sorted, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sorted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSingleCommit(url string, commitOrRef string) (*Commit, error) {
|
||||||
|
maxCommits := 1
|
||||||
|
commits, err := FetchCommitsOnly(url, commitOrRef, &maxCommits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(commits) == 0 {
|
||||||
|
return nil, fmt.Errorf("no commit found for reference: %s", commitOrRef)
|
||||||
|
}
|
||||||
|
return commits[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetObjectByPath(url string, commitOrRef string, path string) (*TreeEntry, error) {
|
||||||
|
normalizedPath := strings.ReplaceAll(path, "\\", "/")
|
||||||
|
normalizedPath = strings.TrimLeft(normalizedPath, "/")
|
||||||
|
normalizedPath = strings.TrimRight(normalizedPath, "/")
|
||||||
|
|
||||||
|
var pathSegments []string
|
||||||
|
if normalizedPath != "" {
|
||||||
|
pathSegments = strings.Split(normalizedPath, "/")
|
||||||
|
}
|
||||||
|
requiredDepth := len(pathSegments)
|
||||||
|
|
||||||
|
tree, err := GetDirectoryTreeAt(url, commitOrRef, &requiredDepth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
currentLevel := tree
|
||||||
|
nextSegment:
|
||||||
|
for i, segment := range pathSegments {
|
||||||
|
isLastSegment := i == len(pathSegments)-1
|
||||||
|
|
||||||
|
for _, dir := range currentLevel.Directories {
|
||||||
|
if dir.Name == segment {
|
||||||
|
if isLastSegment {
|
||||||
|
return &TreeEntry{Path: segment, Mode: "40000", IsDir: true, Hash: dir.Hash}, nil
|
||||||
|
}
|
||||||
|
if dir.Content != nil {
|
||||||
|
currentLevel = dir.Content
|
||||||
|
continue nextSegment
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLastSegment {
|
||||||
|
for _, file := range currentLevel.Files {
|
||||||
|
if file.Name == segment {
|
||||||
|
return &TreeEntry{Path: segment, Mode: "100644", IsDir: false, Hash: file.Hash}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,289 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetRefs(t *testing.T) {
|
||||||
|
info, err := GetInfoRefs("https://codeberg.org/dluvian/gitplaza.git")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Contains(t, info.Capabilities, "shallow")
|
||||||
|
require.Contains(t, info.Capabilities, "object-format=sha1")
|
||||||
|
require.Greater(t, len(info.Refs), 5)
|
||||||
|
require.Contains(t, info.Refs, "refs/heads/master")
|
||||||
|
require.Equal(t, "a04d0761564b0d23c5edbadf494ab4f1cc4656f4", info.Refs["refs/tags/v0.1.0"])
|
||||||
|
require.Equal(t, "refs/heads/master", info.Symrefs["HEAD"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOnlyTreeAtCurrentCommit(t *testing.T) {
|
||||||
|
urls := []string{
|
||||||
|
"https://codeberg.org/dluvian/gitplaza.git",
|
||||||
|
"https://github.com/fiatjaf/pyramid.git",
|
||||||
|
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, url := range urls {
|
||||||
|
t.Run(url, func(t *testing.T) {
|
||||||
|
tree, err := GetDirectoryTreeAt(url, "refs/heads/master", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, file := range tree.Files {
|
||||||
|
require.Nil(t, file.Content, "file %q should have nil content at %s", file.Name, url)
|
||||||
|
}
|
||||||
|
require.Greater(t, len(tree.Directories), 2, "at %s", url)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneRepositoryAtCurrentCommit(t *testing.T) {
|
||||||
|
urls := []string{
|
||||||
|
"https://codeberg.org/dluvian/gitplaza.git",
|
||||||
|
"https://github.com/fiatjaf/pyramid.git",
|
||||||
|
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, url := range urls {
|
||||||
|
t.Run(url, func(t *testing.T) {
|
||||||
|
commit, tree, err := ShallowCloneRepositoryAt(url, "refs/heads/master")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, len(tree.Files), 5, "at %s", url)
|
||||||
|
require.Greater(t, len(tree.Directories), 2, "at %s", url)
|
||||||
|
|
||||||
|
info, err := GetInfoRefs(url)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, info.Refs["refs/heads/master"], commit.Hash, "at %s", url)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSpecificObject(t *testing.T) {
|
||||||
|
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||||
|
hash := "0f9438a8fd68594cd663fb8dbd23c5f5139f5263" // shell.nix
|
||||||
|
|
||||||
|
blob, err := GetObject(url, hash)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, blob)
|
||||||
|
require.Equal(t, ObjectTypeBlob, blob.Type)
|
||||||
|
|
||||||
|
expected := "(builtins.getFlake\n (\"git+file://\" + toString ./.)).devShells.${builtins.currentSystem}.default\n"
|
||||||
|
require.Equal(t, expected, string(blob.Data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNonExistentCommit(t *testing.T) {
|
||||||
|
urls := []string{
|
||||||
|
"https://codeberg.org/dluvian/gitplaza.git",
|
||||||
|
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||||
|
"https://github.com/fiatjaf/nak.git",
|
||||||
|
}
|
||||||
|
|
||||||
|
commit := "1d4438a8fd68594cd663fb8dbd23c5f5139fabcd" // doesn't exist
|
||||||
|
|
||||||
|
for _, url := range urls {
|
||||||
|
t.Run(url, func(t *testing.T) {
|
||||||
|
_, err := GetDirectoryTreeAt(url, commit, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
var missingRef *MissingRef
|
||||||
|
require.ErrorAs(t, err, &missingRef)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchListOfCommits(t *testing.T) {
|
||||||
|
commits, err := FetchCommitsOnly(
|
||||||
|
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||||
|
"refs/heads/master",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, len(commits), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetch10PastCommits(t *testing.T) {
|
||||||
|
maxCommits := 10
|
||||||
|
commits, err := FetchCommitsOnly(
|
||||||
|
"https://github.com/fiatjaf/pyramid.git",
|
||||||
|
"57712756e37d7c60d1ac53e0f6b59e9ecad67c9a",
|
||||||
|
&maxCommits,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, commits, 10)
|
||||||
|
|
||||||
|
c := commits[1]
|
||||||
|
require.Equal(t, "49c1b48f5120bad4089535a190d2233c96188fa2", c.Hash)
|
||||||
|
require.Equal(t, "286786a6f1072a2ef5ae057fbb611858b8e88bc4", c.Tree)
|
||||||
|
require.Equal(t, []string{"1599e46c0ee6f460e25048880754868d4f9644fd"}, c.Parents)
|
||||||
|
require.Equal(t, "fiatjaf", c.Author.Name)
|
||||||
|
require.Equal(t, "fiatjaf@gmail.com", c.Author.Email)
|
||||||
|
require.Equal(t, int64(1767157644), c.Author.Timestamp)
|
||||||
|
require.Equal(t, "-0300", c.Author.Timezone)
|
||||||
|
require.Equal(t, "fiatjaf", c.Committer.Name)
|
||||||
|
require.Equal(t, "fiatjaf@gmail.com", c.Committer.Email)
|
||||||
|
require.Equal(t, int64(1767157724), c.Committer.Timestamp)
|
||||||
|
require.Equal(t, "-0300", c.Committer.Timezone)
|
||||||
|
require.Equal(t, "scheduled events.\n", c.Message)
|
||||||
|
|
||||||
|
expectedMsg5 := "turn off groups logic on QueryStore and PreventBroadcast when groups is turned off.\n\nthis was causing crashes that Golang's bizarre iter API showed as happening inside SortedMerge.\n"
|
||||||
|
require.Equal(t, expectedMsg5, commits[5].Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSingleCommit(t *testing.T) {
|
||||||
|
url := "https://github.com/fiatjaf/pyramid.git"
|
||||||
|
commit, err := GetSingleCommit(url, "5e982dd1122a0bb1b0154c222ec4ba841f3820c6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "5e982dd1122a0bb1b0154c222ec4ba841f3820c6", commit.Hash)
|
||||||
|
require.Equal(t, "fiatjaf", commit.Author.Name)
|
||||||
|
require.Equal(t, "validate incoming git-related stuff.\n", commit.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDirectoryTreeWithDepthLimit(t *testing.T) {
|
||||||
|
url := "https://github.com/fiatjaf/pyramid.git"
|
||||||
|
|
||||||
|
fullTree, err := GetDirectoryTreeAt(url, "refs/heads/master", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
depth := 1
|
||||||
|
shallowTree, err := GetDirectoryTreeAt(url, "refs/heads/master", &depth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, len(fullTree.Directories), len(shallowTree.Directories))
|
||||||
|
|
||||||
|
for _, dir := range shallowTree.Directories {
|
||||||
|
require.NotNil(t, dir.Content, "directory %q content should not be nil at depth 1", dir.Name)
|
||||||
|
for _, file := range dir.Content.Files {
|
||||||
|
require.NotEmpty(t, file.Name)
|
||||||
|
require.Nil(t, file.Content, "file %q content should be nil", file.Name)
|
||||||
|
}
|
||||||
|
for _, subdir := range dir.Content.Directories {
|
||||||
|
require.NotEmpty(t, subdir.Name)
|
||||||
|
require.Nil(t, subdir.Content, "subdir %q content should be nil at depth 1", subdir.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, len(fullTree.Files), len(shallowTree.Files))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetObjectByPathExistingFile(t *testing.T) {
|
||||||
|
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||||
|
entry, err := GetObjectByPath(url, "refs/heads/master", "README.md")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, entry)
|
||||||
|
require.Equal(t, "README.md", entry.Path)
|
||||||
|
require.False(t, entry.IsDir)
|
||||||
|
require.Equal(t, "100644", entry.Mode)
|
||||||
|
require.NotEmpty(t, entry.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetObjectByPathExistingDirectory(t *testing.T) {
|
||||||
|
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||||
|
entry, err := GetObjectByPath(url, "refs/heads/master", "src")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, entry)
|
||||||
|
require.Equal(t, "src", entry.Path)
|
||||||
|
require.True(t, entry.IsDir)
|
||||||
|
require.Equal(t, "40000", entry.Mode)
|
||||||
|
require.NotEmpty(t, entry.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetObjectByPathNestedFile(t *testing.T) {
|
||||||
|
url := "https://github.com/fiatjaf/pyramid.git"
|
||||||
|
entry, err := GetObjectByPath(url, "d567c18cd5c144a58b0214216f454b3caf49d4ff", "grasp/grasp.templ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, entry)
|
||||||
|
require.Equal(t, "grasp.templ", entry.Path)
|
||||||
|
require.False(t, entry.IsDir)
|
||||||
|
require.Equal(t, "100644", entry.Mode)
|
||||||
|
require.Equal(t, "05bce14339ece5f48c670d0592faa8dece9e8957", entry.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetObjectByPathNonExistent(t *testing.T) {
|
||||||
|
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||||
|
entry, err := GetObjectByPath(url, "refs/heads/master", "whatever/something/x/y/z/non-existent-file.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCommitDiff(t *testing.T) {
|
||||||
|
url := "https://github.com/smallhelm/diff-lines.git"
|
||||||
|
diff, err := GetCommitDiff(url, "a73592653fe9d01f948ca3035e088e45f722eca7")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, diff)
|
||||||
|
require.Len(t, diff, 5)
|
||||||
|
|
||||||
|
byPath := make(map[string]DiffFile, len(diff))
|
||||||
|
for _, f := range diff {
|
||||||
|
byPath[f.Path] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
path string
|
||||||
|
status string
|
||||||
|
}{
|
||||||
|
{".travis.yml", "added"},
|
||||||
|
{".gitignore", "added"},
|
||||||
|
{"index.js", "added"},
|
||||||
|
{"tests.js", "added"},
|
||||||
|
{"package.json", "changed"},
|
||||||
|
} {
|
||||||
|
f, ok := byPath[tc.path]
|
||||||
|
require.True(t, ok, "missing diff file %q", tc.path)
|
||||||
|
require.Equal(t, tc.status, f.Status, "%s status", tc.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
gitignore, ok := byPath[".gitignore"]
|
||||||
|
require.True(t, ok, "missing .gitignore in diff")
|
||||||
|
require.Equal(t, "/node_modules\n", string(gitignore.Content))
|
||||||
|
|
||||||
|
pkg, ok := byPath["package.json"]
|
||||||
|
require.True(t, ok, "missing package.json in diff")
|
||||||
|
require.NotEmpty(t, pkg.Lines, "package.json should have diff lines")
|
||||||
|
|
||||||
|
normalizeLineStatus := func(status string) string {
|
||||||
|
if status == "same" {
|
||||||
|
return "changed"
|
||||||
|
}
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
|
||||||
|
lineByIndex := make(map[int]DiffLine, len(pkg.Lines))
|
||||||
|
for _, line := range pkg.Lines {
|
||||||
|
lineByIndex[line.Index] = line
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedLines := []struct {
|
||||||
|
Index int
|
||||||
|
Status string
|
||||||
|
Text string
|
||||||
|
}{
|
||||||
|
{Index: 22, Status: "added", Text: " },"},
|
||||||
|
{Index: 23, Status: "added", Text: " \"homepage\": \"https://github.com/smallhelm/diff-lines#readme\","},
|
||||||
|
{Index: 24, Status: "added", Text: " \"devDependencies\": {"},
|
||||||
|
{Index: 25, Status: "added", Text: " \"tape\": \"^4.6.0\""},
|
||||||
|
{Index: 27, Status: "added", Text: " \"dependencies\": {"},
|
||||||
|
{Index: 28, Status: "added", Text: " \"diff\": \"^2.2.3\""},
|
||||||
|
{Index: 29, Status: "changed", Text: " }"},
|
||||||
|
}
|
||||||
|
|
||||||
|
actualLines := make([]struct {
|
||||||
|
Index int
|
||||||
|
Status string
|
||||||
|
Text string
|
||||||
|
}, 0, len(expectedLines))
|
||||||
|
for _, expected := range expectedLines {
|
||||||
|
line, ok := lineByIndex[expected.Index]
|
||||||
|
require.True(t, ok, "missing package.json diff line %d", expected.Index)
|
||||||
|
actualLines = append(actualLines, struct {
|
||||||
|
Index int
|
||||||
|
Status string
|
||||||
|
Text string
|
||||||
|
}{
|
||||||
|
Index: line.Index,
|
||||||
|
Status: normalizeLineStatus(line.Status),
|
||||||
|
Text: line.Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, expectedLines, actualLines)
|
||||||
|
}
|
||||||
@@ -0,0 +1,154 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var NecessaryCapabilities = []string{
|
||||||
|
"multi_ack_detailed",
|
||||||
|
"side-band-64k",
|
||||||
|
}
|
||||||
|
|
||||||
|
var RequiredCapabilities = []string{
|
||||||
|
"shallow",
|
||||||
|
"object-format=sha1",
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultCapabilities = []string{
|
||||||
|
"ofs-delta",
|
||||||
|
"no-progress",
|
||||||
|
}
|
||||||
|
|
||||||
|
type MissingRef struct{}
|
||||||
|
|
||||||
|
func (e *MissingRef) Error() string { return "missing ref" }
|
||||||
|
|
||||||
|
type InvalidCommit struct {
|
||||||
|
Commit string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *InvalidCommit) Error() string {
|
||||||
|
return fmt.Sprintf("invalid commit '%s', must be 20 byte hex", e.Commit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchPackfile(url string, want string) (*PackfileResult, error) {
|
||||||
|
req, err := http.NewRequest("POST", url+"/git-upload-pack", strings.NewReader(want))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create git-upload-pack request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-git-upload-pack-request")
|
||||||
|
req.Header.Set("Accept", "application/x-git-upload-pack-result")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to call git-upload-pack: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("failed to call git-upload-pack: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read git-upload-pack response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty response")
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
for offset < len(data) {
|
||||||
|
prev := offset
|
||||||
|
if prev+1 >= len(data) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nlIdx := bytes.IndexByte(data[prev+1:], '\n')
|
||||||
|
if nlIdx == -1 {
|
||||||
|
if len(data) >= 32 && string(data[4:32]) == "ERR upload-pack: not our ref" {
|
||||||
|
return nil, &MissingRef{}
|
||||||
|
}
|
||||||
|
end := len(data)
|
||||||
|
if end > 63 {
|
||||||
|
end = 63
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected '%s'", string(data[:end]))
|
||||||
|
}
|
||||||
|
offset = prev + nlIdx + 1
|
||||||
|
if offset >= 3 && string(data[offset-3:offset]) == "NAK" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
offset++
|
||||||
|
|
||||||
|
var packfileData []byte
|
||||||
|
for offset < len(data) {
|
||||||
|
if offset+5 > len(data) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pktLen, err := strconv.ParseInt(string(data[offset:offset+4]), 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
length := int(pktLen)
|
||||||
|
if length == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if offset+length > len(data) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if data[offset+4] == 2 {
|
||||||
|
// progress message, ignore
|
||||||
|
} else if data[offset+4] == 1 {
|
||||||
|
packfileData = append(packfileData, data[offset+5:offset+length]...)
|
||||||
|
}
|
||||||
|
offset += length
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packfileData) == 0 {
|
||||||
|
return nil, &MissingRef{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParsePackfile(packfileData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateWantRequest(commitSha string, capabilities []string, deepen *int, filter string) (string, error) {
|
||||||
|
if len(commitSha) != 40 {
|
||||||
|
return "", &InvalidCommit{Commit: commitSha}
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
|
||||||
|
wantLine := fmt.Sprintf("want %s %s agent=nsa/1.0.0\n", commitSha, strings.Join(capabilities, " "))
|
||||||
|
buf.WriteString(pktEncode(wantLine))
|
||||||
|
|
||||||
|
if deepen != nil {
|
||||||
|
deepenLine := fmt.Sprintf("deepen %d\n", *deepen)
|
||||||
|
buf.WriteString(pktEncode(deepenLine))
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter != "" {
|
||||||
|
filterLine := fmt.Sprintf("filter %s\n", filter)
|
||||||
|
buf.WriteString(pktEncode(filterLine))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteString("0000")
|
||||||
|
buf.WriteString(pktEncode("done\n"))
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pktEncode(data string) string {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "0000"
|
||||||
|
}
|
||||||
|
length := len(data) + 4
|
||||||
|
return fmt.Sprintf("%04x%s", length, data)
|
||||||
|
}
|
||||||
@@ -0,0 +1,307 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/zlib"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ObjectTypeCommit = 1
|
||||||
|
ObjectTypeTree = 2
|
||||||
|
ObjectTypeBlob = 3
|
||||||
|
ObjectTypeTag = 4
|
||||||
|
ObjectTypeOfsDelta = 6
|
||||||
|
ObjectTypeRefDelta = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParsedObject struct {
|
||||||
|
Type int
|
||||||
|
Size int
|
||||||
|
Data []byte
|
||||||
|
Offset int
|
||||||
|
Hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PackfileResult struct {
|
||||||
|
Version int
|
||||||
|
Count int
|
||||||
|
Objects map[string]*ParsedObject
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParsePackfile(data []byte) (*PackfileResult, error) {
|
||||||
|
if len(data) < 12 {
|
||||||
|
return nil, fmt.Errorf("packfile too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
header := string(data[0:4])
|
||||||
|
if header != "PACK" {
|
||||||
|
return nil, fmt.Errorf("invalid packfile header: %s", header)
|
||||||
|
}
|
||||||
|
|
||||||
|
version := int(binary.BigEndian.Uint32(data[4:8]))
|
||||||
|
if version != 2 {
|
||||||
|
return nil, fmt.Errorf("unsupported packfile version: %d", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
count := int(binary.BigEndian.Uint32(data[8:12]))
|
||||||
|
|
||||||
|
objects := make(map[string]*ParsedObject)
|
||||||
|
pos := 12
|
||||||
|
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
obj, newPos, err := parsePackObject(data, pos, objects)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing object %d/%d: %w", i+1, count, err)
|
||||||
|
}
|
||||||
|
objects[obj.Hash] = obj
|
||||||
|
pos = newPos
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PackfileResult{
|
||||||
|
Version: version,
|
||||||
|
Count: count,
|
||||||
|
Objects: objects,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePackObject(data []byte, startPos int, objects map[string]*ParsedObject) (*ParsedObject, int, error) {
|
||||||
|
pos := startPos
|
||||||
|
offset := startPos
|
||||||
|
|
||||||
|
b := data[pos]
|
||||||
|
pos++
|
||||||
|
objType := int((b >> 4) & 0x07)
|
||||||
|
size := int(b & 0x0f)
|
||||||
|
shift := 4
|
||||||
|
|
||||||
|
for b&0x80 != 0 {
|
||||||
|
b = data[pos]
|
||||||
|
pos++
|
||||||
|
size |= int(b&0x7f) << shift
|
||||||
|
shift += 7
|
||||||
|
}
|
||||||
|
|
||||||
|
var objData []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch objType {
|
||||||
|
case ObjectTypeOfsDelta:
|
||||||
|
var actualType int
|
||||||
|
objData, pos, actualType, err = parseOfsDelta(data, pos, offset, objects)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
objType = actualType
|
||||||
|
case ObjectTypeRefDelta:
|
||||||
|
var actualType int
|
||||||
|
objData, pos, actualType, err = parseRefDelta(data, pos, objects)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
objType = actualType
|
||||||
|
case ObjectTypeCommit, ObjectTypeTree, ObjectTypeBlob, ObjectTypeTag:
|
||||||
|
objData, pos, err = zlibDecompress(data, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, 0, fmt.Errorf("unknown object type: %d", objType)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash, err := computeObjectHash(objType, objData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ParsedObject{
|
||||||
|
Type: objType,
|
||||||
|
Size: size,
|
||||||
|
Data: objData,
|
||||||
|
Offset: offset,
|
||||||
|
Hash: hash,
|
||||||
|
}, pos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOfsDelta(data []byte, pos int, currentOffset int, objects map[string]*ParsedObject) ([]byte, int, int, error) {
|
||||||
|
b := data[pos]
|
||||||
|
pos++
|
||||||
|
offset := int(b & 0x7f)
|
||||||
|
|
||||||
|
for b&0x80 != 0 {
|
||||||
|
offset++
|
||||||
|
offset <<= 7
|
||||||
|
b = data[pos]
|
||||||
|
pos++
|
||||||
|
offset += int(b & 0x7f)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseOffset := currentOffset - offset
|
||||||
|
baseObject, _, err := parsePackObject(data, baseOffset, objects)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, fmt.Errorf("failed to parse base object at offset %d: %w", baseOffset, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delta, newPos, err := zlibDecompress(data, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fullObj, err := applyDelta(delta, baseObject.Data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullObj, newPos, baseObject.Type, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRefDelta(data []byte, pos int, objects map[string]*ParsedObject) ([]byte, int, int, error) {
|
||||||
|
baseName := hex.EncodeToString(data[pos : pos+20])
|
||||||
|
pos += 20
|
||||||
|
|
||||||
|
delta, newPos, err := zlibDecompress(data, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
baseObject, ok := objects[baseName]
|
||||||
|
if !ok {
|
||||||
|
return nil, 0, 0, fmt.Errorf("base object not found with name %s", baseName)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullObj, err := applyDelta(delta, baseObject.Data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullObj, newPos, baseObject.Type, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeObjectHash(objType int, data []byte) (string, error) {
|
||||||
|
var typeStr string
|
||||||
|
switch objType {
|
||||||
|
case ObjectTypeCommit:
|
||||||
|
typeStr = "commit"
|
||||||
|
case ObjectTypeTree:
|
||||||
|
typeStr = "tree"
|
||||||
|
case ObjectTypeBlob:
|
||||||
|
typeStr = "blob"
|
||||||
|
case ObjectTypeTag:
|
||||||
|
typeStr = "tag"
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unknown type when computing object hash: %d", objType)
|
||||||
|
}
|
||||||
|
|
||||||
|
header := fmt.Sprintf("%s %d\x00", typeStr, len(data))
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write([]byte(header))
|
||||||
|
h.Write(data)
|
||||||
|
return hex.EncodeToString(h.Sum(nil)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyDelta(delta []byte, base []byte) ([]byte, error) {
|
||||||
|
pos := 0
|
||||||
|
|
||||||
|
_, bytesRead := readVariableInt(delta, pos)
|
||||||
|
pos += bytesRead
|
||||||
|
|
||||||
|
resultSize, bytesRead := readVariableInt(delta, pos)
|
||||||
|
pos += bytesRead
|
||||||
|
|
||||||
|
result := make([]byte, resultSize)
|
||||||
|
resultOffset := 0
|
||||||
|
|
||||||
|
for pos < len(delta) {
|
||||||
|
cmd := delta[pos]
|
||||||
|
pos++
|
||||||
|
|
||||||
|
if cmd&0x80 != 0 {
|
||||||
|
var copyOffset, copySize int
|
||||||
|
|
||||||
|
if cmd&0x01 != 0 {
|
||||||
|
copyOffset = int(delta[pos])
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
if cmd&0x02 != 0 {
|
||||||
|
copyOffset |= int(delta[pos]) << 8
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
if cmd&0x04 != 0 {
|
||||||
|
copyOffset |= int(delta[pos]) << 16
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
if cmd&0x08 != 0 {
|
||||||
|
copyOffset |= int(delta[pos]) << 24
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd&0x10 != 0 {
|
||||||
|
copySize = int(delta[pos])
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
if cmd&0x20 != 0 {
|
||||||
|
copySize |= int(delta[pos]) << 8
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
if cmd&0x40 != 0 {
|
||||||
|
copySize |= int(delta[pos]) << 16
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
|
||||||
|
if copySize == 0 {
|
||||||
|
copySize = 0x10000
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(result[resultOffset:], base[copyOffset:copyOffset+copySize])
|
||||||
|
resultOffset += copySize
|
||||||
|
} else if cmd > 0 {
|
||||||
|
copy(result[resultOffset:], delta[pos:pos+int(cmd)])
|
||||||
|
pos += int(cmd)
|
||||||
|
resultOffset += int(cmd)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("invalid delta command")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func zlibDecompress(data []byte, pos int) ([]byte, int, error) {
|
||||||
|
br := bytes.NewReader(data[pos:])
|
||||||
|
r, err := zlib.NewReader(br)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("zlib init error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decompressed, err := io.ReadAll(r)
|
||||||
|
r.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("zlib decompress error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newPos := len(data) - br.Len()
|
||||||
|
return decompressed, newPos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readVariableInt(data []byte, pos int) (int, int) {
|
||||||
|
value := 0
|
||||||
|
shift := 0
|
||||||
|
bytesRead := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
b := data[pos]
|
||||||
|
pos++
|
||||||
|
bytesRead++
|
||||||
|
value |= int(b&0x7f) << shift
|
||||||
|
shift += 7
|
||||||
|
if b&0x80 == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, bytesRead
|
||||||
|
}
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InfoRefsUploadPackResponse struct {
|
||||||
|
Refs map[string]string
|
||||||
|
Capabilities []string
|
||||||
|
Symrefs map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
var capabilitiesCache sync.Map
|
||||||
|
|
||||||
|
func GetCapabilities(url string, existingInfo *InfoRefsUploadPackResponse) ([]string, error) {
|
||||||
|
if existingInfo != nil {
|
||||||
|
capabilitiesCache.Store(url, existingInfo.Capabilities)
|
||||||
|
return existingInfo.Capabilities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cached, ok := capabilitiesCache.Load(url); ok {
|
||||||
|
return cached.([]string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := GetInfoRefs(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
capabilitiesCache.Store(url, info.Capabilities)
|
||||||
|
return info.Capabilities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetInfoRefs(url string) (*InfoRefsUploadPackResponse, error) {
|
||||||
|
resp, err := http.Get(url + "/info/refs?service=git-upload-pack")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch info/refs: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read info/refs response: %w", err)
|
||||||
|
}
|
||||||
|
response := string(body)
|
||||||
|
|
||||||
|
result := &InfoRefsUploadPackResponse{
|
||||||
|
Refs: make(map[string]string),
|
||||||
|
Symrefs: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(response, "\n")
|
||||||
|
firstRef := true
|
||||||
|
for _, line := range lines {
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "0000") {
|
||||||
|
line = line[4:]
|
||||||
|
}
|
||||||
|
if len(line) < 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
length, err := strconv.ParseInt(line[:4], 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
endIdx := int(length)
|
||||||
|
if endIdx > len(line) {
|
||||||
|
endIdx = len(line)
|
||||||
|
}
|
||||||
|
if endIdx <= 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := line[4:endIdx]
|
||||||
|
|
||||||
|
if firstRef && strings.HasPrefix(content, "# service=") {
|
||||||
|
firstRef = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(content, " ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(content, " ", 2)
|
||||||
|
hash := parts[0]
|
||||||
|
refAndCaps := parts[1]
|
||||||
|
|
||||||
|
if strings.Contains(refAndCaps, "\x00") {
|
||||||
|
nulParts := strings.SplitN(refAndCaps, "\x00", 2)
|
||||||
|
ref := strings.TrimSpace(nulParts[0])
|
||||||
|
result.Refs[ref] = hash
|
||||||
|
|
||||||
|
caps := strings.Fields(nulParts[1])
|
||||||
|
result.Capabilities = caps
|
||||||
|
|
||||||
|
for _, cap := range caps {
|
||||||
|
if strings.HasPrefix(cap, "symref=") {
|
||||||
|
symrefData := cap[7:]
|
||||||
|
colonIdx := strings.Index(symrefData, ":")
|
||||||
|
if colonIdx != -1 {
|
||||||
|
result.Symrefs[symrefData[:colonIdx]] = symrefData[colonIdx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.Refs[strings.TrimSpace(refAndCaps)] = hash
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
package gitnaturalapi
|
||||||
|
|
||||||
|
import "encoding/hex"
|
||||||
|
|
||||||
|
type TreeEntry struct {
|
||||||
|
Path string
|
||||||
|
Mode string
|
||||||
|
IsDir bool
|
||||||
|
Hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TreeFile struct {
|
||||||
|
Name string
|
||||||
|
Hash string
|
||||||
|
Content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type TreeDirectory struct {
|
||||||
|
Name string
|
||||||
|
Hash string
|
||||||
|
Content *Tree
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tree struct {
|
||||||
|
Directories []TreeDirectory
|
||||||
|
Files []TreeFile
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadTree(obj *ParsedObject, objects map[string]*ParsedObject, depth *int) *Tree {
|
||||||
|
directories := make([]TreeDirectory, 0)
|
||||||
|
files := make([]TreeFile, 0)
|
||||||
|
entries := ParseTree(obj.Data)
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
child := objects[entry.Hash]
|
||||||
|
|
||||||
|
if entry.IsDir {
|
||||||
|
var content *Tree
|
||||||
|
if child != nil && (depth == nil || *depth > 0) {
|
||||||
|
var newDepth *int
|
||||||
|
if depth != nil {
|
||||||
|
d := *depth - 1
|
||||||
|
newDepth = &d
|
||||||
|
}
|
||||||
|
content = LoadTree(child, objects, newDepth)
|
||||||
|
}
|
||||||
|
directories = append(directories, TreeDirectory{
|
||||||
|
Name: entry.Path,
|
||||||
|
Hash: entry.Hash,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
var content []byte
|
||||||
|
if child != nil {
|
||||||
|
content = child.Data
|
||||||
|
}
|
||||||
|
files = append(files, TreeFile{
|
||||||
|
Name: entry.Path,
|
||||||
|
Hash: entry.Hash,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Tree{Directories: directories, Files: files}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseTree(treeData []byte) []TreeEntry {
|
||||||
|
entries := make([]TreeEntry, 0)
|
||||||
|
offset := 0
|
||||||
|
|
||||||
|
for offset < len(treeData) {
|
||||||
|
modeEnd := offset
|
||||||
|
for treeData[modeEnd] != 0x20 {
|
||||||
|
modeEnd++
|
||||||
|
}
|
||||||
|
mode := string(treeData[offset:modeEnd])
|
||||||
|
offset = modeEnd + 1
|
||||||
|
|
||||||
|
filenameEnd := offset
|
||||||
|
for treeData[filenameEnd] != 0x00 {
|
||||||
|
filenameEnd++
|
||||||
|
}
|
||||||
|
path := string(treeData[offset:filenameEnd])
|
||||||
|
offset = filenameEnd + 1
|
||||||
|
|
||||||
|
hash := hex.EncodeToString(treeData[offset : offset+20])
|
||||||
|
offset += 20
|
||||||
|
|
||||||
|
isDir := mode == "40000" || mode == "040000"
|
||||||
|
|
||||||
|
entries = append(entries, TreeEntry{
|
||||||
|
Mode: mode,
|
||||||
|
Path: path,
|
||||||
|
Hash: hash,
|
||||||
|
IsDir: isDir,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return entries
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package grasp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"fiatjaf.com/nostr/nip19"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsGraspURL(u string) bool {
|
||||||
|
parsed, err := url.Parse(u)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Count(parsed.Path, "/") != 2 || len(parsed.Path) < 65 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix, _, err := nip19.Decode(parsed.Path[1:64]); err != nil || prefix != "npub" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
+1
-1
@@ -112,7 +112,7 @@ func NewBunker(
|
|||||||
onAuth func(string),
|
onAuth func(string),
|
||||||
) *BunkerClient {
|
) *BunkerClient {
|
||||||
if pool == nil {
|
if pool == nil {
|
||||||
pool = nostr.NewPool(nostr.PoolOptions{})
|
pool = nostr.NewPool()
|
||||||
}
|
}
|
||||||
|
|
||||||
clientPublicKey := nostr.GetPublicKey(clientSecretKey)
|
clientPublicKey := nostr.GetPublicKey(clientSecretKey)
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func NewBunkerFromNostrConnect(
|
|||||||
pool *nostr.Pool,
|
pool *nostr.Pool,
|
||||||
) (*BunkerClient, error) {
|
) (*BunkerClient, error) {
|
||||||
if pool == nil {
|
if pool == nil {
|
||||||
pool = nostr.NewPool(nostr.PoolOptions{})
|
pool = nostr.NewPool()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(relayURLs) == 0 {
|
if len(relayURLs) == 0 {
|
||||||
|
|||||||
+12
-8
@@ -11,20 +11,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func NormalizeIdentifier(name string) string {
|
func NormalizeIdentifier(name string) string {
|
||||||
name = strings.TrimSpace(strings.ToLower(name))
|
|
||||||
res, _, _ := transform.Bytes(norm.NFKC, []byte(name))
|
res, _, _ := transform.Bytes(norm.NFKC, []byte(name))
|
||||||
runes := []rune(string(res))
|
runes := []rune(strings.ToLower(string(res)))
|
||||||
|
|
||||||
b := make([]rune, len(runes))
|
words := make([]string, 0, 3)
|
||||||
for i, letter := range runes {
|
word := make([]rune, 0, 12)
|
||||||
|
for _, letter := range runes {
|
||||||
if unicode.IsLetter(letter) || unicode.IsNumber(letter) {
|
if unicode.IsLetter(letter) || unicode.IsNumber(letter) {
|
||||||
b[i] = letter
|
word = append(word, letter)
|
||||||
} else {
|
} else if len(word) > 0 {
|
||||||
b[i] = '-'
|
words = append(words, string(word))
|
||||||
|
word = make([]rune, 0, 12)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(word) > 0 {
|
||||||
|
words = append(words, string(word))
|
||||||
|
}
|
||||||
|
|
||||||
return string(b)
|
return strings.Join(words, "-")
|
||||||
}
|
}
|
||||||
|
|
||||||
func ArticleAsHTML(content string) string {
|
func ArticleAsHTML(content string) string {
|
||||||
|
|||||||
+1
-1
@@ -13,7 +13,7 @@ func TestNormalization(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{" hello ", "hello"},
|
{" hello ", "hello"},
|
||||||
{"Goodbye", "goodbye"},
|
{"Goodbye", "goodbye"},
|
||||||
{"the long and winding road / that leads to your door", "the-long-and-winding-road---that-leads-to-your-door"},
|
{"the long and winding road / that leads to your door", "the-long-and-winding-road-that-leads-to-your-door"},
|
||||||
{"it's 平仮名", "it-s-平仮名"},
|
{"it's 平仮名", "it-s-平仮名"},
|
||||||
} {
|
} {
|
||||||
if norm := NormalizeIdentifier(vector.before); norm != vector.after {
|
if norm := NormalizeIdentifier(vector.before); norm != vector.after {
|
||||||
|
|||||||
+40
-8
@@ -94,6 +94,11 @@ meltworked:
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// mark tokens as reserved before attempting melt
|
||||||
|
for _, i := range chosen.tokenIndexes {
|
||||||
|
w.Tokens[i].reserved = true
|
||||||
|
}
|
||||||
|
|
||||||
// request from mint to _melt_ into paying the invoice
|
// request from mint to _melt_ into paying the invoice
|
||||||
delay := 200 * time.Millisecond
|
delay := 200 * time.Millisecond
|
||||||
// this request will block until the invoice is paid or it fails
|
// this request will block until the invoice is paid or it fails
|
||||||
@@ -103,17 +108,44 @@ meltworked:
|
|||||||
Inputs: chosen.proofs,
|
Inputs: chosen.proofs,
|
||||||
Outputs: preChange.bm,
|
Outputs: preChange.bm,
|
||||||
})
|
})
|
||||||
inspectmeltstatusresponse:
|
for {
|
||||||
if err != nil || meltStatus.State == nut05.Unpaid {
|
if err != nil || meltStatus.State == nut05.Unpaid {
|
||||||
return "", fmt.Errorf("error melting token: %w", err)
|
// unreserve tokens to available state on failure
|
||||||
} else if meltStatus.State == nut05.Unknown {
|
for _, i := range chosen.tokenIndexes {
|
||||||
return "", fmt.Errorf("we don't know what happened with the melt at %s: %v", chosen.mint, meltStatus)
|
w.Tokens[i].reserved = false
|
||||||
} else if meltStatus.State == nut05.Pending {
|
}
|
||||||
for {
|
return "", fmt.Errorf("error melting token: %w", err)
|
||||||
|
} else if meltStatus.State == nut05.Unknown {
|
||||||
|
// unreserve tokens to available state on failure
|
||||||
|
for _, i := range chosen.tokenIndexes {
|
||||||
|
w.Tokens[i].reserved = false
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("we don't know what happened with the melt at %s: %v", chosen.mint, meltStatus)
|
||||||
|
} else if meltStatus.State == nut05.Pending {
|
||||||
time.Sleep(delay)
|
time.Sleep(delay)
|
||||||
delay *= 2
|
delay *= 2
|
||||||
meltStatus, err = client.GetMeltQuoteState(ctx, chosen.mint, meltStatus.Quote)
|
meltStatus, err = client.GetMeltQuoteState(ctx, chosen.mint, meltStatus.Quote)
|
||||||
goto inspectmeltstatusresponse
|
if err != nil {
|
||||||
|
// unreserve tokens to available state on failure
|
||||||
|
for _, i := range chosen.tokenIndexes {
|
||||||
|
w.Tokens[i].reserved = false
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("error checking melt status: %w", err)
|
||||||
|
}
|
||||||
|
if meltStatus.State == nut05.Unpaid || meltStatus.State == nut05.Unknown {
|
||||||
|
// unreserve tokens to available state on failure
|
||||||
|
for _, i := range chosen.tokenIndexes {
|
||||||
|
w.Tokens[i].reserved = false
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("melt failed with state %v", meltStatus.State)
|
||||||
|
} else if meltStatus.State == nut05.Paid {
|
||||||
|
// payment successful
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// continue looping for pending state
|
||||||
|
continue
|
||||||
|
} else if meltStatus.State == nut05.Paid {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -153,6 +153,9 @@ func (w *Wallet) getProofsForSending(
|
|||||||
) (chosenTokens, uint64, error) {
|
) (chosenTokens, uint64, error) {
|
||||||
byMint := make(map[string]chosenTokens)
|
byMint := make(map[string]chosenTokens)
|
||||||
for t, token := range w.Tokens {
|
for t, token := range w.Tokens {
|
||||||
|
if token.reserved {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if fromMint != "" && token.Mint != fromMint {
|
if fromMint != "" && token.Mint != fromMint {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ type Token struct {
|
|||||||
Proofs cashu.Proofs `json:"proofs"`
|
Proofs cashu.Proofs `json:"proofs"`
|
||||||
Deleted []nostr.ID `json:"del,omitempty"`
|
Deleted []nostr.ID `json:"del,omitempty"`
|
||||||
|
|
||||||
|
reserved bool
|
||||||
mintedAt nostr.Timestamp
|
mintedAt nostr.Timestamp
|
||||||
event *nostr.Event
|
event *nostr.Event
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -249,6 +249,10 @@ func (w *Wallet) removeDeletedToken(eventId nostr.ID) {
|
|||||||
func (w *Wallet) Balance() uint64 {
|
func (w *Wallet) Balance() uint64 {
|
||||||
var sum uint64
|
var sum uint64
|
||||||
for _, token := range w.Tokens {
|
for _, token := range w.Tokens {
|
||||||
|
if token.reserved {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
sum += token.Proofs.Amount()
|
sum += token.Proofs.Amount()
|
||||||
}
|
}
|
||||||
return sum
|
return sum
|
||||||
|
|||||||
@@ -66,13 +66,13 @@ func (bw *BoundWriter) WriteTimestamp(w *bytes.Buffer, timestamp nostr.Timestamp
|
|||||||
bw.lastTimestampOut = timestamp
|
bw.lastTimestampOut = timestamp
|
||||||
|
|
||||||
// add 1 to prevent zeroes from being read as infinites
|
// add 1 to prevent zeroes from being read as infinites
|
||||||
WriteVarInt(w, int(delta+1))
|
WriteVarInt(w, uint64(delta)+1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bw *BoundWriter) WriteBound(w *bytes.Buffer, bound Bound) {
|
func (bw *BoundWriter) WriteBound(w *bytes.Buffer, bound Bound) {
|
||||||
bw.WriteTimestamp(w, bound.Timestamp)
|
bw.WriteTimestamp(w, bound.Timestamp)
|
||||||
WriteVarInt(w, len(bound.IDPrefix))
|
WriteVarInt(w, uint64(len(bound.IDPrefix)))
|
||||||
w.Write(bound.IDPrefix)
|
w.Write(bound.IDPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,33 +111,25 @@ func ReadVarInt(reader *bytes.Reader) (int, error) {
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteVarInt(w *bytes.Buffer, n int) {
|
func WriteVarInt(w *bytes.Buffer, n uint64) {
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
w.WriteByte(0)
|
w.WriteByte(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Write(EncodeVarInt(n))
|
var buf [10]byte
|
||||||
}
|
idx := 9
|
||||||
|
|
||||||
func EncodeVarInt(n int) []byte {
|
|
||||||
if n == 0 {
|
|
||||||
return []byte{0}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]byte, 8)
|
|
||||||
idx := 7
|
|
||||||
|
|
||||||
for n != 0 {
|
for n != 0 {
|
||||||
result[idx] = byte(n & 0x7F)
|
buf[idx] = byte(n & 0x7F)
|
||||||
n >>= 7
|
n >>= 7
|
||||||
idx--
|
idx--
|
||||||
}
|
}
|
||||||
|
|
||||||
result = result[idx+1:]
|
result := buf[idx+1:]
|
||||||
for i := 0; i < len(result)-1; i++ {
|
for i := 0; i < len(result)-1; i++ {
|
||||||
result[i] |= 0x80
|
result[i] |= 0x80
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
w.Write(result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ func (n *Negentropy) reconcileAux(reader *bytes.Reader) ([]byte, error) {
|
|||||||
finishSkip()
|
finishSkip()
|
||||||
|
|
||||||
responseIds := make([]byte, 0, 32*100)
|
responseIds := make([]byte, 0, 32*100)
|
||||||
responses := 0
|
var responses uint64 = 0
|
||||||
|
|
||||||
endBound := currBound
|
endBound := currBound
|
||||||
|
|
||||||
@@ -284,7 +284,7 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *byte
|
|||||||
// we just send the full ids here
|
// we just send the full ids here
|
||||||
n.WriteBound(output, upperBound)
|
n.WriteBound(output, upperBound)
|
||||||
output.WriteByte(byte(IdListMode))
|
output.WriteByte(byte(IdListMode))
|
||||||
WriteVarInt(output, numElems)
|
WriteVarInt(output, uint64(numElems))
|
||||||
|
|
||||||
for _, item := range n.storage.Range(lower, upper) {
|
for _, item := range n.storage.Range(lower, upper) {
|
||||||
output.Write(item.ID[:])
|
output.Write(item.ID[:])
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
@@ -41,8 +42,8 @@ func (acc *Accumulator) AddBytes(other []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (acc *Accumulator) GetFingerprint(n int) [negentropy.FingerprintSize]byte {
|
func (acc *Accumulator) GetFingerprint(n int) [negentropy.FingerprintSize]byte {
|
||||||
input := acc.Buf[:32]
|
input := bytes.NewBuffer(acc.Buf[:32])
|
||||||
input = append(input, negentropy.EncodeVarInt(n)...)
|
negentropy.WriteVarInt(input, uint64(n))
|
||||||
hash := sha256.Sum256(input)
|
hash := sha256.Sum256(input.Bytes())
|
||||||
return [negentropy.FingerprintSize]byte(hash[:negentropy.FingerprintSize])
|
return [negentropy.FingerprintSize]byte(hash[:negentropy.FingerprintSize])
|
||||||
}
|
}
|
||||||
|
|||||||
+16
-6
@@ -128,17 +128,27 @@ func NegentropySync(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(done)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
errch <- nil
|
select {
|
||||||
|
case errch <- nil:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = <-errch
|
select {
|
||||||
if err != nil {
|
case err = <-errch:
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SyncEventsFromIDs(ctx context.Context, dir Direction) {
|
func SyncEventsFromIDs(ctx context.Context, dir Direction) {
|
||||||
|
|||||||
@@ -32,6 +32,24 @@ func DecodeRequest(req Request) (MethodParams, error) {
|
|||||||
return BanPubKey{pk, reason}, nil
|
return BanPubKey{pk, reason}, nil
|
||||||
case "listbannedpubkeys":
|
case "listbannedpubkeys":
|
||||||
return ListBannedPubKeys{}, nil
|
return ListBannedPubKeys{}, nil
|
||||||
|
case "unbanpubkey":
|
||||||
|
if len(req.Params) == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid number of params for '%s'", req.Method)
|
||||||
|
}
|
||||||
|
pkh, ok := req.Params[0].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("missing pubkey param for '%s'", req.Method)
|
||||||
|
}
|
||||||
|
pk, err := nostr.PubKeyFromHex(pkh)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid pubkey param for '%s'", req.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reason string
|
||||||
|
if len(req.Params) >= 2 {
|
||||||
|
reason, _ = req.Params[1].(string)
|
||||||
|
}
|
||||||
|
return UnbanPubKey{pk, reason}, nil
|
||||||
case "allowpubkey":
|
case "allowpubkey":
|
||||||
if len(req.Params) == 0 {
|
if len(req.Params) == 0 {
|
||||||
return nil, fmt.Errorf("invalid number of params for '%s'", req.Method)
|
return nil, fmt.Errorf("invalid number of params for '%s'", req.Method)
|
||||||
@@ -52,6 +70,24 @@ func DecodeRequest(req Request) (MethodParams, error) {
|
|||||||
return AllowPubKey{pk, reason}, nil
|
return AllowPubKey{pk, reason}, nil
|
||||||
case "listallowedpubkeys":
|
case "listallowedpubkeys":
|
||||||
return ListAllowedPubKeys{}, nil
|
return ListAllowedPubKeys{}, nil
|
||||||
|
case "unallowpubkey":
|
||||||
|
if len(req.Params) == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid number of params for '%s'", req.Method)
|
||||||
|
}
|
||||||
|
pkh, ok := req.Params[0].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("missing pubkey param for '%s'", req.Method)
|
||||||
|
}
|
||||||
|
pk, err := nostr.PubKeyFromHex(pkh)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid pubkey param for '%s'", req.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reason string
|
||||||
|
if len(req.Params) >= 2 {
|
||||||
|
reason, _ = req.Params[1].(string)
|
||||||
|
}
|
||||||
|
return UnallowPubKey{pk, reason}, nil
|
||||||
case "listeventsneedingmoderation":
|
case "listeventsneedingmoderation":
|
||||||
return ListEventsNeedingModeration{}, nil
|
return ListEventsNeedingModeration{}, nil
|
||||||
case "allowevent":
|
case "allowevent":
|
||||||
@@ -219,8 +255,10 @@ var (
|
|||||||
_ MethodParams = (*SupportedMethods)(nil)
|
_ MethodParams = (*SupportedMethods)(nil)
|
||||||
_ MethodParams = (*BanPubKey)(nil)
|
_ MethodParams = (*BanPubKey)(nil)
|
||||||
_ MethodParams = (*ListBannedPubKeys)(nil)
|
_ MethodParams = (*ListBannedPubKeys)(nil)
|
||||||
|
_ MethodParams = (*UnbanPubKey)(nil)
|
||||||
_ MethodParams = (*AllowPubKey)(nil)
|
_ MethodParams = (*AllowPubKey)(nil)
|
||||||
_ MethodParams = (*ListAllowedPubKeys)(nil)
|
_ MethodParams = (*ListAllowedPubKeys)(nil)
|
||||||
|
_ MethodParams = (*UnallowPubKey)(nil)
|
||||||
_ MethodParams = (*ListEventsNeedingModeration)(nil)
|
_ MethodParams = (*ListEventsNeedingModeration)(nil)
|
||||||
_ MethodParams = (*AllowEvent)(nil)
|
_ MethodParams = (*AllowEvent)(nil)
|
||||||
_ MethodParams = (*BanEvent)(nil)
|
_ MethodParams = (*BanEvent)(nil)
|
||||||
@@ -256,6 +294,13 @@ type ListBannedPubKeys struct{}
|
|||||||
|
|
||||||
func (ListBannedPubKeys) MethodName() string { return "listbannedpubkeys" }
|
func (ListBannedPubKeys) MethodName() string { return "listbannedpubkeys" }
|
||||||
|
|
||||||
|
type UnbanPubKey struct {
|
||||||
|
PubKey nostr.PubKey
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnbanPubKey) MethodName() string { return "unbanpubkey" }
|
||||||
|
|
||||||
type AllowPubKey struct {
|
type AllowPubKey struct {
|
||||||
PubKey nostr.PubKey
|
PubKey nostr.PubKey
|
||||||
Reason string
|
Reason string
|
||||||
@@ -267,6 +312,13 @@ type ListAllowedPubKeys struct{}
|
|||||||
|
|
||||||
func (ListAllowedPubKeys) MethodName() string { return "listallowedpubkeys" }
|
func (ListAllowedPubKeys) MethodName() string { return "listallowedpubkeys" }
|
||||||
|
|
||||||
|
type UnallowPubKey struct {
|
||||||
|
PubKey nostr.PubKey
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnallowPubKey) MethodName() string { return "unallowpubkey" }
|
||||||
|
|
||||||
type ListEventsNeedingModeration struct{}
|
type ListEventsNeedingModeration struct{}
|
||||||
|
|
||||||
func (ListEventsNeedingModeration) MethodName() string { return "listeventsneedingmoderation" }
|
func (ListEventsNeedingModeration) MethodName() string { return "listeventsneedingmoderation" }
|
||||||
|
|||||||
@@ -58,6 +58,12 @@ func (c *Client) httpCall(
|
|||||||
}
|
}
|
||||||
if resp.Header.StatusCode() >= 300 {
|
if resp.Header.StatusCode() >= 300 {
|
||||||
reason := resp.Header.Peek("X-Reason")
|
reason := resp.Header.Peek("X-Reason")
|
||||||
|
if len(reason) == 0 {
|
||||||
|
reason = resp.Body()
|
||||||
|
if len(reason) > 200 {
|
||||||
|
reason = append(reason[0:199], []byte("…")...)
|
||||||
|
}
|
||||||
|
}
|
||||||
return fmt.Errorf("%s returned an error (%d): %s", url, resp.StatusCode(), string(reason))
|
return fmt.Errorf("%s returned an error (%d): %s", url, resp.StatusCode(), string(reason))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+11
-3
@@ -9,6 +9,7 @@ func GetExtension(mimetype string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hardcode some common cases (abd jbiwb oribkenatuc cases kuje ,ogg/.oga or .mov/.moov)
|
||||||
switch mimetype {
|
switch mimetype {
|
||||||
case "image/jpeg":
|
case "image/jpeg":
|
||||||
return ".jpg"
|
return ".jpg"
|
||||||
@@ -22,13 +23,20 @@ func GetExtension(mimetype string) string {
|
|||||||
return ".mp4"
|
return ".mp4"
|
||||||
case "application/vnd.android.package-archive":
|
case "application/vnd.android.package-archive":
|
||||||
return ".apk"
|
return ".apk"
|
||||||
|
case "video/quicktime":
|
||||||
|
return ".mov"
|
||||||
|
case "application/vnd.sqlite3":
|
||||||
|
return "sqlite3"
|
||||||
|
case "text/markdown":
|
||||||
|
return "md"
|
||||||
|
case "audio/midi":
|
||||||
|
return "midi"
|
||||||
|
case "audio/x-aiff":
|
||||||
|
return "aiff"
|
||||||
}
|
}
|
||||||
|
|
||||||
exts, _ := mime.ExtensionsByType(mimetype)
|
exts, _ := mime.ExtensionsByType(mimetype)
|
||||||
if len(exts) > 0 {
|
if len(exts) > 0 {
|
||||||
if exts[0] == ".moov" {
|
|
||||||
return ".mov"
|
|
||||||
}
|
|
||||||
return exts[0]
|
return exts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,14 +27,13 @@ type Pool struct {
|
|||||||
authRequiredHandler func(context.Context, *Event) error
|
authRequiredHandler func(context.Context, *Event) error
|
||||||
cancel context.CancelCauseFunc
|
cancel context.CancelCauseFunc
|
||||||
|
|
||||||
eventMiddleware func(RelayEvent)
|
EventMiddleware func(RelayEvent)
|
||||||
duplicateMiddleware func(relay string, id ID)
|
DuplicateMiddleware func(relay string, id ID)
|
||||||
queryMiddleware func(relay string, pubkey PubKey, kind Kind)
|
QueryMiddleware func(relay string, pubkey PubKey, kind Kind)
|
||||||
relayOptions RelayOptions
|
RelayOptions RelayOptions
|
||||||
|
|
||||||
// custom things not often used
|
// custom things not often used
|
||||||
penaltyBoxMu sync.Mutex
|
penaltyBox *xsync.MapOf[string, [2]float64]
|
||||||
penaltyBox map[string][2]float64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DirectedFilter combines a Filter with a specific relay URL.
|
// DirectedFilter combines a Filter with a specific relay URL.
|
||||||
@@ -50,27 +49,15 @@ func (df DirectedFilter) String() string {
|
|||||||
func (ie RelayEvent) String() string { return fmt.Sprintf("[%s] >> %s", ie.Relay.URL, ie.Event) }
|
func (ie RelayEvent) String() string { return fmt.Sprintf("[%s] >> %s", ie.Relay.URL, ie.Event) }
|
||||||
|
|
||||||
// NewPool creates a new Pool with the given context and options.
|
// NewPool creates a new Pool with the given context and options.
|
||||||
func NewPool(opts PoolOptions) *Pool {
|
func NewPool() *Pool {
|
||||||
ctx, cancel := context.WithCancelCause(context.Background())
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
|
||||||
pool := &Pool{
|
return &Pool{
|
||||||
Relays: xsync.NewMapOf[string, *Relay](),
|
Relays: xsync.NewMapOf[string, *Relay](),
|
||||||
|
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
|
|
||||||
authRequiredHandler: opts.AuthRequiredHandler,
|
|
||||||
eventMiddleware: opts.EventMiddleware,
|
|
||||||
duplicateMiddleware: opts.DuplicateMiddleware,
|
|
||||||
queryMiddleware: opts.AuthorKindQueryMiddleware,
|
|
||||||
relayOptions: opts.RelayOptions,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.PenaltyBox {
|
|
||||||
go pool.startPenaltyBox()
|
|
||||||
}
|
|
||||||
|
|
||||||
return pool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PoolOptions struct {
|
type PoolOptions struct {
|
||||||
@@ -98,36 +85,50 @@ type PoolOptions struct {
|
|||||||
RelayOptions RelayOptions
|
RelayOptions RelayOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pool *Pool) startPenaltyBox() {
|
func (pool *Pool) StartPenaltyBox() {
|
||||||
pool.penaltyBox = make(map[string][2]float64)
|
pool.penaltyBox = xsync.NewMapOf[string, [2]float64]()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
sleep := 30.0
|
sleep := 30.0
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(sleep) * time.Second)
|
select {
|
||||||
|
case <-pool.Context.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(time.Duration(sleep) * time.Second):
|
||||||
|
|
||||||
pool.penaltyBoxMu.Lock()
|
nextSleep := 300.0
|
||||||
nextSleep := 300.0
|
for url, v := range pool.penaltyBox.Range {
|
||||||
for url, v := range pool.penaltyBox {
|
remainingSeconds := v[1]
|
||||||
remainingSeconds := v[1]
|
remainingSeconds -= sleep
|
||||||
remainingSeconds -= sleep
|
if remainingSeconds <= 0 {
|
||||||
if remainingSeconds <= 0 {
|
pool.penaltyBox.Store(url, [2]float64{v[0], 0})
|
||||||
pool.penaltyBox[url] = [2]float64{v[0], 0}
|
continue
|
||||||
continue
|
} else {
|
||||||
} else {
|
pool.penaltyBox.Store(url, [2]float64{v[0], remainingSeconds})
|
||||||
pool.penaltyBox[url] = [2]float64{v[0], remainingSeconds}
|
}
|
||||||
|
|
||||||
|
if remainingSeconds < nextSleep {
|
||||||
|
nextSleep = remainingSeconds
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if remainingSeconds < nextSleep {
|
sleep = nextSleep
|
||||||
nextSleep = remainingSeconds
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sleep = nextSleep
|
|
||||||
pool.penaltyBoxMu.Unlock()
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddToPenaltyBox manually adds a relay to the penalty box for the specified duration.
|
||||||
|
// This prevents EnsureRelay from attempting to connect to the relay until the duration expires.
|
||||||
|
func (pool *Pool) AddToPenaltyBox(url string, duration time.Duration) {
|
||||||
|
if pool.penaltyBox == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nm := NormalizeURL(url)
|
||||||
|
pool.penaltyBox.Store(nm, [2]float64{0, duration.Seconds()})
|
||||||
|
pool.Relays.Store(nm, nil) // mark as explicitly disconnected for penalty box detection
|
||||||
|
}
|
||||||
|
|
||||||
// EnsureRelay ensures that a relay connection exists and is active.
|
// EnsureRelay ensures that a relay connection exists and is active.
|
||||||
// If the relay is not connected, it attempts to connect.
|
// If the relay is not connected, it attempts to connect.
|
||||||
func (pool *Pool) EnsureRelay(url string) (*Relay, error) {
|
func (pool *Pool) EnsureRelay(url string) (*Relay, error) {
|
||||||
@@ -137,9 +138,7 @@ func (pool *Pool) EnsureRelay(url string) (*Relay, error) {
|
|||||||
relay, ok := pool.Relays.Load(nm)
|
relay, ok := pool.Relays.Load(nm)
|
||||||
if ok && relay == nil {
|
if ok && relay == nil {
|
||||||
if pool.penaltyBox != nil {
|
if pool.penaltyBox != nil {
|
||||||
pool.penaltyBoxMu.Lock()
|
v, _ := pool.penaltyBox.Load(nm)
|
||||||
defer pool.penaltyBoxMu.Unlock()
|
|
||||||
v, _ := pool.penaltyBox[nm]
|
|
||||||
if v[1] > 0 {
|
if v[1] > 0 {
|
||||||
return nil, fmt.Errorf("in penalty box, %fs remaining", v[1])
|
return nil, fmt.Errorf("in penalty box, %fs remaining", v[1])
|
||||||
}
|
}
|
||||||
@@ -149,21 +148,27 @@ func (pool *Pool) EnsureRelay(url string) (*Relay, error) {
|
|||||||
return relay, nil
|
return relay, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
relay = NewRelay(pool.Context, url, pool.relayOptions)
|
relay = NewRelay(pool.Context, url, pool.RelayOptions)
|
||||||
// try to connect
|
// try to connect
|
||||||
// we use this ctx here so when the pool dies everything dies
|
// we use this ctx here so when the pool dies everything dies
|
||||||
if err := relay.Connect(pool.Context); err != nil {
|
if err := relay.Connect(pool.Context); err != nil {
|
||||||
if pool.penaltyBox != nil {
|
if pool.penaltyBox != nil {
|
||||||
// putting relay in penalty box
|
// putting relay in penalty box
|
||||||
pool.penaltyBoxMu.Lock()
|
pool.penaltyBox.Compute(nm, func(v [2]float64, loaded bool) (newV [2]float64, delete bool) {
|
||||||
defer pool.penaltyBoxMu.Unlock()
|
return [2]float64{v[0] + 1, 30.0 + math.Pow(2, v[0]+1)}, false
|
||||||
v, _ := pool.penaltyBox[nm]
|
})
|
||||||
pool.penaltyBox[nm] = [2]float64{v[0] + 1, 30.0 + math.Pow(2, v[0]+1)}
|
pool.Relays.Store(nm, nil) // this is important for penalty box detection on EnsureRelay
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pool.Relays.Store(nm, relay)
|
pool.Relays.Store(nm, relay)
|
||||||
|
go func(r *Relay, relayURL string) {
|
||||||
|
<-r.Context().Done()
|
||||||
|
if current, ok := pool.Relays.Load(relayURL); ok && current == r {
|
||||||
|
pool.Relays.Delete(relayURL)
|
||||||
|
}
|
||||||
|
}(relay, nm)
|
||||||
return relay, nil
|
return relay, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,8 +280,8 @@ func (pool *Pool) fetchMany(
|
|||||||
if opts.CheckDuplicate == nil {
|
if opts.CheckDuplicate == nil {
|
||||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||||
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
||||||
if exists && pool.duplicateMiddleware != nil {
|
if exists && pool.DuplicateMiddleware != nil {
|
||||||
pool.duplicateMiddleware(relay, id)
|
pool.DuplicateMiddleware(relay, id)
|
||||||
}
|
}
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
@@ -357,7 +362,7 @@ func (pool *Pool) FetchManyReplaceable(
|
|||||||
go func(nm string) {
|
go func(nm string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
if mh := pool.queryMiddleware; mh != nil {
|
if mh := pool.QueryMiddleware; mh != nil {
|
||||||
if filter.Kinds != nil && filter.Authors != nil {
|
if filter.Kinds != nil && filter.Authors != nil {
|
||||||
for _, kind := range filter.Kinds {
|
for _, kind := range filter.Kinds {
|
||||||
for _, author := range filter.Authors {
|
for _, author := range filter.Authors {
|
||||||
@@ -405,7 +410,7 @@ func (pool *Pool) FetchManyReplaceable(
|
|||||||
}
|
}
|
||||||
|
|
||||||
ie := RelayEvent{Event: evt, Relay: relay}
|
ie := RelayEvent{Event: evt, Relay: relay}
|
||||||
if mh := pool.eventMiddleware; mh != nil {
|
if mh := pool.EventMiddleware; mh != nil {
|
||||||
mh(ie)
|
mh(ie)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -448,51 +453,55 @@ func (pool *Pool) subMany(
|
|||||||
if opts.CheckDuplicate == nil {
|
if opts.CheckDuplicate == nil {
|
||||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||||
_, exists := seenAlready.LoadOrStore(id, Now())
|
_, exists := seenAlready.LoadOrStore(id, Now())
|
||||||
if exists && pool.duplicateMiddleware != nil {
|
if exists && pool.DuplicateMiddleware != nil {
|
||||||
pool.duplicateMiddleware(relay, id)
|
pool.DuplicateMiddleware(relay, id)
|
||||||
}
|
}
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pending := xsync.NewCounter()
|
pendingWg := sync.WaitGroup{}
|
||||||
pending.Add(int64(len(urls)))
|
pendingWg.Add(len(urls))
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
pendingWg.Wait()
|
||||||
|
close(events)
|
||||||
|
cancel(fmt.Errorf("aborted: %w", context.Cause(ctx)))
|
||||||
|
if closedChan != nil {
|
||||||
|
close(closedChan)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for i, url := range urls {
|
for i, url := range urls {
|
||||||
url = NormalizeURL(url)
|
url = NormalizeURL(url)
|
||||||
urls[i] = url
|
urls[i] = url
|
||||||
if idx := slices.Index(urls, url); idx != i {
|
if idx := slices.Index(urls, url); idx != i {
|
||||||
// skip duplicate relays in the list
|
// skip duplicate relays in the list
|
||||||
eoseWg.Done()
|
eoseWg.Done()
|
||||||
pending.Dec()
|
pendingWg.Done()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
eosed := atomic.Bool{}
|
eosed := atomic.Bool{}
|
||||||
|
|
||||||
go func(nm string) {
|
go func(nm string, filter Filter) {
|
||||||
defer func() {
|
defer func() {
|
||||||
pending.Dec()
|
|
||||||
if pending.Value() == 0 {
|
|
||||||
close(events)
|
|
||||||
cancel(fmt.Errorf("aborted: %w", context.Cause(ctx)))
|
|
||||||
}
|
|
||||||
if eosed.CompareAndSwap(false, true) {
|
if eosed.CompareAndSwap(false, true) {
|
||||||
eoseWg.Done()
|
eoseWg.Done()
|
||||||
}
|
}
|
||||||
|
pendingWg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hasAuthed := false
|
hasAuthed := false
|
||||||
interval := 3 * time.Second
|
interval := 3 * time.Second
|
||||||
for {
|
for {
|
||||||
select {
|
if ctx.Err() != nil {
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sub *Subscription
|
var sub *Subscription
|
||||||
|
|
||||||
if mh := pool.queryMiddleware; mh != nil {
|
if mh := pool.QueryMiddleware; mh != nil {
|
||||||
if filter.Kinds != nil && filter.Authors != nil {
|
if filter.Kinds != nil && filter.Authors != nil {
|
||||||
for _, kind := range filter.Kinds {
|
for _, kind := range filter.Kinds {
|
||||||
for _, author := range filter.Authors {
|
for _, author := range filter.Authors {
|
||||||
@@ -542,7 +551,7 @@ func (pool *Pool) subMany(
|
|||||||
}
|
}
|
||||||
|
|
||||||
ie := RelayEvent{Event: evt, Relay: relay}
|
ie := RelayEvent{Event: evt, Relay: relay}
|
||||||
if mh := pool.eventMiddleware; mh != nil {
|
if mh := pool.EventMiddleware; mh != nil {
|
||||||
mh(ie)
|
mh(ie)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -567,10 +576,13 @@ func (pool *Pool) subMany(
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
hasAuthed = true // so we don't keep doing AUTH again and again
|
hasAuthed = true // so we don't keep doing AUTH again and again
|
||||||
if closedChan != nil {
|
if closedChan != nil {
|
||||||
closedChan <- RelayClosed{
|
select {
|
||||||
|
case closedChan <- RelayClosed{
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
Relay: relay,
|
Relay: relay,
|
||||||
HandledAuth: true,
|
HandledAuth: true,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
goto subscribe
|
goto subscribe
|
||||||
@@ -578,9 +590,12 @@ func (pool *Pool) subMany(
|
|||||||
}
|
}
|
||||||
debugLogf("CLOSED from %s: '%s'\n", nm, reason)
|
debugLogf("CLOSED from %s: '%s'\n", nm, reason)
|
||||||
if closedChan != nil {
|
if closedChan != nil {
|
||||||
closedChan <- RelayClosed{
|
select {
|
||||||
|
case closedChan <- RelayClosed{
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
Relay: relay,
|
Relay: relay,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -597,7 +612,7 @@ func (pool *Pool) subMany(
|
|||||||
time.Sleep(interval)
|
time.Sleep(interval)
|
||||||
interval = min(10*time.Minute, interval*17/10) // the next time we try we will wait longer
|
interval = min(10*time.Minute, interval*17/10) // the next time we try we will wait longer
|
||||||
}
|
}
|
||||||
}(url)
|
}(url, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
return events
|
return events
|
||||||
@@ -621,13 +636,16 @@ func (pool *Pool) subManyEose(
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
cancel(errors.New("all subscriptions ended"))
|
cancel(errors.New("all subscriptions ended"))
|
||||||
close(events)
|
close(events)
|
||||||
|
if closedChan != nil {
|
||||||
|
close(closedChan)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, url := range urls {
|
for _, url := range urls {
|
||||||
go func(nm string) {
|
go func(nm string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
if mh := pool.queryMiddleware; mh != nil {
|
if mh := pool.QueryMiddleware; mh != nil {
|
||||||
if filter.Kinds != nil && filter.Authors != nil {
|
if filter.Kinds != nil && filter.Authors != nil {
|
||||||
for _, kind := range filter.Kinds {
|
for _, kind := range filter.Kinds {
|
||||||
for _, author := range filter.Authors {
|
for _, author := range filter.Authors {
|
||||||
@@ -665,10 +683,13 @@ func (pool *Pool) subManyEose(
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
hasAuthed = true // so we don't keep doing AUTH again and again
|
hasAuthed = true // so we don't keep doing AUTH again and again
|
||||||
if closedChan != nil {
|
if closedChan != nil {
|
||||||
closedChan <- RelayClosed{
|
select {
|
||||||
|
case closedChan <- RelayClosed{
|
||||||
Relay: relay,
|
Relay: relay,
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
HandledAuth: true,
|
HandledAuth: true,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
goto subscribe
|
goto subscribe
|
||||||
@@ -676,9 +697,12 @@ func (pool *Pool) subManyEose(
|
|||||||
}
|
}
|
||||||
debugLogf("[pool] CLOSED from %s: '%s'\n", nm, reason)
|
debugLogf("[pool] CLOSED from %s: '%s'\n", nm, reason)
|
||||||
if closedChan != nil {
|
if closedChan != nil {
|
||||||
closedChan <- RelayClosed{
|
select {
|
||||||
|
case closedChan <- RelayClosed{
|
||||||
Relay: relay,
|
Relay: relay,
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -688,7 +712,7 @@ func (pool *Pool) subManyEose(
|
|||||||
}
|
}
|
||||||
|
|
||||||
ie := RelayEvent{Event: evt, Relay: relay}
|
ie := RelayEvent{Event: evt, Relay: relay}
|
||||||
if mh := pool.eventMiddleware; mh != nil {
|
if mh := pool.EventMiddleware; mh != nil {
|
||||||
mh(ie)
|
mh(ie)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -783,21 +807,40 @@ func (pool *Pool) batchedQueryMany(
|
|||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(len(dfs))
|
wg.Add(len(dfs))
|
||||||
seenAlready := xsync.NewMapOf[ID, struct{}]()
|
seenAlready := xsync.NewMapOf[ID, struct{}]()
|
||||||
|
forwardWg := sync.WaitGroup{}
|
||||||
|
|
||||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||||
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
||||||
if exists && pool.duplicateMiddleware != nil {
|
if exists && pool.DuplicateMiddleware != nil {
|
||||||
pool.duplicateMiddleware(relay, id)
|
pool.DuplicateMiddleware(relay, id)
|
||||||
}
|
}
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, df := range dfs {
|
for _, df := range dfs {
|
||||||
go func(df DirectedFilter) {
|
go func(df DirectedFilter) {
|
||||||
|
var innerClosed chan RelayClosed
|
||||||
|
if closedChan != nil {
|
||||||
|
innerClosed = make(chan RelayClosed)
|
||||||
|
forwardWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer forwardWg.Done()
|
||||||
|
for rc := range innerClosed {
|
||||||
|
select {
|
||||||
|
case closedChan <- rc:
|
||||||
|
case <-ctx.Done():
|
||||||
|
for range innerClosed {
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
for ie := range pool.subManyEose(ctx,
|
for ie := range pool.subManyEose(ctx,
|
||||||
[]string{df.Relay},
|
[]string{df.Relay},
|
||||||
df.Filter,
|
df.Filter,
|
||||||
closedChan,
|
innerClosed,
|
||||||
opts,
|
opts,
|
||||||
) {
|
) {
|
||||||
select {
|
select {
|
||||||
@@ -814,6 +857,10 @@ func (pool *Pool) batchedQueryMany(
|
|||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(res)
|
close(res)
|
||||||
|
if closedChan != nil {
|
||||||
|
forwardWg.Wait()
|
||||||
|
close(closedChan)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@@ -849,22 +896,41 @@ func (pool *Pool) batchedSubscribeMany(
|
|||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(len(dfs))
|
wg.Add(len(dfs))
|
||||||
seenAlready := xsync.NewMapOf[ID, struct{}]()
|
seenAlready := xsync.NewMapOf[ID, struct{}]()
|
||||||
|
forwardWg := sync.WaitGroup{}
|
||||||
|
|
||||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||||
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
||||||
if exists && pool.duplicateMiddleware != nil {
|
if exists && pool.DuplicateMiddleware != nil {
|
||||||
pool.duplicateMiddleware(relay, id)
|
pool.DuplicateMiddleware(relay, id)
|
||||||
}
|
}
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, df := range dfs {
|
for _, df := range dfs {
|
||||||
go func(df DirectedFilter) {
|
go func(df DirectedFilter) {
|
||||||
|
var innerClosed chan RelayClosed
|
||||||
|
if closedChan != nil {
|
||||||
|
innerClosed = make(chan RelayClosed)
|
||||||
|
forwardWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer forwardWg.Done()
|
||||||
|
for rc := range innerClosed {
|
||||||
|
select {
|
||||||
|
case closedChan <- rc:
|
||||||
|
case <-ctx.Done():
|
||||||
|
for range innerClosed {
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
for ie := range pool.subMany(ctx,
|
for ie := range pool.subMany(ctx,
|
||||||
[]string{df.Relay},
|
[]string{df.Relay},
|
||||||
df.Filter,
|
df.Filter,
|
||||||
nil,
|
nil,
|
||||||
closedChan,
|
innerClosed,
|
||||||
opts,
|
opts,
|
||||||
) {
|
) {
|
||||||
select {
|
select {
|
||||||
@@ -881,6 +947,10 @@ func (pool *Pool) batchedSubscribeMany(
|
|||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(res)
|
close(res)
|
||||||
|
if closedChan != nil {
|
||||||
|
forwardWg.Wait()
|
||||||
|
close(closedChan)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|||||||
@@ -1,15 +1,19 @@
|
|||||||
package nostr
|
package nostr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"iter"
|
"iter"
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,18 +24,37 @@ import (
|
|||||||
|
|
||||||
var subscriptionIDCounter atomic.Int64
|
var subscriptionIDCounter atomic.Int64
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrDisconnected = errors.New("<disconnected>")
|
||||||
|
ErrPingFailed = errors.New("<ping failed>")
|
||||||
|
)
|
||||||
|
|
||||||
|
type writeRequest struct {
|
||||||
|
msg []byte
|
||||||
|
answer chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type closeCause struct {
|
||||||
|
code ws.StatusCode
|
||||||
|
reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c closeCause) Error() string {
|
||||||
|
if c.reason == "" {
|
||||||
|
return "relay closed"
|
||||||
|
}
|
||||||
|
return c.reason
|
||||||
|
}
|
||||||
|
|
||||||
// Relay represents a connection to a Nostr relay.
|
// Relay represents a connection to a Nostr relay.
|
||||||
type Relay struct {
|
type Relay struct {
|
||||||
closeMutex sync.Mutex
|
|
||||||
|
|
||||||
URL string
|
URL string
|
||||||
requestHeader http.Header // e.g. for origin header
|
requestHeader http.Header // e.g. for origin header
|
||||||
|
|
||||||
// websocket connection
|
// websocket connection
|
||||||
conn *ws.Conn
|
conn *ws.Conn
|
||||||
writeQueue chan writeRequest
|
writeQueue chan writeRequest
|
||||||
closed *atomic.Bool
|
closed *atomic.Bool
|
||||||
closedNotify chan struct{}
|
|
||||||
|
|
||||||
Subscriptions *xsync.MapOf[int64, *Subscription]
|
Subscriptions *xsync.MapOf[int64, *Subscription]
|
||||||
|
|
||||||
@@ -39,7 +62,10 @@ type Relay struct {
|
|||||||
connectionContext context.Context // will be canceled when the connection closes
|
connectionContext context.Context // will be canceled when the connection closes
|
||||||
connectionContextCancel context.CancelCauseFunc
|
connectionContextCancel context.CancelCauseFunc
|
||||||
|
|
||||||
challenge string // NIP-42 challenge, we only keep the last
|
challenge string // NIP-42 challenge, we only keep the last
|
||||||
|
performAuth sync.Once
|
||||||
|
authed bool
|
||||||
|
|
||||||
authHandler func(context.Context, *Relay, *Event) error
|
authHandler func(context.Context, *Relay, *Event) error
|
||||||
noticeHandler func(*Relay, string) // NIP-01 NOTICEs
|
noticeHandler func(*Relay, string) // NIP-01 NOTICEs
|
||||||
customHandler func(string) // nonstandard unparseable messages
|
customHandler func(string) // nonstandard unparseable messages
|
||||||
@@ -66,8 +92,32 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
|||||||
customHandler: opts.CustomHandler,
|
customHandler: opts.CustomHandler,
|
||||||
noticeHandler: opts.NoticeHandler,
|
noticeHandler: opts.NoticeHandler,
|
||||||
authHandler: opts.AuthHandler,
|
authHandler: opts.AuthHandler,
|
||||||
|
closed: &atomic.Bool{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if wasClosed := r.closed.Swap(true); wasClosed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.conn != nil {
|
||||||
|
cause := context.Cause(ctx)
|
||||||
|
code := ws.StatusNormalClosure
|
||||||
|
reason := ""
|
||||||
|
var cc closeCause
|
||||||
|
if errors.As(cause, &cc) {
|
||||||
|
code = cc.code
|
||||||
|
reason = cc.reason
|
||||||
|
} else if cause != nil {
|
||||||
|
reason = cause.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = r.conn.Close(code, reason)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,7 +159,18 @@ func (r *Relay) String() string {
|
|||||||
func (r *Relay) Context() context.Context { return r.connectionContext }
|
func (r *Relay) Context() context.Context { return r.connectionContext }
|
||||||
|
|
||||||
// IsConnected returns true if the connection to this relay seems to be active.
|
// IsConnected returns true if the connection to this relay seems to be active.
|
||||||
func (r *Relay) IsConnected() bool { return !r.closed.Load() }
|
func (r *Relay) IsConnected() bool {
|
||||||
|
if r.closed.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.conn == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.connectionContext == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return r.connectionContext.Err() == nil
|
||||||
|
}
|
||||||
|
|
||||||
// Connect tries to establish a websocket connection to r.URL.
|
// Connect tries to establish a websocket connection to r.URL.
|
||||||
// If the context expires before the connection is complete, an error is returned.
|
// If the context expires before the connection is complete, an error is returned.
|
||||||
@@ -136,6 +197,9 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
|
|||||||
if r.connectionContext == nil || r.Subscriptions == nil {
|
if r.connectionContext == nil || r.Subscriptions == nil {
|
||||||
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
|
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
|
||||||
}
|
}
|
||||||
|
if r.connectionContext.Err() != nil {
|
||||||
|
return fmt.Errorf("relay context canceled")
|
||||||
|
}
|
||||||
|
|
||||||
if r.URL == "" {
|
if r.URL == "" {
|
||||||
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
||||||
@@ -148,6 +212,128 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
|
||||||
|
debugLogf("{%s} connecting!\n", r.URL)
|
||||||
|
|
||||||
|
dialCtx := ctx
|
||||||
|
if _, ok := dialCtx.Deadline(); !ok {
|
||||||
|
// if no timeout is set, force it to 7 seconds
|
||||||
|
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||||
|
}
|
||||||
|
|
||||||
|
dialOpts := &ws.DialOptions{
|
||||||
|
HTTPHeader: http.Header{
|
||||||
|
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
|
||||||
|
},
|
||||||
|
CompressionMode: ws.CompressionContextTakeover,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
}
|
||||||
|
for k, v := range r.requestHeader {
|
||||||
|
dialOpts.HTTPHeader[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.SetReadLimit(2 << 24) // 33MB
|
||||||
|
|
||||||
|
// ping every 19 seconds
|
||||||
|
ticker := time.NewTicker(19 * time.Second)
|
||||||
|
|
||||||
|
// main websocket loop
|
||||||
|
readQueue := make(chan string, 64 /* add some buffer to account for processing/IO mismatches */)
|
||||||
|
|
||||||
|
r.conn = c
|
||||||
|
r.writeQueue = make(chan writeRequest, 64 /* idem */)
|
||||||
|
r.closed = &atomic.Bool{}
|
||||||
|
|
||||||
|
connCtx := r.connectionContext
|
||||||
|
go func() {
|
||||||
|
pingAttempt := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-connCtx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
debugLogf("{%s} pinging\n", r.URL)
|
||||||
|
pingCtx, cancel := context.WithTimeoutCause(connCtx, time.Millisecond*800, errors.New("ping took too long"))
|
||||||
|
err := c.Ping(pingCtx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
pingAttempt++
|
||||||
|
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
|
||||||
|
|
||||||
|
if pingAttempt >= 3 {
|
||||||
|
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
|
||||||
|
_ = r.close(ErrPingFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// ping was OK
|
||||||
|
debugLogf("{%s} ping OK", r.URL)
|
||||||
|
pingAttempt = 0
|
||||||
|
case wr := <-r.writeQueue:
|
||||||
|
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
|
||||||
|
writeCtx, cancel := context.WithTimeoutCause(connCtx, time.Second*10, errors.New("write took too long"))
|
||||||
|
err := c.Write(writeCtx, ws.MessageText, wr.msg)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
|
||||||
|
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "write failed"})
|
||||||
|
if wr.answer != nil {
|
||||||
|
wr.answer <- err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if wr.answer != nil {
|
||||||
|
close(wr.answer)
|
||||||
|
}
|
||||||
|
case msg := <-readQueue:
|
||||||
|
debugLogf("{%s} received %v\n", r.URL, msg)
|
||||||
|
r.handleMessage(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// read loop -- loops back to the main loop
|
||||||
|
go func() {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
|
for {
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
_, reader, err := c.Reader(connCtx)
|
||||||
|
if err != nil {
|
||||||
|
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
|
||||||
|
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to get reader"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(buf, reader); err != nil {
|
||||||
|
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
|
||||||
|
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to read"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := string(buf.Bytes())
|
||||||
|
select {
|
||||||
|
case readQueue <- msg:
|
||||||
|
case <-connCtx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Relay) handleMessage(message string) {
|
func (r *Relay) handleMessage(message string) {
|
||||||
// if this is an "EVENT" we will have this preparser logic that should speed things up a little
|
// if this is an "EVENT" we will have this preparser logic that should speed things up a little
|
||||||
// as we skip handling duplicate events
|
// as we skip handling duplicate events
|
||||||
@@ -178,7 +364,6 @@ func (r *Relay) handleMessage(message string) {
|
|||||||
|
|
||||||
switch env := envelope.(type) {
|
switch env := envelope.(type) {
|
||||||
case *NoticeEnvelope:
|
case *NoticeEnvelope:
|
||||||
// see WithNoticeHandler
|
|
||||||
if r.noticeHandler != nil {
|
if r.noticeHandler != nil {
|
||||||
r.noticeHandler(r, string(*env))
|
r.noticeHandler(r, string(*env))
|
||||||
} else {
|
} else {
|
||||||
@@ -188,7 +373,10 @@ func (r *Relay) handleMessage(message string) {
|
|||||||
if env.Challenge == nil {
|
if env.Challenge == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.performAuth = sync.Once{} // this ensures we can try to auth again
|
||||||
r.challenge = *env.Challenge
|
r.challenge = *env.Challenge
|
||||||
|
|
||||||
if r.authHandler != nil {
|
if r.authHandler != nil {
|
||||||
go func() {
|
go func() {
|
||||||
r.Auth(r.Context(), func(ctx context.Context, evt *Event) error {
|
r.Auth(r.Context(), func(ctx context.Context, evt *Event) error {
|
||||||
@@ -244,14 +432,6 @@ func (r *Relay) handleMessage(message string) {
|
|||||||
|
|
||||||
// Write queues an arbitrary message to be sent to the relay.
|
// Write queues an arbitrary message to be sent to the relay.
|
||||||
func (r *Relay) Write(msg []byte) {
|
func (r *Relay) Write(msg []byte) {
|
||||||
r.closeMutex.Lock()
|
|
||||||
defer r.closeMutex.Unlock()
|
|
||||||
select {
|
|
||||||
case <-r.closedNotify:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-r.connectionContext.Done():
|
case <-r.connectionContext.Done():
|
||||||
case r.writeQueue <- writeRequest{msg: msg, answer: nil}:
|
case r.writeQueue <- writeRequest{msg: msg, answer: nil}:
|
||||||
@@ -260,20 +440,24 @@ func (r *Relay) Write(msg []byte) {
|
|||||||
|
|
||||||
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
|
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
|
||||||
func (r *Relay) WriteWithError(msg []byte) error {
|
func (r *Relay) WriteWithError(msg []byte) error {
|
||||||
ch := make(chan error)
|
ch := make(chan error, 1)
|
||||||
r.closeMutex.Lock()
|
|
||||||
defer r.closeMutex.Unlock()
|
if r.writeQueue == nil {
|
||||||
select {
|
return nil
|
||||||
case <-r.closedNotify:
|
|
||||||
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-r.connectionContext.Done():
|
case <-r.connectionContext.Done():
|
||||||
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
||||||
case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
||||||
}
|
}
|
||||||
return <-ch
|
|
||||||
|
select {
|
||||||
|
case err := <-ch:
|
||||||
|
return err
|
||||||
|
case <-r.connectionContext.Done():
|
||||||
|
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an OK response.
|
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an OK response.
|
||||||
@@ -286,20 +470,38 @@ func (r *Relay) Publish(ctx context.Context, event Event) error {
|
|||||||
// You don't have to build the AUTH event yourself, this function takes a function to which the
|
// You don't have to build the AUTH event yourself, this function takes a function to which the
|
||||||
// event that must be signed will be passed, so it's only necessary to sign that.
|
// event that must be signed will be passed, so it's only necessary to sign that.
|
||||||
func (r *Relay) Auth(ctx context.Context, sign func(context.Context, *Event) error) error {
|
func (r *Relay) Auth(ctx context.Context, sign func(context.Context, *Event) error) error {
|
||||||
authEvent := Event{
|
if r.authed {
|
||||||
CreatedAt: Now(),
|
return nil
|
||||||
Kind: KindClientAuthentication,
|
|
||||||
Tags: Tags{
|
|
||||||
Tag{"relay", r.URL},
|
|
||||||
Tag{"challenge", r.challenge},
|
|
||||||
},
|
|
||||||
Content: "",
|
|
||||||
}
|
|
||||||
if err := sign(ctx, &authEvent); err != nil {
|
|
||||||
return fmt.Errorf("error signing auth event: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.publish(ctx, authEvent.ID, &AuthEnvelope{Event: authEvent})
|
if r.challenge == "" {
|
||||||
|
return fmt.Errorf("no challenge, can't AUTH")
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
r.performAuth.Do(func() {
|
||||||
|
authEvent := Event{
|
||||||
|
CreatedAt: Now(),
|
||||||
|
Kind: KindClientAuthentication,
|
||||||
|
Tags: Tags{
|
||||||
|
Tag{"relay", r.URL},
|
||||||
|
Tag{"challenge", r.challenge},
|
||||||
|
},
|
||||||
|
Content: "",
|
||||||
|
}
|
||||||
|
if err := sign(ctx, &authEvent); err != nil {
|
||||||
|
err = fmt.Errorf("error signing auth event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.publish(ctx, authEvent.ID, &AuthEnvelope{Event: authEvent})
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
r.authed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// publish can be used both for EVENT and for AUTH
|
// publish can be used both for EVENT and for AUTH
|
||||||
@@ -381,22 +583,20 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
|
|||||||
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
|
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
|
||||||
// Failure to do that will result in a huge number of halted goroutines being created.
|
// Failure to do that will result in a huge number of halted goroutines being created.
|
||||||
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
||||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
if !r.IsConnected() {
|
||||||
|
return nil, ErrDisconnected
|
||||||
if r.conn == nil {
|
|
||||||
return nil, fmt.Errorf("not connected to %s", r.URL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||||
|
|
||||||
if err := sub.Fire(); err != nil {
|
if err := sub.Fire(); err != nil {
|
||||||
|
sub.cancel(ErrFireFailed)
|
||||||
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)
|
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
<-ctx.Done()
|
||||||
case <-r.closedNotify:
|
sub.cancel(nil)
|
||||||
sub.unsub(ErrDisconnected)
|
|
||||||
case <-ctx.Done():
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return sub, nil
|
return sub, nil
|
||||||
@@ -420,6 +620,7 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
|
|||||||
ClosedReason: make(chan string, 1),
|
ClosedReason: make(chan string, 1),
|
||||||
Filter: filter,
|
Filter: filter,
|
||||||
match: filter.Matches,
|
match: filter.Matches,
|
||||||
|
eoseTimedOut: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
sub.checkDuplicate = opts.CheckDuplicate
|
sub.checkDuplicate = opts.CheckDuplicate
|
||||||
@@ -444,12 +645,45 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(opts.MaxWaitForEOSE)
|
time.Sleep(opts.MaxWaitForEOSE)
|
||||||
|
close(sub.eoseTimedOut)
|
||||||
sub.dispatchEose()
|
sub.dispatchEose()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if the relay connection dies, cancel this subscription
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-sub.Context.Done():
|
||||||
|
return
|
||||||
|
case <-r.connectionContext.Done():
|
||||||
|
sub.cancel(context.Cause(r.connectionContext))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// start handling events, eose, unsub etc:
|
// start handling events, eose, unsub etc:
|
||||||
go sub.start()
|
go func() {
|
||||||
|
<-sub.Context.Done()
|
||||||
|
|
||||||
|
// mark subscription as closed and send a CLOSE to the relay (naive sync.Once implementation)
|
||||||
|
if sub.live.CompareAndSwap(true, false) {
|
||||||
|
closeMsg := CloseEnvelope(sub.id)
|
||||||
|
closeb, _ := (&closeMsg).MarshalJSON()
|
||||||
|
if err := sub.Relay.WriteWithError(closeb); err != nil {
|
||||||
|
_ = sub.Relay.close(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove subscription from our map
|
||||||
|
sub.Relay.Subscriptions.Delete(sub.counter)
|
||||||
|
|
||||||
|
// do this so we don't have the possibility of closing the Events channel and then trying to send to it
|
||||||
|
sub.mu.Lock()
|
||||||
|
close(sub.Events)
|
||||||
|
if sub.countResult != nil {
|
||||||
|
close(sub.countResult)
|
||||||
|
}
|
||||||
|
sub.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
return sub
|
return sub
|
||||||
}
|
}
|
||||||
@@ -484,29 +718,25 @@ func (r *Relay) QueryEvents(filter Filter) iter.Seq[Event] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
|
// Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
|
||||||
func (r *Relay) Count(
|
// If opts.AutoAuth is set, it will handle "auth-required:" CLOSEs using RelayOptions.AuthHandler.
|
||||||
ctx context.Context,
|
func (r *Relay) Count(ctx context.Context, filter Filter, opts SubscriptionOptions) (uint32, []byte, error) {
|
||||||
filter Filter,
|
|
||||||
opts SubscriptionOptions,
|
|
||||||
) (uint32, []byte, error) {
|
|
||||||
v, err := r.countInternal(ctx, filter, opts)
|
v, err := r.countInternal(ctx, filter, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v.Count == nil {
|
||||||
|
return 0, nil, errors.New("count subscription ended without result")
|
||||||
|
}
|
||||||
|
|
||||||
return *v.Count, v.HyperLogLog, nil
|
return *v.Count, v.HyperLogLog, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) {
|
func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) {
|
||||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
if !r.IsConnected() {
|
||||||
sub.countResult = make(chan CountEnvelope)
|
return CountEnvelope{}, ErrDisconnected
|
||||||
|
|
||||||
if err := sub.Fire(); err != nil {
|
|
||||||
return CountEnvelope{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer sub.unsub(errors.New("countInternal() ended"))
|
|
||||||
|
|
||||||
if _, ok := ctx.Deadline(); !ok {
|
if _, ok := ctx.Deadline(); !ok {
|
||||||
// if no timeout is set, force it to 7 seconds
|
// if no timeout is set, force it to 7 seconds
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
@@ -514,13 +744,54 @@ func (r *Relay) countInternal(ctx context.Context, filter Filter, opts Subscript
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hasAuthed := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||||
case count := <-sub.countResult:
|
sub.countResult = make(chan CountEnvelope, 1)
|
||||||
return count, nil
|
|
||||||
case <-ctx.Done():
|
if err := sub.Fire(); err != nil {
|
||||||
return CountEnvelope{}, ctx.Err()
|
sub.cancel(ErrFireFailed)
|
||||||
|
return CountEnvelope{}, fmt.Errorf("couldn't count %v at %s: %w", filter, r.URL, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
sub.cancel(nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case count, ok := <-sub.countResult:
|
||||||
|
sub.cancel(errors.New("countInternal() ended"))
|
||||||
|
if !ok || count.Count == nil {
|
||||||
|
return CountEnvelope{}, errors.New("count subscription ended without result")
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
case reason := <-sub.ClosedReason:
|
||||||
|
sub.cancel(errors.New("countInternal() ended"))
|
||||||
|
if strings.HasPrefix(reason, "auth-required:") && r.authHandler != nil && !hasAuthed {
|
||||||
|
authErr := r.Auth(ctx, func(authCtx context.Context, evt *Event) error {
|
||||||
|
return r.authHandler(authCtx, r, evt)
|
||||||
|
})
|
||||||
|
if authErr == nil {
|
||||||
|
hasAuthed = true
|
||||||
|
goto resubscribe
|
||||||
|
}
|
||||||
|
return CountEnvelope{}, fmt.Errorf("failed to auth: %w", authErr)
|
||||||
|
}
|
||||||
|
return CountEnvelope{}, fmt.Errorf("count: CLOSED received: %s", reason)
|
||||||
|
case <-sub.Context.Done():
|
||||||
|
sub.cancel(errors.New("countInternal() ended"))
|
||||||
|
return CountEnvelope{}, context.Cause(sub.Context)
|
||||||
|
case <-ctx.Done():
|
||||||
|
sub.cancel(errors.New("countInternal() ended"))
|
||||||
|
return CountEnvelope{}, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resubscribe:
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,19 +801,7 @@ func (r *Relay) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) close(reason error) error {
|
func (r *Relay) close(reason error) error {
|
||||||
r.closeMutex.Lock()
|
|
||||||
defer r.closeMutex.Unlock()
|
|
||||||
|
|
||||||
if r.connectionContextCancel == nil {
|
|
||||||
return fmt.Errorf("relay already closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.conn == nil {
|
|
||||||
return fmt.Errorf("relay not connected")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.connectionContextCancel(reason)
|
r.connectionContextCancel(reason)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -18,7 +18,7 @@ import (
|
|||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultSchemaURL = "https://raw.githubusercontent.com/nostr-protocol/registry-of-kinds/952c36fcd7129aa85be22d00c2b381ae47ee9c18/schema.yaml"
|
const DefaultSchemaURL = "https://raw.githubusercontent.com/nostr-protocol/registry-of-kinds/ffa18bf6fb5496d755b465b062e18c676df1a5d4/schema.yaml"
|
||||||
|
|
||||||
// this is used by hex.Decode in the "hex" validator -- we don't care about data races
|
// this is used by hex.Decode in the "hex" validator -- we don't care about data races
|
||||||
var hexdummydecoder = make([]byte, 128)
|
var hexdummydecoder = make([]byte, 128)
|
||||||
|
|||||||
+13
-13
@@ -124,7 +124,7 @@ func TestValidateNext_ID(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
next := &nextSpec{Type: "id", Required: true}
|
next := &ContentSpec{Type: "id", Required: true}
|
||||||
_, err := v.validateNext(tt.tag, 1, next)
|
_, err := v.validateNext(tt.tag, 1, next)
|
||||||
if tt.valid {
|
if tt.valid {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -163,7 +163,7 @@ func TestValidateNext_PubKey(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
next := &nextSpec{Type: "pubkey", Required: true}
|
next := &ContentSpec{Type: "pubkey", Required: true}
|
||||||
_, err := v.validateNext(tt.tag, 1, next)
|
_, err := v.validateNext(tt.tag, 1, next)
|
||||||
if tt.valid {
|
if tt.valid {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -212,7 +212,7 @@ func TestValidateNext_Relay(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
next := &nextSpec{Type: "relay", Required: true}
|
next := &ContentSpec{Type: "relay", Required: true}
|
||||||
_, err := v.validateNext(tt.tag, 1, next)
|
_, err := v.validateNext(tt.tag, 1, next)
|
||||||
if tt.valid {
|
if tt.valid {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -261,7 +261,7 @@ func TestValidateNext_Kind(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
next := &nextSpec{Type: "kind", Required: true}
|
next := &ContentSpec{Type: "kind", Required: true}
|
||||||
_, err := v.validateNext(tt.tag, 1, next)
|
_, err := v.validateNext(tt.tag, 1, next)
|
||||||
if tt.valid {
|
if tt.valid {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -298,7 +298,7 @@ func TestValidateNext_Constrained(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
next := &nextSpec{Type: "constrained", Required: true, Either: tt.allowed}
|
next := &ContentSpec{Type: "constrained", Required: true, Either: tt.allowed}
|
||||||
_, err := v.validateNext(tt.tag, 3, next)
|
_, err := v.validateNext(tt.tag, 3, next)
|
||||||
if tt.valid {
|
if tt.valid {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -342,7 +342,7 @@ func TestValidateNext_GitCommit(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
next := &nextSpec{Type: "hex", Min: 40, Max: 40, Required: true}
|
next := &ContentSpec{Type: "hex", Min: 40, Max: 40, Required: true}
|
||||||
_, err := v.validateNext(tt.tag, 1, next)
|
_, err := v.validateNext(tt.tag, 1, next)
|
||||||
if tt.valid {
|
if tt.valid {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -376,7 +376,7 @@ func TestValidateNext_Addr(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
next := &nextSpec{Type: "addr", Required: true}
|
next := &ContentSpec{Type: "addr", Required: true}
|
||||||
_, err := v.validateNext(tt.tag, 1, next)
|
_, err := v.validateNext(tt.tag, 1, next)
|
||||||
if tt.valid {
|
if tt.valid {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -393,7 +393,7 @@ func TestValidateNext_Free(t *testing.T) {
|
|||||||
|
|
||||||
// free type should accept anything
|
// free type should accept anything
|
||||||
tag := nostr.Tag{"test", "any value here", "even", "multiple", "values"}
|
tag := nostr.Tag{"test", "any value here", "even", "multiple", "values"}
|
||||||
next := &nextSpec{Type: "free", Required: true}
|
next := &ContentSpec{Type: "free", Required: true}
|
||||||
_, err = v.validateNext(tag, 1, next)
|
_, err = v.validateNext(tag, 1, next)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
@@ -403,7 +403,7 @@ func TestValidateNext_UnknownType(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tag := nostr.Tag{"test", "value"}
|
tag := nostr.Tag{"test", "value"}
|
||||||
next := &nextSpec{Type: "unknown-type", Required: true}
|
next := &ContentSpec{Type: "unknown-type", Required: true}
|
||||||
|
|
||||||
// should not fail when FailOnUnknownType is false (default)
|
// should not fail when FailOnUnknownType is false (default)
|
||||||
_, err = v.validateNext(tag, 1, next)
|
_, err = v.validateNext(tag, 1, next)
|
||||||
@@ -422,13 +422,13 @@ func TestValidateNext_RequiredField(t *testing.T) {
|
|||||||
|
|
||||||
// test missing required field
|
// test missing required field
|
||||||
tag := nostr.Tag{"test"} // only name, missing required value
|
tag := nostr.Tag{"test"} // only name, missing required value
|
||||||
next := &nextSpec{Type: "free", Required: true}
|
next := &ContentSpec{Type: "free", Required: true}
|
||||||
_, err = v.validateNext(tag, 1, next)
|
_, err = v.validateNext(tag, 1, next)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "missing index 1")
|
require.Contains(t, err.Error(), "missing index 1")
|
||||||
|
|
||||||
// test optional field
|
// test optional field
|
||||||
next = &nextSpec{Type: "free", Required: false}
|
next = &ContentSpec{Type: "free", Required: false}
|
||||||
_, err = v.validateNext(tag, 1, next)
|
_, err = v.validateNext(tag, 1, next)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
@@ -439,7 +439,7 @@ func TestValidateNext_Variadic(t *testing.T) {
|
|||||||
|
|
||||||
// test variadic field with multiple values
|
// test variadic field with multiple values
|
||||||
tag := nostr.Tag{"test", "value1", "value2", "value3"}
|
tag := nostr.Tag{"test", "value1", "value2", "value3"}
|
||||||
next := &nextSpec{Type: "free", Variadic: true}
|
next := &ContentSpec{Type: "free", Variadic: true}
|
||||||
_, err = v.validateNext(tag, 1, next)
|
_, err = v.validateNext(tag, 1, next)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -450,7 +450,7 @@ func TestValidateNext_Variadic(t *testing.T) {
|
|||||||
|
|
||||||
// test variadic field with no values (should fail if required)
|
// test variadic field with no values (should fail if required)
|
||||||
tag = nostr.Tag{"test"}
|
tag = nostr.Tag{"test"}
|
||||||
next = &nextSpec{Type: "free", Variadic: true, Required: true}
|
next = &ContentSpec{Type: "free", Variadic: true, Required: true}
|
||||||
_, err = v.validateNext(tag, 1, next)
|
_, err = v.validateNext(tag, 1, next)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package sdk
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
@@ -55,14 +56,15 @@ func (sys *System) batchLoadAddressableEvents(
|
|||||||
cm := sync.Mutex{}
|
cm := sync.Mutex{}
|
||||||
|
|
||||||
aggregatedContext, aggregatedCancel := context.WithCancel(context.Background())
|
aggregatedContext, aggregatedCancel := context.WithCancel(context.Background())
|
||||||
waiting := len(pubkeys)
|
waiting := atomic.Int32{}
|
||||||
|
waiting.Add(int32(len(pubkeys)))
|
||||||
|
|
||||||
for i, pubkey := range pubkeys {
|
for i, pubkey := range pubkeys {
|
||||||
ctx, cancel := context.WithCancel(ctxs[i])
|
ctx, cancel := context.WithCancel(ctxs[i])
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// build batched queries for the external relays
|
// build batched queries for the external relays
|
||||||
go func(i int, pubkey nostr.PubKey) {
|
go func(i int, pubkey nostr.PubKey, ctx context.Context) {
|
||||||
// gather relays we'll use for this pubkey
|
// gather relays we'll use for this pubkey
|
||||||
relays := sys.determineRelaysToQuery(ctx, pubkey, kind)
|
relays := sys.determineRelaysToQuery(ctx, pubkey, kind)
|
||||||
|
|
||||||
@@ -92,11 +94,10 @@ func (sys *System) batchLoadAddressableEvents(
|
|||||||
wg.Done()
|
wg.Done()
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
waiting--
|
if waiting.Add(-1) == 0 {
|
||||||
if waiting == 0 {
|
|
||||||
aggregatedCancel()
|
aggregatedCancel()
|
||||||
}
|
}
|
||||||
}(i, pubkey)
|
}(i, pubkey, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for relay batches to be prepared
|
// wait for relay batches to be prepared
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user