Compare commits
83 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4daeb8737c | |||
| 7ab69cbc60 | |||
| 029f4eb0d8 | |||
| cf734a3ac7 | |||
| d92a0cde16 | |||
| 5944a3ead6 | |||
| 3e35681cb9 | |||
| 8515153df2 | |||
| 98fa53464e | |||
| 29cdd48fcb | |||
| 181de14642 | |||
| 1794f0690f | |||
| 12af4717d4 | |||
| b989b66bb7 | |||
| 4261bc88f8 | |||
| a8205a3790 | |||
| 0152341144 | |||
| 9bf9816c15 | |||
| 82f2fbdb99 | |||
| d5b54a1c91 | |||
| 637412fd38 | |||
| 9b881801d8 | |||
| 371cecdb84 | |||
| 2735abe060 | |||
| b9a3e78752 | |||
| ff03090610 | |||
| 72a5be58d7 | |||
| 2c30300756 | |||
| d1fdc262f2 | |||
| 117a304f68 | |||
| ac2d4579f1 | |||
| 56610a32e6 | |||
| d4940c7858 | |||
| 172e7890b9 | |||
| 3acfbbca0a | |||
| b5974cfa45 | |||
| c74ac74a0e | |||
| ec6f3f8a41 | |||
| d43fbbf02d | |||
| 6a686c31af | |||
| a6fdcd8b30 | |||
| e675f04bd2 | |||
| 0630bbe4e9 | |||
| 55c5194bdf | |||
| f3f5c3982d | |||
| 1520264394 | |||
| 2cec1c9434 | |||
| 6cbe984e16 | |||
| 5a0b18e65a | |||
| bb4093d834 | |||
| 3bd059d1f9 | |||
| 681bd55e55 | |||
| 4e490879b5 | |||
| 4348c64b14 | |||
| 2c0d9712e3 | |||
| 4719c0bc9f | |||
| 163e59e1f1 | |||
| 21ce0046c0 | |||
| 1d14e6bebe | |||
| 23d525f067 | |||
| 4dab261bdf | |||
| 44c429d6b1 | |||
| 4b5c51ffc0 | |||
| 5de9501556 | |||
| 8ba05114cd | |||
| 1df85217d9 | |||
| 195cb944e2 | |||
| c31b92707b | |||
| 00ffe16cb7 | |||
| 4d1b6c1df0 | |||
| 62d15178ec | |||
| 32dd39da81 | |||
| 7aa127a8c3 | |||
| 55cc52876a | |||
| 137c09369a | |||
| d445ba9919 | |||
| d30c1bff46 | |||
| 65ef1c50a7 | |||
| 7a4b71b39b | |||
| 3f52d10421 | |||
| a98ac0d050 | |||
| 28bef1c990 | |||
| beb8a72491 |
-158
@@ -1,158 +0,0 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
var ErrDisconnected = errors.New("<disconnected>")
|
||||
|
||||
type writeRequest struct {
|
||||
msg []byte
|
||||
answer chan error
|
||||
}
|
||||
|
||||
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
|
||||
debugLogf("{%s} connecting!\n", r.URL)
|
||||
|
||||
dialCtx := ctx
|
||||
if _, ok := dialCtx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||
}
|
||||
|
||||
dialOpts := &ws.DialOptions{
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
|
||||
},
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
for k, v := range r.requestHeader {
|
||||
dialOpts.HTTPHeader[k] = v
|
||||
}
|
||||
|
||||
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.SetReadLimit(2 << 24) // 33MB
|
||||
|
||||
// this will tell if the connection is closed
|
||||
|
||||
// ping every 29 seconds
|
||||
ticker := time.NewTicker(29 * time.Second)
|
||||
|
||||
// main websocket loop
|
||||
readQueue := make(chan string)
|
||||
|
||||
r.conn = c
|
||||
r.writeQueue = make(chan writeRequest)
|
||||
r.closed = &atomic.Bool{}
|
||||
r.closedNotify = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
pingAttempt := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.closeConnection(ws.StatusNormalClosure, "")
|
||||
debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx))
|
||||
return
|
||||
case <-r.closedNotify:
|
||||
return
|
||||
case <-ticker.C:
|
||||
debugLogf("{%s} pinging\n", r.URL)
|
||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
|
||||
err := c.Ping(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
pingAttempt++
|
||||
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
|
||||
|
||||
if pingAttempt >= 3 {
|
||||
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
|
||||
err = r.Close() // this should trigger a context cancelation
|
||||
if err != nil {
|
||||
debugLogf("{%s} failed to close relay: %v", r.URL, err)
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// ping was OK
|
||||
debugLogf("{%s} ping OK", r.URL)
|
||||
pingAttempt = 0
|
||||
case wr := <-r.writeQueue:
|
||||
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
|
||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
|
||||
err := c.Write(ctx, ws.MessageText, wr.msg)
|
||||
cancel()
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
|
||||
r.closeConnection(ws.StatusAbnormalClosure, "write failed")
|
||||
if wr.answer != nil {
|
||||
wr.answer <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
if wr.answer != nil {
|
||||
close(wr.answer)
|
||||
}
|
||||
case msg := <-readQueue:
|
||||
debugLogf("{%s} received %v\n", r.URL, msg)
|
||||
r.handleMessage(msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// read loop -- loops back to the main loop
|
||||
go func() {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
for {
|
||||
buf.Reset()
|
||||
|
||||
_, reader, err := c.Reader(ctx)
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
|
||||
r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader")
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(buf, reader); err != nil {
|
||||
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
|
||||
r.closeConnection(ws.StatusAbnormalClosure, "failed to read")
|
||||
return
|
||||
}
|
||||
|
||||
readQueue <- string(buf.Bytes())
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||
wasClosed := r.closed.Swap(true)
|
||||
if !wasClosed {
|
||||
r.conn.Close(code, reason)
|
||||
r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason))
|
||||
r.closeMutex.Lock()
|
||||
close(r.closedNotify)
|
||||
close(r.writeQueue)
|
||||
r.conn = nil
|
||||
r.closeMutex.Unlock()
|
||||
}
|
||||
}
|
||||
+1
-1
@@ -17,7 +17,7 @@ func TestEOSEMadness(t *testing.T) {
|
||||
}, SubscriptionOptions{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeout := time.After(3 * time.Second)
|
||||
timeout := time.After(2 * time.Second)
|
||||
n := 0
|
||||
e := 0
|
||||
|
||||
|
||||
@@ -29,8 +29,6 @@ type Store interface {
|
||||
}
|
||||
```
|
||||
|
||||
[](https://pkg.go.dev/fiatjaf.com/nostr/eventstore) [](https://fiatjaf.com/nostr/eventstore/actions/workflows/test.yml)
|
||||
|
||||
## Available Implementations
|
||||
|
||||
- **bleve**: Full-text search and indexing using the Bleve search library
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore/lmdb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBleveFlow(t *testing.T) {
|
||||
@@ -21,7 +22,9 @@ func TestBleveFlow(t *testing.T) {
|
||||
Path: "/tmp/blevetest-bleve",
|
||||
RawEventStore: bb,
|
||||
}
|
||||
bl.Init()
|
||||
err := bl.Init()
|
||||
require.NoError(t, err, "init")
|
||||
|
||||
defer bl.Close()
|
||||
|
||||
willDelete := make([]nostr.Event, 0, 3)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"fiatjaf.com/nostr"
|
||||
)
|
||||
|
||||
func (b *BleveBackend) DeleteEvent(id nostr.ID) error {
|
||||
return b.index.Delete(id.Hex())
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package bleve
|
||||
|
||||
const (
|
||||
idField = "i"
|
||||
contentField = "c"
|
||||
kindField = "k"
|
||||
createdAtField = "a"
|
||||
pubkeyField = "p"
|
||||
)
|
||||
@@ -0,0 +1,104 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// lexer tokenizes the input string
|
||||
type Lexer struct {
|
||||
input string
|
||||
pos int
|
||||
|
||||
peekedQueue []Token
|
||||
}
|
||||
|
||||
func NewLexer(input string) *Lexer {
|
||||
return &Lexer{input: input, pos: 0}
|
||||
}
|
||||
|
||||
func (l *Lexer) peek() rune {
|
||||
if l.pos >= len(l.input) {
|
||||
return 0
|
||||
}
|
||||
return rune(l.input[l.pos])
|
||||
}
|
||||
|
||||
func (l *Lexer) advance() rune {
|
||||
if l.pos >= len(l.input) {
|
||||
return 0
|
||||
}
|
||||
ch := rune(l.input[l.pos])
|
||||
l.pos++
|
||||
return ch
|
||||
}
|
||||
|
||||
func (l *Lexer) skipWhitespace() {
|
||||
for l.peek() != 0 && unicode.IsSpace(l.peek()) {
|
||||
l.advance()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lexer) readWord() string {
|
||||
start := l.pos
|
||||
|
||||
// read regular word (alphanumeric, hyphens, underscores)
|
||||
for l.peek() != 0 && !unicode.IsSpace(l.peek()) &&
|
||||
l.peek() != '(' && l.peek() != ')' && l.peek() != '"' {
|
||||
l.advance()
|
||||
}
|
||||
|
||||
return l.input[start:l.pos]
|
||||
}
|
||||
|
||||
func (l *Lexer) PeekToken() Token {
|
||||
next := l.NextToken()
|
||||
l.peekedQueue = append(l.peekedQueue, next)
|
||||
return next
|
||||
}
|
||||
|
||||
func (l *Lexer) ReturnToken(tok Token) {
|
||||
l.peekedQueue = append(l.peekedQueue, tok)
|
||||
}
|
||||
|
||||
func (l *Lexer) NextToken() (tok Token) {
|
||||
if len(l.peekedQueue) > 0 {
|
||||
next := l.peekedQueue[len(l.peekedQueue)-1]
|
||||
l.peekedQueue = l.peekedQueue[0 : len(l.peekedQueue)-1]
|
||||
return next
|
||||
}
|
||||
|
||||
l.skipWhitespace()
|
||||
|
||||
if l.pos >= len(l.input) {
|
||||
return Token{Type: TokenEOF}
|
||||
}
|
||||
|
||||
ch := l.peek()
|
||||
|
||||
switch ch {
|
||||
case '(':
|
||||
l.advance()
|
||||
return Token{Type: TokenLParen, Value: "("}
|
||||
case ')':
|
||||
l.advance()
|
||||
return Token{Type: TokenRParen, Value: ")"}
|
||||
case '"':
|
||||
l.advance()
|
||||
return Token{Type: TokenQuote, Value: "\""}
|
||||
default:
|
||||
word := l.readWord()
|
||||
upperWord := strings.ToUpper(word)
|
||||
|
||||
switch upperWord {
|
||||
case "OR", "||":
|
||||
return Token{Type: TokenOR, Value: word}
|
||||
case "AND", "&&":
|
||||
return Token{Type: TokenAND, Value: word}
|
||||
case "NOT", "!":
|
||||
return Token{Type: TokenNOT, Value: word}
|
||||
default:
|
||||
return Token{Type: TokenWord, Value: word}
|
||||
}
|
||||
}
|
||||
}
|
||||
+425
-15
@@ -1,34 +1,101 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore"
|
||||
"fiatjaf.com/nostr/nip27"
|
||||
"fiatjaf.com/nostr/nip73"
|
||||
"fiatjaf.com/nostr/sdk"
|
||||
bleve "github.com/blevesearch/bleve/v2"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/analyzer/simple"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ar"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/cjk"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/da"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/de"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/en"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/es"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fa"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fi"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fr"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/gl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hi"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hr"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hu"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/in"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/it"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/nl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/no"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/pl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/pt"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ro"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ru"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/sv"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/tr"
|
||||
bleveMapping "github.com/blevesearch/bleve/v2/mapping"
|
||||
bleveQuery "github.com/blevesearch/bleve/v2/search/query"
|
||||
"github.com/pemistahl/lingua-go"
|
||||
)
|
||||
|
||||
var _ eventstore.Store = (*BleveBackend)(nil)
|
||||
const (
|
||||
labelContentField = "c"
|
||||
labelKindField = "k"
|
||||
labelCreatedAtField = "a"
|
||||
labelAuthorField = "p"
|
||||
labelReferencesField = "r"
|
||||
labelExtrasField = "x"
|
||||
)
|
||||
|
||||
var SupportedLanguages = []lingua.Language{
|
||||
// each of these translates to a specific bleve analyzer
|
||||
// except for japanese-korean-chinese that all use the same "cjk" analyzer
|
||||
lingua.Arabic,
|
||||
lingua.Chinese,
|
||||
lingua.Croatian,
|
||||
lingua.Danish,
|
||||
lingua.Dutch,
|
||||
lingua.English,
|
||||
lingua.Finnish,
|
||||
lingua.French,
|
||||
lingua.German,
|
||||
lingua.Hindi,
|
||||
lingua.Hungarian,
|
||||
lingua.Italian,
|
||||
lingua.Japanese,
|
||||
lingua.Korean,
|
||||
lingua.Persian,
|
||||
lingua.Polish,
|
||||
lingua.Portuguese,
|
||||
lingua.Romanian,
|
||||
lingua.Russian,
|
||||
lingua.Spanish,
|
||||
lingua.Swedish,
|
||||
lingua.Turkish,
|
||||
}
|
||||
|
||||
type BleveBackend struct {
|
||||
sync.Mutex
|
||||
// Path is where the index will be saved
|
||||
Path string
|
||||
|
||||
// RawEventStore is where we'll fetch the raw events from
|
||||
// bleve will only store ids, so the actual events must be somewhere else
|
||||
Path string
|
||||
RawEventStore eventstore.Store
|
||||
ReadOnly bool
|
||||
OpenTimeout time.Duration
|
||||
|
||||
index bleve.Index
|
||||
}
|
||||
IndexableKinds []nostr.Kind
|
||||
|
||||
func (b *BleveBackend) Close() {
|
||||
if b.index != nil {
|
||||
b.index.Close()
|
||||
}
|
||||
Languages []lingua.Language
|
||||
languageCodes []string
|
||||
|
||||
index bleve.Index
|
||||
detector lingua.LanguageDetector
|
||||
}
|
||||
|
||||
func (b *BleveBackend) Init() error {
|
||||
@@ -38,12 +105,94 @@ func (b *BleveBackend) Init() error {
|
||||
if b.RawEventStore == nil {
|
||||
return fmt.Errorf("missing RawEventStore")
|
||||
}
|
||||
if len(b.Languages) == 0 {
|
||||
return fmt.Errorf("missing Languages")
|
||||
}
|
||||
if len(b.IndexableKinds) == 0 {
|
||||
b.IndexableKinds = []nostr.Kind{0, 1, 6, 11, 16, 20, 21, 22, 24, 1111, 9802, 30023, 30818}
|
||||
}
|
||||
|
||||
// try to open existing index
|
||||
index, err := bleve.Open(b.Path)
|
||||
validLanguages := make([]lingua.Language, 0, len(b.Languages))
|
||||
b.languageCodes = make([]string, 0, len(b.Languages))
|
||||
for _, lang := range b.Languages {
|
||||
var code string
|
||||
|
||||
switch lang {
|
||||
case lingua.Chinese, lingua.Korean, lingua.Japanese:
|
||||
code = "cjk"
|
||||
default:
|
||||
code = strings.ToLower(lang.IsoCode639_1().String())
|
||||
}
|
||||
|
||||
if slices.Contains(b.languageCodes, code) {
|
||||
continue
|
||||
}
|
||||
|
||||
validLanguages = append(validLanguages, lang)
|
||||
b.languageCodes = append(b.languageCodes, code)
|
||||
}
|
||||
b.Languages = validLanguages
|
||||
|
||||
opts := map[string]any{
|
||||
"read_only": b.ReadOnly,
|
||||
}
|
||||
if b.OpenTimeout != 0 {
|
||||
opts["bolt_timeout"] = b.OpenTimeout.String()
|
||||
}
|
||||
|
||||
index, err := bleve.OpenUsing(b.Path, opts)
|
||||
if err == bleve.ErrorIndexPathDoesNotExist {
|
||||
// create new index with default mapping
|
||||
mapping := bleveMapping.NewIndexMapping()
|
||||
mapping.DefaultMapping.Dynamic = false
|
||||
doc := bleveMapping.NewDocumentStaticMapping()
|
||||
|
||||
for _, code := range b.languageCodes {
|
||||
contentField := bleveMapping.NewTextFieldMapping()
|
||||
contentField.Analyzer = code
|
||||
contentField.Store = false
|
||||
contentField.IncludeTermVectors = false
|
||||
contentField.DocValues = false
|
||||
contentField.IncludeInAll = false
|
||||
doc.AddFieldMappingsAt(labelContentField+"_"+code, contentField)
|
||||
}
|
||||
|
||||
extrasField := bleveMapping.NewTextFieldMapping()
|
||||
extrasField.Analyzer = "simple"
|
||||
extrasField.Store = false
|
||||
extrasField.IncludeTermVectors = false
|
||||
extrasField.DocValues = false
|
||||
extrasField.IncludeInAll = false
|
||||
doc.AddFieldMappingsAt(labelExtrasField, extrasField)
|
||||
|
||||
referencesField := bleveMapping.NewKeywordFieldMapping()
|
||||
referencesField.DocValues = false
|
||||
referencesField.Store = false
|
||||
referencesField.IncludeTermVectors = false
|
||||
referencesField.IncludeInAll = false
|
||||
doc.AddFieldMappingsAt(labelReferencesField, referencesField)
|
||||
|
||||
authorField := bleveMapping.NewKeywordFieldMapping()
|
||||
authorField.DocValues = false
|
||||
authorField.Store = false
|
||||
authorField.IncludeTermVectors = false
|
||||
doc.AddFieldMappingsAt(labelAuthorField, authorField)
|
||||
|
||||
kindField := bleveMapping.NewKeywordFieldMapping()
|
||||
kindField.DocValues = false
|
||||
kindField.Store = false
|
||||
kindField.IncludeTermVectors = false
|
||||
kindField.IncludeInAll = false
|
||||
doc.AddFieldMappingsAt(labelKindField, kindField)
|
||||
|
||||
timestampField := bleveMapping.NewDateTimeFieldMapping()
|
||||
timestampField.DocValues = false
|
||||
timestampField.Store = false
|
||||
timestampField.IncludeTermVectors = false
|
||||
timestampField.IncludeInAll = false
|
||||
doc.AddFieldMappingsAt(labelCreatedAtField, timestampField)
|
||||
|
||||
mapping.AddDocumentMapping("_default", doc)
|
||||
|
||||
index, err = bleve.New(b.Path, mapping)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating index: %w", err)
|
||||
@@ -53,6 +202,116 @@ func (b *BleveBackend) Init() error {
|
||||
}
|
||||
|
||||
b.index = index
|
||||
b.detector = lingua.NewLanguageDetectorBuilder().
|
||||
FromLanguages(b.Languages...).
|
||||
Build()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BleveBackend) Close() {
|
||||
if b != nil && b.index != nil {
|
||||
b.index.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BleveBackend) SaveEvent(event nostr.Event) error {
|
||||
if slices.Contains(b.IndexableKinds, event.Kind) {
|
||||
return b.indexEvent(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BleveBackend) DeleteEvent(id nostr.ID) error {
|
||||
if b != nil && b.index != nil {
|
||||
return b.index.Delete(id.Hex())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BleveBackend) indexEvent(evt nostr.Event) error {
|
||||
docID := evt.ID
|
||||
|
||||
var references []string
|
||||
var extras string
|
||||
|
||||
switch evt.Kind {
|
||||
case 6, 16:
|
||||
var innerEvt nostr.Event
|
||||
if err := json.Unmarshal([]byte(evt.Content), &innerEvt); err != nil || !innerEvt.VerifySignature() {
|
||||
return nil
|
||||
}
|
||||
evt = innerEvt
|
||||
case 0:
|
||||
var pm sdk.ProfileMetadata
|
||||
if err := json.Unmarshal([]byte(evt.Content), &pm); err == nil {
|
||||
evt.Content = pm.Name + "\n" + pm.DisplayName + "\n" + pm.About
|
||||
references = append(references, pm.NIP05)
|
||||
}
|
||||
case 9802:
|
||||
for _, tag := range evt.Tags {
|
||||
if len(tag) < 2 {
|
||||
continue
|
||||
}
|
||||
switch tag[0] {
|
||||
case "comment":
|
||||
evt.Content += "\n\n" + tag[1]
|
||||
case "e":
|
||||
if ptr, err := nostr.EventPointerFromTag(tag); err == nil {
|
||||
references = append(references, ptr.AsTagReference())
|
||||
}
|
||||
case "a":
|
||||
if ptr, err := nostr.EntityPointerFromTag(tag); err == nil {
|
||||
references = append(references, ptr.AsTagReference())
|
||||
}
|
||||
case "r":
|
||||
references = append(references, tag[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
doc := map[string]any{
|
||||
labelKindField: strconv.Itoa(int(evt.Kind)),
|
||||
labelAuthorField: evt.PubKey.Hex()[56:],
|
||||
labelCreatedAtField: evt.CreatedAt.Time(),
|
||||
}
|
||||
|
||||
content := strings.Builder{}
|
||||
content.Grow(len(evt.Content))
|
||||
|
||||
for block := range nip27.Parse(evt.Content) {
|
||||
if block.Pointer == nil {
|
||||
content.WriteString(strings.TrimSpace(block.Text))
|
||||
} else {
|
||||
references = append(references, block.Pointer.AsTagReference())
|
||||
if ep, ok := block.Pointer.(nip73.ExternalPointer); ok {
|
||||
extras += ep.Thing + " "
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
indexableContent := content.String()
|
||||
lang, ok := b.detector.DetectLanguageOf(indexableContent)
|
||||
if !ok {
|
||||
lang = lingua.English
|
||||
}
|
||||
|
||||
var analyzerLangCode string
|
||||
switch lang {
|
||||
case lingua.Japanese, lingua.Chinese, lingua.Korean:
|
||||
analyzerLangCode = "cjk"
|
||||
default:
|
||||
analyzerLangCode = strings.ToLower(lang.IsoCode639_1().String())
|
||||
}
|
||||
doc[labelContentField+"_"+analyzerLangCode] = indexableContent
|
||||
|
||||
doc[labelReferencesField] = references
|
||||
doc[labelExtrasField] = extras
|
||||
|
||||
if err := b.index.Index(docID.Hex(), doc); err != nil {
|
||||
return fmt.Errorf("failed to index '%s' document: %w", docID.Hex(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,3 +323,154 @@ func (b *BleveBackend) CountEvents(filter nostr.Filter) (uint32, error) {
|
||||
|
||||
return 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
func (b *BleveBackend) QueryEvents(filter nostr.Filter, maxLimit int) iter.Seq[nostr.Event] {
|
||||
return func(yield func(nostr.Event) bool) {
|
||||
if tlimit := filter.GetTheoreticalLimit(); tlimit == 0 {
|
||||
return
|
||||
} else if tlimit < maxLimit {
|
||||
maxLimit = tlimit
|
||||
}
|
||||
|
||||
filter.Search = strings.TrimSpace(filter.Search)
|
||||
if len(filter.Search) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
and := make([]bleveQuery.Query, 0, 3)
|
||||
|
||||
searchC := strings.Builder{}
|
||||
searchC.Grow(len(filter.Search))
|
||||
|
||||
for block := range nip27.Parse(filter.Search) {
|
||||
if block.Pointer != nil {
|
||||
genericRef := bleve.NewTermQuery(block.Pointer.AsTagReference())
|
||||
genericRef.SetField(labelReferencesField)
|
||||
genericRef.SetBoost(2)
|
||||
|
||||
var ref bleveQuery.Query = genericRef
|
||||
if profile, ok := block.Pointer.(nostr.ProfilePointer); ok {
|
||||
authorQuery := bleve.NewTermQuery(profile.PublicKey.Hex()[56:])
|
||||
authorQuery.SetField(labelAuthorField)
|
||||
authorQuery.SetBoost(2)
|
||||
orRef := bleve.NewDisjunctionQuery()
|
||||
orRef.AddQuery(genericRef)
|
||||
orRef.AddQuery(authorQuery)
|
||||
ref = orRef
|
||||
} else if addr, ok := block.Pointer.(nostr.EntityPointer); ok {
|
||||
authorQuery := bleve.NewTermQuery(addr.PublicKey.Hex()[56:])
|
||||
authorQuery.SetField(labelAuthorField)
|
||||
authorQuery.SetBoost(2)
|
||||
orRef := bleve.NewDisjunctionQuery()
|
||||
orRef.AddQuery(genericRef)
|
||||
orRef.AddQuery(authorQuery)
|
||||
ref = orRef
|
||||
}
|
||||
and = append(and, ref)
|
||||
} else {
|
||||
searchC.WriteString(strings.TrimSpace(block.Text))
|
||||
}
|
||||
}
|
||||
|
||||
searchContent := searchC.String()
|
||||
|
||||
var exactMatches []string
|
||||
if len(searchContent) > 0 {
|
||||
contentQueries := make([]bleveQuery.Query, 0, len(b.Languages)+1)
|
||||
|
||||
searchQ, exactMatches_, err := parse(searchContent, labelContentField+"_"+b.languageCodes[0])
|
||||
if err != nil {
|
||||
for _, code := range b.languageCodes {
|
||||
match := bleve.NewMatchQuery(searchContent)
|
||||
match.SetField(labelContentField + "_" + code)
|
||||
contentQueries = append(contentQueries, match)
|
||||
}
|
||||
} else {
|
||||
contentQueries = append(contentQueries, searchQ)
|
||||
for _, code := range b.languageCodes[1:] {
|
||||
searchQ, _, _ := parse(searchContent, labelContentField+"_"+code)
|
||||
contentQueries = append(contentQueries, searchQ)
|
||||
}
|
||||
}
|
||||
exactMatches = exactMatches_
|
||||
|
||||
extrasQ := bleve.NewMatchQuery(searchContent)
|
||||
extrasQ.SetField(labelExtrasField)
|
||||
contentQueries = append(contentQueries, extrasQ)
|
||||
|
||||
and = append(and, bleveQuery.NewDisjunctionQuery(contentQueries))
|
||||
}
|
||||
|
||||
if len(filter.Kinds) > 0 {
|
||||
eitherKind := bleve.NewDisjunctionQuery()
|
||||
for _, kind := range filter.Kinds {
|
||||
kindQ := bleve.NewTermQuery(strconv.Itoa(int(kind)))
|
||||
kindQ.SetField(labelKindField)
|
||||
eitherKind.AddQuery(kindQ)
|
||||
}
|
||||
and = append(and, eitherKind)
|
||||
}
|
||||
|
||||
if len(filter.Authors) > 0 {
|
||||
eitherPubkey := bleve.NewDisjunctionQuery()
|
||||
for _, pubkey := range filter.Authors {
|
||||
pubkeyQ := bleve.NewTermQuery(pubkey.Hex()[56:])
|
||||
pubkeyQ.SetField(labelAuthorField)
|
||||
eitherPubkey.AddQuery(pubkeyQ)
|
||||
}
|
||||
and = append(and, eitherPubkey)
|
||||
}
|
||||
|
||||
if filter.Since != 0 || filter.Until != 0 {
|
||||
var min time.Time
|
||||
if filter.Since != 0 {
|
||||
min = filter.Since.Time()
|
||||
}
|
||||
var max time.Time
|
||||
if filter.Until != 0 {
|
||||
max = filter.Until.Time()
|
||||
} else {
|
||||
max = time.Now()
|
||||
}
|
||||
dateRangeQ := bleve.NewDateRangeQuery(min, max)
|
||||
dateRangeQ.SetField(labelCreatedAtField)
|
||||
and = append(and, dateRangeQ)
|
||||
}
|
||||
|
||||
q := bleveQuery.NewConjunctionQuery(and)
|
||||
req := bleve.NewSearchRequest(q)
|
||||
req.Size = maxLimit
|
||||
req.From = 0
|
||||
req.Explain = true
|
||||
|
||||
result, err := b.index.Search(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resultHit:
|
||||
for _, hit := range result.Hits {
|
||||
id, err := nostr.IDFromHex(hit.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for evt := range b.RawEventStore.QueryEvents(nostr.Filter{IDs: []nostr.ID{id}}, 1) {
|
||||
for _, exactMatch := range exactMatches {
|
||||
if !strings.Contains(strings.ToLower(evt.Content), exactMatch) {
|
||||
continue resultHit
|
||||
}
|
||||
}
|
||||
|
||||
for f, v := range filter.Tags {
|
||||
if !evt.Tags.ContainsAny(f, v) {
|
||||
continue resultHit
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(evt) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"strconv"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
bleve "github.com/blevesearch/bleve/v2"
|
||||
"github.com/blevesearch/bleve/v2/search/query"
|
||||
)
|
||||
|
||||
func (b *BleveBackend) QueryEvents(filter nostr.Filter, maxLimit int) iter.Seq[nostr.Event] {
|
||||
return func(yield func(nostr.Event) bool) {
|
||||
if tlimit := filter.GetTheoreticalLimit(); tlimit == 0 {
|
||||
return
|
||||
} else if tlimit < maxLimit {
|
||||
maxLimit = tlimit
|
||||
}
|
||||
|
||||
if len(filter.Search) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
searchQ := bleve.NewMatchQuery(filter.Search)
|
||||
searchQ.SetField(contentField)
|
||||
var q query.Query = searchQ
|
||||
|
||||
conjQueries := []query.Query{searchQ}
|
||||
|
||||
if len(filter.Kinds) > 0 {
|
||||
eitherKind := bleve.NewDisjunctionQuery()
|
||||
for _, kind := range filter.Kinds {
|
||||
kindQ := bleve.NewTermQuery(strconv.Itoa(int(kind)))
|
||||
kindQ.SetField(kindField)
|
||||
eitherKind.AddQuery(kindQ)
|
||||
}
|
||||
conjQueries = append(conjQueries, eitherKind)
|
||||
}
|
||||
|
||||
if len(filter.Authors) > 0 {
|
||||
eitherPubkey := bleve.NewDisjunctionQuery()
|
||||
for _, pubkey := range filter.Authors {
|
||||
if len(pubkey) != 64 {
|
||||
continue
|
||||
}
|
||||
pubkeyQ := bleve.NewTermQuery(pubkey.Hex()[56:])
|
||||
pubkeyQ.SetField(pubkeyField)
|
||||
eitherPubkey.AddQuery(pubkeyQ)
|
||||
}
|
||||
conjQueries = append(conjQueries, eitherPubkey)
|
||||
}
|
||||
|
||||
if filter.Since != 0 || filter.Until != 0 {
|
||||
var min *float64
|
||||
if filter.Since != 0 {
|
||||
minVal := float64(filter.Since)
|
||||
min = &minVal
|
||||
}
|
||||
var max *float64
|
||||
if filter.Until != 0 {
|
||||
maxVal := float64(filter.Until)
|
||||
max = &maxVal
|
||||
}
|
||||
dateRangeQ := bleve.NewNumericRangeInclusiveQuery(min, max, nil, nil)
|
||||
dateRangeQ.SetField(createdAtField)
|
||||
conjQueries = append(conjQueries, dateRangeQ)
|
||||
}
|
||||
|
||||
if len(conjQueries) > 1 {
|
||||
q = bleve.NewConjunctionQuery(conjQueries...)
|
||||
}
|
||||
|
||||
req := bleve.NewSearchRequest(q)
|
||||
req.Size = maxLimit
|
||||
req.From = 0
|
||||
|
||||
result, err := b.index.Search(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, hit := range result.Hits {
|
||||
id, err := nostr.IDFromHex(hit.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for evt := range b.RawEventStore.QueryEvents(nostr.Filter{IDs: []nostr.ID{id}}, 1) {
|
||||
if !yield(evt) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
bleve "github.com/blevesearch/bleve/v2"
|
||||
bleveQuery "github.com/blevesearch/bleve/v2/search/query"
|
||||
)
|
||||
|
||||
// token types
|
||||
type TokenType int
|
||||
|
||||
const (
|
||||
TokenWord TokenType = iota
|
||||
TokenOR
|
||||
TokenAND
|
||||
TokenNOT
|
||||
TokenLParen
|
||||
TokenRParen
|
||||
TokenQuote
|
||||
TokenEOF
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Type TokenType
|
||||
Value string
|
||||
}
|
||||
|
||||
type Parser struct {
|
||||
lexer *Lexer
|
||||
field string
|
||||
}
|
||||
|
||||
func parse(input string, field string) (bleveQuery.Query, []string, error) {
|
||||
lexer := NewLexer(input)
|
||||
p := &Parser{
|
||||
lexer: lexer,
|
||||
}
|
||||
|
||||
var exactMatches []string
|
||||
var reusableCurrentMatch strings.Builder
|
||||
var currentExactMatch *strings.Builder
|
||||
var currentWords []string
|
||||
var negated bool
|
||||
var parents []bleveQuery.Query
|
||||
var parentOps []TokenType // tracks if parent should be AND or OR
|
||||
var lastOp TokenType = TokenAND // track last operator for parentheses
|
||||
|
||||
curr := bleve.NewBooleanQuery()
|
||||
|
||||
for {
|
||||
token := p.lexer.NextToken()
|
||||
|
||||
if token.Type == TokenEOF {
|
||||
if len(currentWords) > 0 {
|
||||
match := bleve.NewMatchQuery(strings.Join(currentWords, " "))
|
||||
match.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||
match.SetField(field)
|
||||
if negated {
|
||||
curr.AddMustNot(match)
|
||||
} else {
|
||||
curr.AddMust(match)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if token.Type == TokenQuote {
|
||||
if currentExactMatch == nil {
|
||||
currentExactMatch = &reusableCurrentMatch
|
||||
} else {
|
||||
exactMatches = append(exactMatches, currentExactMatch.String())
|
||||
currentExactMatch.Reset()
|
||||
reusableCurrentMatch = *currentExactMatch
|
||||
currentExactMatch = nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentExactMatch != nil {
|
||||
if currentExactMatch.Len() > 0 {
|
||||
currentExactMatch.WriteByte(' ')
|
||||
}
|
||||
currentExactMatch.WriteString(strings.ToLower(token.Value))
|
||||
currentWords = append(currentWords, token.Value)
|
||||
continue
|
||||
}
|
||||
|
||||
if token.Type == TokenWord {
|
||||
currentWords = append(currentWords, token.Value)
|
||||
continue
|
||||
} else if len(currentWords) > 0 {
|
||||
match := bleve.NewMatchQuery(strings.Join(currentWords, " "))
|
||||
match.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||
match.SetField(field)
|
||||
if negated {
|
||||
curr.AddMustNot(match)
|
||||
} else {
|
||||
curr.AddMust(match)
|
||||
}
|
||||
currentWords = currentWords[:0]
|
||||
negated = false
|
||||
}
|
||||
|
||||
switch token.Type {
|
||||
case TokenLParen:
|
||||
// push current query to parents stack with the last operator
|
||||
parents = append(parents, curr)
|
||||
parentOps = append(parentOps, lastOp)
|
||||
// reset lastOp to default for inner parentheses
|
||||
lastOp = TokenAND
|
||||
// start new boolean query for parentheses content
|
||||
curr = bleve.NewBooleanQuery()
|
||||
continue
|
||||
case TokenRParen:
|
||||
// finalize any remaining words
|
||||
if len(currentWords) > 0 {
|
||||
match := bleve.NewMatchQuery(strings.Join(currentWords, " "))
|
||||
match.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||
match.SetField(field)
|
||||
if negated {
|
||||
curr.AddMustNot(match)
|
||||
} else {
|
||||
curr.AddMust(match)
|
||||
}
|
||||
currentWords = currentWords[:0]
|
||||
negated = false
|
||||
}
|
||||
|
||||
// pop parent and merge with current
|
||||
if len(parents) > 0 {
|
||||
parent := parents[len(parents)-1]
|
||||
op := parentOps[len(parentOps)-1]
|
||||
|
||||
// create a new boolean query to combine parent and current
|
||||
var combined bleveQuery.Query
|
||||
switch op {
|
||||
case TokenOR:
|
||||
or := bleve.NewDisjunctionQuery()
|
||||
or.AddQuery(parent)
|
||||
or.AddQuery(curr)
|
||||
combined = or
|
||||
case TokenAND:
|
||||
and := bleve.NewConjunctionQuery()
|
||||
and.AddQuery(parent)
|
||||
and.AddQuery(curr)
|
||||
combined = and
|
||||
}
|
||||
|
||||
curr = bleve.NewBooleanQuery()
|
||||
curr.AddMust(combined)
|
||||
parents = parents[:len(parents)-1]
|
||||
parentOps = parentOps[:len(parentOps)-1]
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
next := p.lexer.NextToken()
|
||||
following := p.lexer.PeekToken()
|
||||
if next.Type == TokenNOT {
|
||||
negated = true
|
||||
}
|
||||
|
||||
switch token.Type {
|
||||
case TokenOR:
|
||||
if next.Type != TokenLParen && !(next.Type == TokenNOT && following.Type == TokenLParen) {
|
||||
// if this is not followed by a "(" or "NOT (" consider the follow next word as the only parameter
|
||||
other := bleve.NewMatchQuery(next.Value)
|
||||
other.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||
other.SetField(field)
|
||||
or := bleve.NewDisjunctionQuery()
|
||||
or.AddQuery(curr)
|
||||
or.AddQuery(other)
|
||||
curr = bleve.NewBooleanQuery()
|
||||
curr.AddMust(or)
|
||||
} else {
|
||||
lastOp = TokenOR
|
||||
}
|
||||
case TokenAND:
|
||||
if next.Type != TokenLParen && !(next.Type == TokenNOT && following.Type == TokenLParen) {
|
||||
// if this is not followed by a "(" consider the follow next word as the only parameter
|
||||
other := bleve.NewMatchQuery(next.Value)
|
||||
other.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||
other.SetField(field)
|
||||
and := bleve.NewConjunctionQuery()
|
||||
and.AddQuery(curr)
|
||||
and.AddQuery(other)
|
||||
curr = bleve.NewBooleanQuery()
|
||||
curr.AddMust(and)
|
||||
} else {
|
||||
lastOp = TokenAND
|
||||
}
|
||||
case TokenNOT:
|
||||
if next.Type != TokenLParen {
|
||||
// if this is not followed by a "(" or "NOT (" consider the follow next word as the only parameter
|
||||
other := bleve.NewMatchQuery(next.Value)
|
||||
other.SetOperator(bleveQuery.MatchQueryOperatorAnd)
|
||||
other.SetField(field)
|
||||
curr.AddMustNot(other)
|
||||
} else {
|
||||
negated = true
|
||||
}
|
||||
default:
|
||||
p.lexer.ReturnToken(next)
|
||||
}
|
||||
}
|
||||
|
||||
return curr, exactMatches, nil
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/blevesearch/bleve/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseQuery(t *testing.T) {
|
||||
mapping := bleve.NewIndexMapping()
|
||||
mapping.DefaultAnalyzer = "en"
|
||||
index, err := bleve.NewMemOnly(mapping)
|
||||
require.NoError(t, err)
|
||||
|
||||
docs := []map[string]interface{}{
|
||||
{"id": "1", "phrase": "I like fruit especially banana and strawberry"},
|
||||
{"id": "2", "phrase": "I like fruit like apples and oranges"},
|
||||
{"id": "3", "phrase": "I like vegetables but not fruit"},
|
||||
{"id": "4", "phrase": "Banana bread is delicious"},
|
||||
{"id": "5", "phrase": "Strawberry jam and banana smoothie"},
|
||||
}
|
||||
|
||||
for _, doc := range docs {
|
||||
err := index.Index(doc["id"].(string), doc)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
testQueries := []struct {
|
||||
query string
|
||||
expected int
|
||||
exactMatches []string
|
||||
}{
|
||||
{"fruit", 3, nil},
|
||||
{"banana (NOT delicious)", 2, nil},
|
||||
{"banana (NOT delicious) bread", 0, nil},
|
||||
{"smoothie OR apples", 2, nil},
|
||||
{"smoothie OR apples (NOT fruit)", 1, nil},
|
||||
{"\"I like\"", 3, []string{"i like"}},
|
||||
{"banana \"I like fruit\" strawberries", 1, []string{"i like fruit"}},
|
||||
{"\"I like fruit\" (strawberry OR apple)", 2, []string{"i like fruit"}},
|
||||
}
|
||||
|
||||
for _, test := range testQueries {
|
||||
query, exactMatches, err := parse(test.query, "phrase")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, test.exactMatches, exactMatches)
|
||||
|
||||
search := bleve.NewSearchRequest(query)
|
||||
results, err := index.Search(search)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, test.expected, int(results.Total),
|
||||
"query '%s' expected %d results, got %d", test.query, test.expected, results.Total)
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore"
|
||||
)
|
||||
|
||||
func (b *BleveBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
filter := nostr.Filter{Kinds: []nostr.Kind{evt.Kind}, Authors: []nostr.PubKey{evt.PubKey}}
|
||||
if evt.Kind.IsAddressable() {
|
||||
filter.Tags = nostr.TagMap{"d": []string{evt.Tags.GetD()}}
|
||||
}
|
||||
|
||||
shouldStore := true
|
||||
for previous := range b.QueryEvents(filter, 1) {
|
||||
if nostr.IsOlder(previous, evt) {
|
||||
if err := b.DeleteEvent(previous.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete event for replacing: %w", err)
|
||||
}
|
||||
} else {
|
||||
shouldStore = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldStore {
|
||||
if err := b.SaveEvent(evt); err != nil && err != eventstore.ErrDupEvent {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package bleve
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
)
|
||||
|
||||
func (b *BleveBackend) SaveEvent(evt nostr.Event) error {
|
||||
doc := map[string]interface{}{
|
||||
contentField: evt.Content,
|
||||
kindField: strconv.Itoa(int(evt.Kind)),
|
||||
pubkeyField: evt.PubKey.Hex()[56:],
|
||||
createdAtField: float64(evt.CreatedAt),
|
||||
}
|
||||
|
||||
if err := b.index.Index(evt.ID.Hex(), doc); err != nil {
|
||||
return fmt.Errorf("failed to index '%s' document: %w", evt.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -28,6 +28,8 @@ type BoltBackend struct {
|
||||
MapSize int64
|
||||
DB *bbolt.DB
|
||||
|
||||
ReadOnly bool
|
||||
|
||||
EnableHLLCacheFor func(kind nostr.Kind) (useCache bool, skipSavingActualEvent bool)
|
||||
}
|
||||
|
||||
@@ -36,6 +38,7 @@ func (b *BoltBackend) Init() error {
|
||||
Timeout: 2 * time.Second,
|
||||
PreLoadFreelist: true,
|
||||
FreelistType: bbolt.FreelistMapType,
|
||||
ReadOnly: b.ReadOnly,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
return b.DB.Update(func(txn *bbolt.Tx) error {
|
||||
func (b *BoltBackend) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||
err = b.DB.Update(func(txn *bbolt.Tx) error {
|
||||
rawBucket := txn.Bucket(rawEventStore)
|
||||
|
||||
// check if we already have this id
|
||||
@@ -25,12 +25,12 @@ func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
}
|
||||
|
||||
// now we fetch the past events, whatever they are, delete them and then save the new
|
||||
var err error
|
||||
var qerr error
|
||||
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
||||
err = b.query(txn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
||||
qerr = b.query(txn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query past events with %s: %w", filter, err)
|
||||
if qerr != nil {
|
||||
return fmt.Errorf("failed to query past events with %s: %w", filter, qerr)
|
||||
}
|
||||
|
||||
shouldStore := true
|
||||
@@ -39,6 +39,7 @@ func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
if err := b.delete(txn, previous.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, err)
|
||||
}
|
||||
deleted = append(deleted, previous)
|
||||
} else {
|
||||
// there is a newer event already stored, so we won't store this
|
||||
shouldStore = false
|
||||
@@ -50,4 +51,5 @@ func (b *BoltBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
|
||||
return nil
|
||||
})
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
@@ -40,12 +40,12 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
||||
buf[0] = 0
|
||||
|
||||
if evt.Kind > MaxKind {
|
||||
return fmt.Errorf("kind is too big: %d, max is %d", evt.Kind, MaxKind)
|
||||
return fmt.Errorf("kind is too big: %d, max is %d", evt.Kind, uint16(MaxKind))
|
||||
}
|
||||
binary.LittleEndian.PutUint16(buf[1:3], uint16(evt.Kind))
|
||||
|
||||
if evt.CreatedAt > MaxCreatedAt {
|
||||
return fmt.Errorf("created_at is too big: %d, max is %d", evt.CreatedAt, MaxCreatedAt)
|
||||
return fmt.Errorf("created_at is too big: %d, max is %d", evt.CreatedAt, uint32(MaxCreatedAt))
|
||||
}
|
||||
binary.LittleEndian.PutUint32(buf[3:7], uint32(evt.CreatedAt))
|
||||
|
||||
@@ -58,7 +58,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
||||
|
||||
ntags := len(evt.Tags)
|
||||
if ntags > MaxTagCount {
|
||||
return fmt.Errorf("can't encode too many tags: %d, max is %d", ntags, MaxTagCount)
|
||||
return fmt.Errorf("can't encode too many tags: %d, max is %d", ntags, uint16(MaxTagCount))
|
||||
}
|
||||
binary.LittleEndian.PutUint16(buf[137:139], uint16(ntags))
|
||||
|
||||
@@ -68,7 +68,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
||||
|
||||
itemCount := len(tag)
|
||||
if itemCount > MaxTagItemCount {
|
||||
return fmt.Errorf("can't encode a tag with so many items: %d, max is %d", itemCount, MaxTagItemCount)
|
||||
return fmt.Errorf("can't encode a tag with so many items: %d, max is %d", itemCount, uint8(MaxTagItemCount))
|
||||
}
|
||||
buf[tagBase+tagOffset] = uint8(itemCount)
|
||||
|
||||
@@ -76,7 +76,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
||||
for _, item := range tag {
|
||||
itemSize := len(item)
|
||||
if itemSize > MaxTagItemSize {
|
||||
return fmt.Errorf("tag item is too large: %d, max is %d", itemSize, MaxTagItemSize)
|
||||
return fmt.Errorf("tag item is too large: %d, max is %d", itemSize, uint16(MaxTagItemSize))
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint16(buf[tagBase+tagOffset+itemOffset:], uint16(itemSize))
|
||||
@@ -91,7 +91,7 @@ func Marshal(evt nostr.Event, buf []byte) error {
|
||||
|
||||
// content
|
||||
if contentLength := len(evt.Content); contentLength > MaxContentSize {
|
||||
return fmt.Errorf("content is too large: %d, max is %d", contentLength, MaxContentSize)
|
||||
return fmt.Errorf("content is too large: %d, max is %d", contentLength, uint16(MaxContentSize))
|
||||
} else {
|
||||
binary.LittleEndian.PutUint16(buf[tagBase+tagsSectionLength:], uint16(contentLength))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package checks
|
||||
|
||||
import (
|
||||
"fiatjaf.com/nostr/eventstore"
|
||||
"fiatjaf.com/nostr/eventstore/bleve"
|
||||
"fiatjaf.com/nostr/eventstore/boltdb"
|
||||
"fiatjaf.com/nostr/eventstore/lmdb"
|
||||
"fiatjaf.com/nostr/eventstore/mmm"
|
||||
@@ -13,5 +12,4 @@ var (
|
||||
_ eventstore.Store = (*lmdb.LMDBBackend)(nil)
|
||||
_ eventstore.Store = (*mmm.IndexingLayer)(nil)
|
||||
_ eventstore.Store = (*boltdb.BoltBackend)(nil)
|
||||
_ eventstore.Store = (*bleve.BleveBackend)(nil)
|
||||
)
|
||||
|
||||
@@ -36,18 +36,15 @@ func (b *LMDBBackend) CountEvents(filter nostr.Filter) (uint32, error) {
|
||||
// we already have a k and a v and an err from the cursor setup, so check and use these
|
||||
if it.exhausted ||
|
||||
it.err != nil ||
|
||||
len(it.key) != q.keySize ||
|
||||
len(it.key) != len(q.prefix)+4 ||
|
||||
!bytes.HasPrefix(it.key, q.prefix) {
|
||||
// either iteration has errored or we reached the end of this prefix
|
||||
break // stop this cursor and move to the next one
|
||||
}
|
||||
|
||||
// "id" indexes don't contain a timestamp
|
||||
if q.dbi != b.indexId {
|
||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||
if createdAt < since {
|
||||
break
|
||||
}
|
||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||
if createdAt < since {
|
||||
break
|
||||
}
|
||||
|
||||
if extraAuthors == nil && extraKinds == nil && extraTagValues == nil {
|
||||
@@ -129,18 +126,15 @@ func (b *LMDBBackend) CountEventsHLL(filter nostr.Filter, offset int) (uint32, *
|
||||
for {
|
||||
// we already have a k and a v and an err from the cursor setup, so check and use these
|
||||
if it.err != nil ||
|
||||
len(it.key) != q.keySize ||
|
||||
len(it.key) != len(q.prefix)+4 ||
|
||||
!bytes.HasPrefix(it.key, q.prefix) {
|
||||
// either iteration has errored or we reached the end of this prefix
|
||||
break // stop this cursor and move to the next one
|
||||
}
|
||||
|
||||
// "id" indexes don't contain a timestamp
|
||||
if q.dbi != b.indexId {
|
||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||
if createdAt < since {
|
||||
break
|
||||
}
|
||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||
if createdAt < since {
|
||||
break
|
||||
}
|
||||
|
||||
// fetch actual event (we need it regardless because we need the pubkey for the hll)
|
||||
|
||||
@@ -45,7 +45,7 @@ func (it *iterator) pull(n int, since uint32) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(it.key) != query.keySize || !bytes.HasPrefix(it.key, query.prefix) {
|
||||
if len(it.key) != len(query.prefix)+4 || !bytes.HasPrefix(it.key, query.prefix) {
|
||||
// we reached the end of this prefix
|
||||
it.exhausted = true
|
||||
return
|
||||
|
||||
+1
-27
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore"
|
||||
@@ -34,8 +33,6 @@ type LMDBBackend struct {
|
||||
|
||||
hllCache lmdb.DBI
|
||||
EnableHLLCacheFor func(kind nostr.Kind) (useCache bool, skipSavingActualEvent bool)
|
||||
|
||||
lastId atomic.Uint32
|
||||
}
|
||||
|
||||
func (b *LMDBBackend) Init() error {
|
||||
@@ -112,7 +109,7 @@ func (b *LMDBBackend) initialize() error {
|
||||
env.SetMapSize(b.MapSize)
|
||||
}
|
||||
|
||||
if err := env.Open(b.Path, lmdb.NoTLS|lmdb.WriteMap|b.extraFlags, 0644); err != nil {
|
||||
if err := env.Open(b.Path, lmdb.NoTLS|b.extraFlags, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
b.lmdbEnv = env
|
||||
@@ -186,28 +183,5 @@ func (b *LMDBBackend) initialize() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// get lastId
|
||||
if err := b.lmdbEnv.View(func(txn *lmdb.Txn) error {
|
||||
txn.RawRead = true
|
||||
cursor, err := txn.OpenCursor(b.rawEventStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cursor.Close()
|
||||
k, _, err := cursor.Get(nil, nil, lmdb.Last)
|
||||
if lmdb.IsNotFound(err) {
|
||||
// nothing found, so we're at zero
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.lastId.Store(binary.BigEndian.Uint32(k))
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return b.migrate()
|
||||
}
|
||||
|
||||
@@ -54,7 +54,6 @@ func (b *LMDBBackend) queryByIds(txn *lmdb.Txn, ids []nostr.ID, yield func(nostr
|
||||
continue
|
||||
}
|
||||
|
||||
txn.Get(b.rawEventStore, idx)
|
||||
bin, err := txn.Get(b.rawEventStore, idx)
|
||||
if err != nil {
|
||||
continue
|
||||
|
||||
@@ -14,7 +14,6 @@ type query struct {
|
||||
i int
|
||||
dbi lmdb.DBI
|
||||
prefix []byte
|
||||
keySize int
|
||||
startingPoint []byte
|
||||
}
|
||||
|
||||
@@ -40,10 +39,10 @@ func (b *LMDBBackend) prepareQueries(filter nostr.Filter) (
|
||||
}
|
||||
}
|
||||
for i, q := range queries {
|
||||
sp := make([]byte, len(q.prefix))
|
||||
sp = sp[0:len(q.prefix)]
|
||||
copy(sp, q.prefix)
|
||||
queries[i].startingPoint = binary.BigEndian.AppendUint32(sp, uint32(until))
|
||||
sp := make([]byte, len(q.prefix)+4)
|
||||
copy(sp[0:len(q.prefix)], q.prefix)
|
||||
binary.BigEndian.PutUint32(sp[len(q.prefix):], uint32(until))
|
||||
queries[i].startingPoint = sp
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -64,39 +63,27 @@ func (b *LMDBBackend) prepareQueries(filter nostr.Filter) (
|
||||
}
|
||||
|
||||
// only "p" tag has a goodness of 2, so
|
||||
if goodness == 2 {
|
||||
if goodness == 2 && filter.Kinds != nil {
|
||||
// this means we got a "p" tag, so we will use the ptag-kind index
|
||||
i := 0
|
||||
if filter.Kinds != nil {
|
||||
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
||||
for _, value := range tagValues {
|
||||
if len(value) != 64 {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
|
||||
for _, kind := range filter.Kinds {
|
||||
k := make([]byte, 8+2)
|
||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
||||
queries[i] = query{i: i, dbi: b.indexPTagKind, prefix: k[0 : 8+2], keySize: 8 + 2 + 4}
|
||||
i++
|
||||
}
|
||||
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
||||
for _, value := range tagValues {
|
||||
if len(value) != 64 {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
} else {
|
||||
// even if there are no kinds, in that case we will just return any kind and not care
|
||||
queries = make([]query, len(tagValues))
|
||||
for i, value := range tagValues {
|
||||
if len(value) != 64 {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
|
||||
k := make([]byte, 8)
|
||||
for _, kind := range filter.Kinds {
|
||||
k := make([]byte, 8+2)
|
||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
queries[i] = query{i: i, dbi: b.indexPTagKind, prefix: k[0:8], keySize: 8 + 2 + 4}
|
||||
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: b.indexPTagKind,
|
||||
prefix: k[0 : 8+2],
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -107,7 +94,11 @@ func (b *LMDBBackend) prepareQueries(filter nostr.Filter) (
|
||||
dbi, k, offset := b.getTagIndexPrefix(tagKey, value)
|
||||
// remove the last parts part to get just the prefix we want here
|
||||
prefix := k[0:offset]
|
||||
queries[i] = query{i: i, dbi: dbi, prefix: prefix, keySize: len(prefix) + 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: dbi,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// add an extra kind filter if available (only do this on plain tag index, not on ptag-kind index)
|
||||
@@ -142,7 +133,11 @@ pubkeyMatching:
|
||||
// will use pubkey index
|
||||
queries = make([]query, len(filter.Authors))
|
||||
for i, pk := range filter.Authors {
|
||||
queries[i] = query{i: i, dbi: b.indexPubkey, prefix: pk[0:8], keySize: 8 + 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: b.indexPubkey,
|
||||
prefix: pk[0:8],
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// will use pubkeyKind index
|
||||
@@ -153,7 +148,11 @@ pubkeyMatching:
|
||||
prefix := make([]byte, 8+2)
|
||||
copy(prefix[0:8], pk[0:8])
|
||||
binary.BigEndian.PutUint16(prefix[8:8+2], uint16(kind))
|
||||
queries[i] = query{i: i, dbi: b.indexPubkeyKind, prefix: prefix[0 : 8+2], keySize: 10 + 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: b.indexPubkeyKind,
|
||||
prefix: prefix[0 : 8+2],
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
@@ -170,7 +169,11 @@ pubkeyMatching:
|
||||
for i, kind := range filter.Kinds {
|
||||
prefix := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(prefix[0:2], uint16(kind))
|
||||
queries[i] = query{i: i, dbi: b.indexKind, prefix: prefix[0:2], keySize: 2 + 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: b.indexKind,
|
||||
prefix: prefix[0:2],
|
||||
}
|
||||
}
|
||||
|
||||
// potentially with an extra useless tag filtering
|
||||
@@ -181,6 +184,10 @@ pubkeyMatching:
|
||||
// if we got here our query will have nothing to filter with
|
||||
queries = make([]query, 1)
|
||||
prefix := make([]byte, 0)
|
||||
queries[0] = query{i: 0, dbi: b.indexCreatedAt, prefix: prefix, keySize: 0 + 4}
|
||||
queries[0] = query{
|
||||
i: 0,
|
||||
dbi: b.indexCreatedAt,
|
||||
prefix: prefix,
|
||||
}
|
||||
return queries, nil, nil, "", nil, since, nil
|
||||
}
|
||||
|
||||
+12
-14
@@ -2,14 +2,13 @@ package lmdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||
)
|
||||
|
||||
func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
return b.lmdbEnv.Update(func(txn *lmdb.Txn) error {
|
||||
func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||
err = b.lmdbEnv.Update(func(txn *lmdb.Txn) error {
|
||||
// check if we already have this id
|
||||
_, existsErr := txn.Get(b.indexId, evt.ID[0:8])
|
||||
if existsErr == nil {
|
||||
@@ -26,24 +25,21 @@ func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
}
|
||||
|
||||
// now we fetch the past events, whatever they are, delete them and then save the new
|
||||
var err error
|
||||
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
||||
err = b.query(txn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query past events with %s: %w", filter, err)
|
||||
}
|
||||
|
||||
shouldStore := true
|
||||
for previous := range results {
|
||||
if qerr := b.query(txn, filter, 10 /* could be just 1 */, func(previous nostr.Event) bool {
|
||||
if nostr.IsOlder(previous, evt) {
|
||||
if err := b.delete(txn, previous.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, err)
|
||||
if qerr := b.delete(txn, previous.ID); qerr != nil {
|
||||
qerr = fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, qerr)
|
||||
return false
|
||||
}
|
||||
deleted = append(deleted, previous)
|
||||
} else {
|
||||
// there is a newer event already stored, so we won't store this
|
||||
shouldStore = false
|
||||
}
|
||||
return true
|
||||
}); qerr != nil {
|
||||
return fmt.Errorf("failed to query past events with %s: %w", filter, qerr)
|
||||
}
|
||||
if shouldStore {
|
||||
return b.save(txn, evt)
|
||||
@@ -51,4 +47,6 @@ func (b *LMDBBackend) ReplaceEvent(evt nostr.Event) error {
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
@@ -33,18 +33,15 @@ func (il *IndexingLayer) CountEvents(filter nostr.Filter) (uint32, error) {
|
||||
// we already have a k and a v and an err from the cursor setup, so check and use these
|
||||
if it.exhausted ||
|
||||
it.err != nil ||
|
||||
len(it.key) != q.keySize ||
|
||||
len(it.key) != len(q.prefix)+4 ||
|
||||
!bytes.HasPrefix(it.key, q.prefix) {
|
||||
// either iteration has errored or we reached the end of this prefix
|
||||
break // stop this cursor and move to the next one
|
||||
}
|
||||
|
||||
// "id" indexes don't contain a timestamp
|
||||
if q.timestampSize == 4 {
|
||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||
if createdAt < since {
|
||||
break
|
||||
}
|
||||
createdAt := binary.BigEndian.Uint32(it.key[len(it.key)-4:])
|
||||
if createdAt < since {
|
||||
break
|
||||
}
|
||||
|
||||
if extraAuthors == nil && extraKinds == nil && extraTagValues == nil {
|
||||
|
||||
@@ -116,8 +116,7 @@ func (b *MultiMmapManager) Rescan() error {
|
||||
}
|
||||
}
|
||||
|
||||
b.freeRanges, err = b.gatherFreeRanges(mmmtxn)
|
||||
if err != nil {
|
||||
if err := b.gatherFreeRanges(mmmtxn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||
)
|
||||
|
||||
func (b *MultiMmapManager) gatherFreeRanges(txn *lmdb.Txn) (positions, error) {
|
||||
const LARGE_FREERANGE = 142
|
||||
|
||||
func (b *MultiMmapManager) gatherFreeRanges(txn *lmdb.Txn) error {
|
||||
cursor, err := txn.OpenCursor(b.indexId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open cursor on indexId: %w", err)
|
||||
return fmt.Errorf("failed to open cursor on indexId: %w", err)
|
||||
}
|
||||
defer cursor.Close()
|
||||
|
||||
@@ -28,31 +30,35 @@ func (b *MultiMmapManager) gatherFreeRanges(txn *lmdb.Txn) (positions, error) {
|
||||
usedPositions = append(usedPositions, position{start: b.mmapfEnd, size: 0})
|
||||
|
||||
// calculate free ranges as gaps between used positions
|
||||
freeRanges := make(positions, 0, len(usedPositions)/2)
|
||||
b.freeRangesAll = make(positions, 0, len(usedPositions))
|
||||
b.freeRangesLarge = make([]position, 0, len(usedPositions)/10)
|
||||
var currentStart uint64 = 0
|
||||
for _, used := range usedPositions {
|
||||
if used.start > currentStart {
|
||||
// gap from currentStart to pos.start
|
||||
freeSize := used.start - currentStart
|
||||
if freeSize > 0 {
|
||||
freeRanges = append(freeRanges, position{
|
||||
fr := position{
|
||||
start: currentStart,
|
||||
size: uint32(freeSize),
|
||||
})
|
||||
}
|
||||
b.freeRangesAll = append(b.freeRangesAll, fr)
|
||||
if fr.isLarge() {
|
||||
b.freeRangesLarge = append(b.freeRangesLarge, fr)
|
||||
}
|
||||
}
|
||||
}
|
||||
currentStart = used.start + uint64(used.size)
|
||||
}
|
||||
|
||||
return freeRanges, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
||||
// use binary search to find the insertion point for the new pos
|
||||
idx, exists := slices.BinarySearchFunc(b.freeRanges, newFreeRange.start, func(item position, target uint64) int {
|
||||
idx, exists := slices.BinarySearchFunc(b.freeRangesAll, newFreeRange.start, func(item position, target uint64) int {
|
||||
return cmp.Compare(item.start, target)
|
||||
})
|
||||
|
||||
if exists {
|
||||
panic(fmt.Errorf("can't add free range that already exists: %s", newFreeRange))
|
||||
}
|
||||
@@ -62,7 +68,7 @@ func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
||||
|
||||
// check the range immediately before
|
||||
if idx > 0 {
|
||||
before := b.freeRanges[idx-1]
|
||||
before := b.freeRangesAll[idx-1]
|
||||
if before.start+uint64(before.size) == newFreeRange.start {
|
||||
deleteStart = idx - 1
|
||||
deleting++
|
||||
@@ -72,8 +78,8 @@ func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
||||
}
|
||||
|
||||
// check the range immediately after
|
||||
if idx < len(b.freeRanges) {
|
||||
after := b.freeRanges[idx]
|
||||
if idx < len(b.freeRangesAll) {
|
||||
after := b.freeRangesAll[idx]
|
||||
if newFreeRange.start+uint64(newFreeRange.size) == after.start {
|
||||
if deleteStart == -1 {
|
||||
deleteStart = idx
|
||||
@@ -87,13 +93,60 @@ func (b *MultiMmapManager) mergeNewFreeRange(newFreeRange position) {
|
||||
switch deleting {
|
||||
case 0:
|
||||
// if we are not deleting anything we must insert the new free range
|
||||
b.freeRanges = slices.Insert(b.freeRanges, idx, newFreeRange)
|
||||
b.freeRangesAll = slices.Insert(b.freeRangesAll, idx, newFreeRange)
|
||||
|
||||
// if it's large add it to the list of large free ranges
|
||||
if newFreeRange.isLarge() {
|
||||
b.freeRangesLarge = append(b.freeRangesLarge, newFreeRange)
|
||||
}
|
||||
case 1:
|
||||
deleted := b.freeRangesAll[deleteStart]
|
||||
|
||||
// if we're deleting a single range, don't delete it, modify it in-place instead.
|
||||
b.freeRanges[deleteStart] = newFreeRange
|
||||
b.freeRangesAll[deleteStart] = newFreeRange
|
||||
|
||||
// if the list we're modifying is in the list of large ranges modify it there too
|
||||
if deleted.isLarge() {
|
||||
for i, large := range b.freeRangesLarge {
|
||||
if large.start == deleted.start {
|
||||
b.freeRangesLarge[i] = newFreeRange
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if newFreeRange.isLarge() {
|
||||
// otherwise: if after modification it's big enough we should add it to list of large ranges
|
||||
b.freeRangesLarge = append(b.freeRangesLarge, newFreeRange)
|
||||
}
|
||||
case 2:
|
||||
// now if we're deleting two ranges, delete just one instead and modify the other in place
|
||||
b.freeRanges[deleteStart] = newFreeRange
|
||||
b.freeRanges = slices.Delete(b.freeRanges, deleteStart+1, deleteStart+1+1)
|
||||
// now if we're deleting two ranges, delete the second instead and modify the first in place
|
||||
first := b.freeRangesAll[deleteStart]
|
||||
second := b.freeRangesAll[deleteStart+1]
|
||||
|
||||
b.freeRangesAll = slices.Delete(b.freeRangesAll, deleteStart+1, deleteStart+1+1)
|
||||
b.freeRangesAll[deleteStart] = newFreeRange
|
||||
|
||||
// if the second was in the list of large lists delete it from there too
|
||||
if second.isLarge() {
|
||||
for i, large := range b.freeRangesLarge {
|
||||
if large.start == second.start {
|
||||
b.freeRangesLarge[i] = b.freeRangesLarge[len(b.freeRangesLarge)-1]
|
||||
b.freeRangesLarge = b.freeRangesLarge[0 : len(b.freeRangesLarge)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if the list we're modifying (the first) is already in the list of large ranges modify it there too
|
||||
if first.isLarge() {
|
||||
for i, large := range b.freeRangesLarge {
|
||||
if large.start == first.start {
|
||||
b.freeRangesLarge[i] = newFreeRange
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if newFreeRange.isLarge() {
|
||||
// otherwise if after modification has become big enough we should add it to list of large ranges
|
||||
b.freeRangesLarge = append(b.freeRangesLarge, newFreeRange)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func FuzzFreeRanges(f *testing.F) {
|
||||
|
||||
total := 0
|
||||
for {
|
||||
freeBefore, spaceBefore := countUsableFreeRanges(mmmm)
|
||||
freeBefore, spaceBefore := countUsableFreeRanges(t, mmmm)
|
||||
|
||||
hasAdded := false
|
||||
for i := range rnd.IntN(40) {
|
||||
@@ -69,7 +69,7 @@ func FuzzFreeRanges(f *testing.F) {
|
||||
total++
|
||||
}
|
||||
|
||||
freeAfter, spaceAfter := countUsableFreeRanges(mmmm)
|
||||
freeAfter, spaceAfter := countUsableFreeRanges(t, mmmm)
|
||||
if hasAdded && freeBefore > 0 {
|
||||
require.Lessf(t, spaceAfter, spaceBefore, "must use some of the existing free ranges when inserting new events (before: %d, after: %d)", freeBefore, freeAfter)
|
||||
}
|
||||
@@ -86,9 +86,35 @@ func FuzzFreeRanges(f *testing.F) {
|
||||
}
|
||||
}
|
||||
|
||||
verifyFreeRangesInvariants(t, mmmm)
|
||||
|
||||
// add more events
|
||||
for i := range rnd.IntN(40) {
|
||||
content := "1"
|
||||
if i > 0 {
|
||||
content = strings.Repeat("z", rnd.IntN(1000))
|
||||
}
|
||||
|
||||
evt := nostr.Event{
|
||||
CreatedAt: nostr.Timestamp(rnd.Uint32()),
|
||||
Kind: 1,
|
||||
Content: content,
|
||||
Tags: nostr.Tags{},
|
||||
}
|
||||
evt.Sign(sk)
|
||||
err := il.SaveEvent(evt)
|
||||
require.NoError(t, err)
|
||||
|
||||
total++
|
||||
}
|
||||
|
||||
verifyFreeRangesInvariants(t, mmmm)
|
||||
|
||||
mmmm.lmdbEnv.View(func(txn *lmdb.Txn) error {
|
||||
expectedFreeRanges, _ := mmmm.gatherFreeRanges(txn)
|
||||
require.Equalf(t, expectedFreeRanges, mmmm.freeRanges, "expected %s, got %s", expectedFreeRanges, mmmm.freeRanges)
|
||||
before := mmmm.freeRangesAll
|
||||
err := mmmm.gatherFreeRanges(txn)
|
||||
require.NoError(t, err)
|
||||
require.Equalf(t, mmmm.freeRangesAll, before, "expected %s, got %s", before, mmmm.freeRangesAll)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -99,12 +125,54 @@ func FuzzFreeRanges(f *testing.F) {
|
||||
})
|
||||
}
|
||||
|
||||
func countUsableFreeRanges(mmmm *MultiMmapManager) (count int, space int) {
|
||||
for _, fr := range mmmm.freeRanges {
|
||||
if fr.size >= 142 {
|
||||
func countUsableFreeRanges(t *testing.T, mmmm *MultiMmapManager) (count int, space int) {
|
||||
for _, fr := range mmmm.freeRangesAll {
|
||||
if fr.size >= LARGE_FREERANGE {
|
||||
count++
|
||||
space += int(fr.size)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, count, len(mmmm.freeRangesLarge))
|
||||
|
||||
return count, space
|
||||
}
|
||||
|
||||
func verifyFreeRangesInvariants(t *testing.T, mmmm *MultiMmapManager) {
|
||||
all := mmmm.freeRangesAll
|
||||
large := mmmm.freeRangesLarge
|
||||
|
||||
for _, l := range large {
|
||||
found := false
|
||||
for _, a := range all {
|
||||
if l.start == a.start && l.size == a.size {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, found, "large range %v not found in all ranges", l)
|
||||
}
|
||||
|
||||
for i := 1; i < len(all); i++ {
|
||||
require.Greater(t, all[i].start, all[i-1].start, "all ranges should be sorted by start")
|
||||
}
|
||||
|
||||
for i := range all {
|
||||
for j := i + 1; j < len(all); j++ {
|
||||
end1 := all[i].start + uint64(all[i].size)
|
||||
end2 := all[j].start + uint64(all[j].size)
|
||||
require.False(t, (all[i].start >= all[j].start && all[i].start < end2) ||
|
||||
(all[j].start >= all[i].start && all[j].start < end1),
|
||||
"ranges %v and %v overlap", all[i], all[j])
|
||||
}
|
||||
}
|
||||
|
||||
mmmm.lmdbEnv.View(func(txn *lmdb.Txn) error {
|
||||
before := make(positions, len(mmmm.freeRangesAll))
|
||||
copy(before, mmmm.freeRangesAll)
|
||||
err := mmmm.gatherFreeRanges(txn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, before, mmmm.freeRangesAll, "recomputing free ranges should yield the same result")
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func (it *iterator) pull(n int, since uint32) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(it.key) != it.query.keySize || !bytes.HasPrefix(it.key, it.query.prefix) {
|
||||
if len(it.key) != len(it.query.prefix)+4 || !bytes.HasPrefix(it.key, it.query.prefix) {
|
||||
// we reached the end of this prefix
|
||||
it.exhausted = true
|
||||
return
|
||||
@@ -226,7 +226,7 @@ func (il *IndexingLayer) getIndexKeysForEvent(evt nostr.Event) iter.Seq[key] {
|
||||
return
|
||||
}
|
||||
|
||||
// now the p-tag+kind+date
|
||||
// now the p-1733934977tag+kind+date
|
||||
if dbi == il.indexTag32 && tag[0] == "p" {
|
||||
k := make([]byte, 8+2+4)
|
||||
xhex.Decode(k[0:8], []byte(tag[1][0:8*2]))
|
||||
|
||||
@@ -61,7 +61,7 @@ func (il *IndexingLayer) Init() error {
|
||||
|
||||
env.SetMaxDBs(9)
|
||||
env.SetMaxReaders(1000)
|
||||
env.SetMapSize(1 << 38) // ~273GB
|
||||
env.SetMapSize(MMAP_INFINITE_SIZE)
|
||||
|
||||
// create directory if it doesn't exist and open it
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
|
||||
+46
-24
@@ -35,13 +35,15 @@ type MultiMmapManager struct {
|
||||
mmapfEnd uint64
|
||||
|
||||
writeMutex sync.Mutex
|
||||
lockfile *os.File
|
||||
|
||||
lmdbEnv *lmdb.Env
|
||||
stuff lmdb.DBI
|
||||
knownLayers lmdb.DBI
|
||||
indexId lmdb.DBI
|
||||
|
||||
freeRanges positions
|
||||
freeRangesAll positions // sorted by position
|
||||
freeRangesLarge []position // unsorted
|
||||
}
|
||||
|
||||
func (b *MultiMmapManager) String() string {
|
||||
@@ -49,33 +51,43 @@ func (b *MultiMmapManager) String() string {
|
||||
}
|
||||
|
||||
const (
|
||||
MMAP_INFINITE_SIZE = 1 << 40
|
||||
MMAP_INFINITE_SIZE = 100_000_000_000
|
||||
maxuint16 = 65535
|
||||
maxuint32 = 4294967295
|
||||
)
|
||||
|
||||
func (b *MultiMmapManager) Init() error {
|
||||
func (b *MultiMmapManager) Init() (err error) {
|
||||
if b.Logger == nil {
|
||||
nopLogger := zerolog.Nop()
|
||||
b.Logger = &nopLogger
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
b.releaseLock()
|
||||
}
|
||||
}()
|
||||
|
||||
// create directory if it doesn't exist
|
||||
dbpath := filepath.Join(b.Dir, "mmmm")
|
||||
if err := os.MkdirAll(dbpath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dbpath, err)
|
||||
}
|
||||
|
||||
if !b.ReadOnly {
|
||||
// create lockfile to prevent multiple instances
|
||||
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
||||
if _, err := os.OpenFile(lockfilePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644); err != nil {
|
||||
if os.IsExist(err) {
|
||||
return fmt.Errorf("database at %s is already in use by another instance", b.Dir)
|
||||
}
|
||||
return fmt.Errorf("failed to create lockfile %s: %w", lockfilePath, err)
|
||||
}
|
||||
// lock database directory to prevent multiple instances
|
||||
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
||||
lockfile, err := os.OpenFile(lockfilePath, os.O_CREATE|os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open lockfile %s: %w", lockfilePath, err)
|
||||
}
|
||||
if err := syscall.Flock(int(lockfile.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
|
||||
lockfile.Close()
|
||||
if errors.Is(err, syscall.EWOULDBLOCK) || errors.Is(err, syscall.EAGAIN) {
|
||||
return fmt.Errorf("database at %s is already in use by another instance", b.Dir)
|
||||
}
|
||||
return fmt.Errorf("failed to lock database at %s: %w", b.Dir, err)
|
||||
}
|
||||
b.lockfile = lockfile
|
||||
|
||||
// open a huge mmapped file
|
||||
b.mmapfPath = filepath.Join(b.Dir, "events")
|
||||
@@ -83,7 +95,7 @@ func (b *MultiMmapManager) Init() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open events file at %s: %w", b.mmapfPath, err)
|
||||
}
|
||||
mmapf, err := syscall.Mmap(int(file.Fd()), 0, MMAP_INFINITE_SIZE,
|
||||
mmapf, err := syscall.Mmap(int(file.Fd()), 0, int(MMAP_INFINITE_SIZE),
|
||||
syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mmap events file at %s: %w", b.mmapfPath, err)
|
||||
@@ -104,7 +116,7 @@ func (b *MultiMmapManager) Init() error {
|
||||
|
||||
env.SetMaxDBs(3)
|
||||
env.SetMaxReaders(1000)
|
||||
env.SetMapSize(1 << 38) // ~273GB
|
||||
env.SetMapSize(MMAP_INFINITE_SIZE)
|
||||
|
||||
err = env.Open(dbpath, lmdb.NoTLS, 0644)
|
||||
if err != nil {
|
||||
@@ -139,18 +151,17 @@ func (b *MultiMmapManager) Init() error {
|
||||
|
||||
if !b.ReadOnly {
|
||||
// scan index table to calculate free ranges from used positions
|
||||
b.freeRanges, err = b.gatherFreeRanges(txn)
|
||||
if err != nil {
|
||||
if err := b.gatherFreeRanges(txn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logOp := b.Logger.Debug()
|
||||
for _, pos := range b.freeRanges {
|
||||
if pos.size > 20 {
|
||||
logOp = logOp.Uint32(fmt.Sprintf("%d", pos.start), pos.size)
|
||||
}
|
||||
count := 0
|
||||
for _, pos := range b.freeRangesLarge {
|
||||
logOp = logOp.Uint32(fmt.Sprintf("%d", pos.start), pos.size)
|
||||
count++
|
||||
}
|
||||
logOp.Msg("calculated free ranges from index scan")
|
||||
logOp.Int("count", count).Msg("calculated free ranges from index scan")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -365,6 +376,19 @@ func (b *MultiMmapManager) getNextAvailableLayerId(txn *lmdb.Txn) (uint16, error
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (b *MultiMmapManager) releaseLock() {
|
||||
if b.lockfile == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = syscall.Flock(int(b.lockfile.Fd()), syscall.LOCK_UN)
|
||||
_ = b.lockfile.Close()
|
||||
b.lockfile = nil
|
||||
|
||||
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
||||
_ = os.Remove(lockfilePath)
|
||||
}
|
||||
|
||||
func (b *MultiMmapManager) Close() {
|
||||
b.lmdbEnv.Close()
|
||||
for _, il := range b.layers {
|
||||
@@ -373,7 +397,5 @@ func (b *MultiMmapManager) Close() {
|
||||
|
||||
syscall.Munmap(b.mmapf)
|
||||
|
||||
// remove lockfile
|
||||
lockfilePath := filepath.Join(b.Dir, "mmmm.lock")
|
||||
os.Remove(lockfilePath)
|
||||
b.releaseLock()
|
||||
}
|
||||
|
||||
@@ -1,22 +1,35 @@
|
||||
package mmm
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type positions []position
|
||||
|
||||
func (poss positions) find(start uint64) (idx int) {
|
||||
idx, _ = slices.BinarySearchFunc(poss, start, func(item position, target uint64) int {
|
||||
return cmp.Compare(item.start, target)
|
||||
})
|
||||
return idx
|
||||
}
|
||||
|
||||
func (poss positions) del(start uint64) positions {
|
||||
idx := poss.find(start)
|
||||
return slices.Delete(poss, idx, idx+1)
|
||||
}
|
||||
|
||||
func (poss positions) String() string {
|
||||
str := strings.Builder{}
|
||||
str.Grow(10 + 20*len(poss))
|
||||
str.WriteString("positions:[")
|
||||
for _, pos := range poss {
|
||||
str.WriteByte(' ')
|
||||
str.WriteString(pos.String())
|
||||
}
|
||||
str.WriteString(" ]")
|
||||
str.WriteString("]")
|
||||
return str.String()
|
||||
}
|
||||
|
||||
@@ -29,6 +42,10 @@ func (pos position) String() string {
|
||||
return fmt.Sprintf("<%d|%d|%d>", pos.start, pos.size, pos.start+uint64(pos.size))
|
||||
}
|
||||
|
||||
func (pos position) isLarge() bool {
|
||||
return pos.size >= LARGE_FREERANGE
|
||||
}
|
||||
|
||||
func positionFromBytes(posb []byte) position {
|
||||
return position{
|
||||
size: binary.BigEndian.Uint32(posb[0:4]),
|
||||
|
||||
+11
-3
@@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"math"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore/codec/betterbinary"
|
||||
@@ -14,6 +15,12 @@ import (
|
||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||
)
|
||||
|
||||
var tempResultsPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make([]nostr.Event, 0, 64)
|
||||
},
|
||||
}
|
||||
|
||||
// GetByID returns the event -- if found in this mmm -- and all the IndexingLayers it belongs to.
|
||||
func (b *MultiMmapManager) GetByID(id nostr.ID) (*nostr.Event, IndexingLayers) {
|
||||
var event *nostr.Event
|
||||
@@ -140,7 +147,8 @@ func (il *IndexingLayer) query(txn *lmdb.Txn, filter nostr.Filter, limit int, yi
|
||||
|
||||
numberOfIteratorsToPullOnEachRound := max(1, int(math.Ceil(float64(len(iterators))/float64(12))))
|
||||
totalEventsEmitted := 0
|
||||
tempResults := make([]nostr.Event, 0, batchSizePerQuery*2)
|
||||
tempResults := tempResultsPool.Get().([]nostr.Event)
|
||||
defer tempResultsPool.Put(tempResults[:0])
|
||||
|
||||
for len(iterators) > 0 {
|
||||
// reset stuff
|
||||
@@ -180,8 +188,8 @@ func (il *IndexingLayer) query(txn *lmdb.Txn, filter nostr.Filter, limit int, yi
|
||||
// decode the entire thing
|
||||
event := nostr.Event{}
|
||||
if err := betterbinary.Unmarshal(bin, &event); err != nil {
|
||||
log.Printf("lmdb: value read error (id %x) on query prefix %x sp %x dbi %v: %s\n",
|
||||
betterbinary.GetID(bin), iterators[i].query.prefix, iterators[i].query.startingPoint, iterators[i].query.dbi, err)
|
||||
log.Printf("mmm: value read error (id %s) on query prefix %x sp %x dbi %v: %s\n",
|
||||
betterbinary.GetID(bin).Hex(), iterators[i].query.prefix, iterators[i].query.startingPoint, iterators[i].query.dbi, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ type query struct {
|
||||
i int
|
||||
dbi lmdb.DBI
|
||||
prefix []byte
|
||||
keySize int
|
||||
timestampSize int
|
||||
startingPoint []byte
|
||||
}
|
||||
|
||||
@@ -41,10 +39,10 @@ func (il *IndexingLayer) prepareQueries(filter nostr.Filter) (
|
||||
}
|
||||
}
|
||||
for i, q := range queries {
|
||||
sp := make([]byte, len(q.prefix))
|
||||
sp = sp[0:len(q.prefix)]
|
||||
copy(sp, q.prefix)
|
||||
queries[i].startingPoint = binary.BigEndian.AppendUint32(sp, uint32(until))
|
||||
sp := make([]byte, len(q.prefix)+4)
|
||||
copy(sp[0:len(q.prefix)], q.prefix)
|
||||
binary.BigEndian.PutUint32(sp[len(q.prefix):], uint32(until))
|
||||
queries[i].startingPoint = sp
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -65,39 +63,27 @@ func (il *IndexingLayer) prepareQueries(filter nostr.Filter) (
|
||||
}
|
||||
|
||||
// only "p" tag has a goodness of 2, so
|
||||
if goodness == 2 {
|
||||
if goodness == 2 && filter.Kinds != nil {
|
||||
// this means we got a "p" tag, so we will use the ptag-kind index
|
||||
i := 0
|
||||
if filter.Kinds != nil {
|
||||
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
||||
for _, value := range tagValues {
|
||||
if len(value) != 64 {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
|
||||
for _, kind := range filter.Kinds {
|
||||
k := make([]byte, 8+2)
|
||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
||||
queries[i] = query{i: i, dbi: il.indexPTagKind, prefix: k[0 : 8+2], keySize: 8 + 2 + 4, timestampSize: 4}
|
||||
i++
|
||||
}
|
||||
queries = make([]query, len(tagValues)*len(filter.Kinds))
|
||||
for _, value := range tagValues {
|
||||
if len(value) != 64 {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
} else {
|
||||
// even if there are no kinds, in that case we will just return any kind and not care
|
||||
queries = make([]query, len(tagValues))
|
||||
for i, value := range tagValues {
|
||||
if len(value) != 64 {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
|
||||
k := make([]byte, 8)
|
||||
for _, kind := range filter.Kinds {
|
||||
k := make([]byte, 8+2)
|
||||
if err := xhex.Decode(k[0:8], []byte(value[0:8*2])); err != nil {
|
||||
return nil, nil, nil, "", nil, 0, fmt.Errorf("invalid 'p' tag '%s'", value)
|
||||
}
|
||||
queries[i] = query{i: i, dbi: il.indexPTagKind, prefix: k[0:8], keySize: 8 + 2 + 4, timestampSize: 4}
|
||||
binary.BigEndian.PutUint16(k[8:8+2], uint16(kind))
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: il.indexPTagKind,
|
||||
prefix: k[0 : 8+2],
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -108,7 +94,11 @@ func (il *IndexingLayer) prepareQueries(filter nostr.Filter) (
|
||||
dbi, k, offset := il.getTagIndexPrefix(tagKey, value)
|
||||
// remove the last parts part to get just the prefix we want here
|
||||
prefix := k[0:offset]
|
||||
queries[i] = query{i: i, dbi: dbi, prefix: prefix, keySize: len(prefix) + 4, timestampSize: 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: dbi,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// add an extra kind filter if available (only do this on plain tag index, not on ptag-kind index)
|
||||
@@ -143,9 +133,11 @@ pubkeyMatching:
|
||||
// will use pubkey index
|
||||
queries = make([]query, len(filter.Authors))
|
||||
for i, pk := range filter.Authors {
|
||||
prefix := make([]byte, 8)
|
||||
copy(prefix[0:8], pk[0:8])
|
||||
queries[i] = query{i: i, dbi: il.indexPubkey, prefix: prefix[0:8], keySize: 8 + 4, timestampSize: 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: il.indexPubkey,
|
||||
prefix: pk[0:8],
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// will use pubkeyKind index
|
||||
@@ -156,7 +148,11 @@ pubkeyMatching:
|
||||
prefix := make([]byte, 8+2)
|
||||
copy(prefix[0:8], pk[0:8])
|
||||
binary.BigEndian.PutUint16(prefix[8:8+2], uint16(kind))
|
||||
queries[i] = query{i: i, dbi: il.indexPubkeyKind, prefix: prefix[0 : 8+2], keySize: 10 + 4, timestampSize: 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: il.indexPubkeyKind,
|
||||
prefix: prefix[0 : 8+2],
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
@@ -173,7 +169,11 @@ pubkeyMatching:
|
||||
for i, kind := range filter.Kinds {
|
||||
prefix := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(prefix[0:2], uint16(kind))
|
||||
queries[i] = query{i: i, dbi: il.indexKind, prefix: prefix[0:2], keySize: 2 + 4, timestampSize: 4}
|
||||
queries[i] = query{
|
||||
i: i,
|
||||
dbi: il.indexKind,
|
||||
prefix: prefix[0:2],
|
||||
}
|
||||
}
|
||||
|
||||
// potentially with an extra useless tag filtering
|
||||
@@ -184,6 +184,10 @@ pubkeyMatching:
|
||||
// if we got here our query will have nothing to filter with
|
||||
queries = make([]query, 1)
|
||||
prefix := make([]byte, 0)
|
||||
queries[0] = query{i: 0, dbi: il.indexCreatedAt, prefix: prefix, keySize: 0 + 4, timestampSize: 4}
|
||||
queries[0] = query{
|
||||
i: 0,
|
||||
dbi: il.indexCreatedAt,
|
||||
prefix: prefix,
|
||||
}
|
||||
return queries, nil, nil, "", nil, since, nil
|
||||
}
|
||||
|
||||
+18
-16
@@ -9,9 +9,9 @@ import (
|
||||
"github.com/PowerDNS/lmdb-go/lmdb"
|
||||
)
|
||||
|
||||
func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
||||
func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||
if il.mmmm.ReadOnly {
|
||||
return ReadOnly
|
||||
return nil, ReadOnly
|
||||
}
|
||||
|
||||
il.mmmm.writeMutex.Lock()
|
||||
@@ -29,7 +29,7 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
||||
// prepare transactions
|
||||
mmmtxn, err := il.mmmm.lmdbEnv.BeginTxn(nil, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
// defer abort but only if we haven't committed (we'll set it to nil after committing)
|
||||
@@ -41,7 +41,7 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
||||
|
||||
iltxn, err := il.lmdbEnv.BeginTxn(nil, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
// defer abort but only if we haven't committed (we'll set it to nil after committing)
|
||||
@@ -54,33 +54,35 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
||||
// check if we already have this id
|
||||
_, existsErr := mmmtxn.Get(il.mmmm.indexId, evt.ID[0:8])
|
||||
if existsErr == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
if !lmdb.IsNotFound(existsErr) {
|
||||
return fmt.Errorf("error checking existence: %w", existsErr)
|
||||
return nil, fmt.Errorf("error checking existence: %w", existsErr)
|
||||
}
|
||||
|
||||
// now we fetch the past events, whatever they are, delete them and then save the new
|
||||
var qerr error
|
||||
var results iter.Seq[nostr.Event] = func(yield func(nostr.Event) bool) {
|
||||
err = il.query(iltxn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
||||
qerr = il.query(iltxn, filter, 10 /* in theory limit could be just 1 and this should work */, yield)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query past events with %s: %w", filter, err)
|
||||
if qerr != nil {
|
||||
return nil, fmt.Errorf("failed to query past events with %s: %w", filter, qerr)
|
||||
}
|
||||
|
||||
var acquiredFreeRangeFromDelete *position
|
||||
shouldStore := true
|
||||
for previous := range results {
|
||||
if nostr.IsOlder(previous, evt) {
|
||||
if pos, shouldPurge, err := il.delete(mmmtxn, iltxn, previous.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, err)
|
||||
if pos, shouldPurge, derr := il.delete(mmmtxn, iltxn, previous.ID); derr != nil {
|
||||
return nil, fmt.Errorf("failed to delete event %s for replacing: %w", previous.ID, derr)
|
||||
} else if shouldPurge {
|
||||
// purge
|
||||
if err := mmmtxn.Del(il.mmmm.indexId, previous.ID[0:8], nil); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
acquiredFreeRangeFromDelete = &pos
|
||||
}
|
||||
deleted = append(deleted, previous)
|
||||
} else {
|
||||
// there is a newer event already stored, so we won't store this
|
||||
shouldStore = false
|
||||
@@ -90,17 +92,17 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
||||
if shouldStore {
|
||||
_, err := il.mmmm.storeOn(mmmtxn, iltxn, il, evt)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// commit in this order to minimize problematic inconsistencies
|
||||
if err := mmmtxn.Commit(); err != nil {
|
||||
return fmt.Errorf("can't commit mmmtxn: %w", err)
|
||||
return nil, fmt.Errorf("can't commit mmmtxn: %w", err)
|
||||
}
|
||||
mmmtxn = nil
|
||||
if err := iltxn.Commit(); err != nil {
|
||||
return fmt.Errorf("can't commit iltxn: %w", err)
|
||||
return nil, fmt.Errorf("can't commit iltxn: %w", err)
|
||||
}
|
||||
iltxn = nil
|
||||
|
||||
@@ -110,5 +112,5 @@ func (il *IndexingLayer) ReplaceEvent(evt nostr.Event) error {
|
||||
il.mmmm.mergeNewFreeRange(*acquiredFreeRangeFromDelete)
|
||||
}
|
||||
|
||||
return nil
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
+22
-7
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
@@ -104,25 +103,41 @@ func (b *MultiMmapManager) storeOn(
|
||||
return false, fmt.Errorf("event too large to store, max %d, got %d", 1<<16, pos.size)
|
||||
}
|
||||
|
||||
// find a suitable place for this to be stored in
|
||||
// find a suitable place for this to be stored in (search only large free ranges)
|
||||
appendToMmap := true
|
||||
for f, fr := range b.freeRanges {
|
||||
for f, fr := range b.freeRangesLarge {
|
||||
if fr.size >= pos.size {
|
||||
// found the smallest possible place that can fit this event
|
||||
// found a place that can fit this event
|
||||
appendToMmap = false
|
||||
pos.start = fr.start
|
||||
|
||||
// modify the free ranges we're keeping track of
|
||||
// (in case of conflict we lose this free range but it's ok, it will be recovered on the next startup)
|
||||
if pos.size == fr.size {
|
||||
// if we've used it entirely just delete it
|
||||
b.freeRanges = slices.Delete(b.freeRanges, f, f+1)
|
||||
// if we've used it entirely just delete it (swap-delete since it's unsorted)
|
||||
b.freeRangesLarge[f] = b.freeRangesLarge[len(b.freeRangesLarge)-1]
|
||||
b.freeRangesLarge = b.freeRangesLarge[0 : len(b.freeRangesLarge)-1]
|
||||
|
||||
// also delete it from b.freeRangesAll
|
||||
b.freeRangesAll = b.freeRangesAll.del(fr.start)
|
||||
} else {
|
||||
// otherwise modify it in place
|
||||
b.freeRanges[f] = position{
|
||||
newFreeRange := position{
|
||||
start: fr.start + uint64(pos.size),
|
||||
size: fr.size - pos.size,
|
||||
}
|
||||
// only keep it in freeRangesLarge if it's still large enough
|
||||
if newFreeRange.size >= LARGE_FREERANGE {
|
||||
b.freeRangesLarge[f] = newFreeRange
|
||||
} else {
|
||||
// remove it from freeRangesLarge if it's no longer large enough
|
||||
b.freeRangesLarge[f] = b.freeRangesLarge[len(b.freeRangesLarge)-1]
|
||||
b.freeRangesLarge = b.freeRangesLarge[0 : len(b.freeRangesLarge)-1]
|
||||
}
|
||||
|
||||
// also modify it in b.freeRangesAll
|
||||
idx := b.freeRangesAll.find(fr.start)
|
||||
b.freeRangesAll[idx] = newFreeRange
|
||||
}
|
||||
|
||||
break
|
||||
|
||||
@@ -29,8 +29,8 @@ func (b NullStore) SaveEvent(evt nostr.Event) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b NullStore) ReplaceEvent(evt nostr.Event) error {
|
||||
return nil
|
||||
func (b NullStore) ReplaceEvent(evt nostr.Event) ([]nostr.Event, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b NullStore) CountEvents(filter nostr.Filter) (uint32, error) {
|
||||
|
||||
@@ -122,7 +122,7 @@ func (b *SliceStore) delete(id nostr.ID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *SliceStore) ReplaceEvent(evt nostr.Event) error {
|
||||
func (b *SliceStore) ReplaceEvent(evt nostr.Event) (deleted []nostr.Event, err error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
@@ -135,8 +135,9 @@ func (b *SliceStore) ReplaceEvent(evt nostr.Event) error {
|
||||
for previous := range b.QueryEvents(filter, 1) {
|
||||
if nostr.IsOlder(previous, evt) {
|
||||
if err := b.delete(previous.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete event for replacing: %w", err)
|
||||
return nil, fmt.Errorf("failed to delete event for replacing: %w", err)
|
||||
}
|
||||
deleted = append(deleted, previous)
|
||||
} else {
|
||||
shouldStore = false
|
||||
}
|
||||
@@ -144,11 +145,11 @@ func (b *SliceStore) ReplaceEvent(evt nostr.Event) error {
|
||||
|
||||
if shouldStore {
|
||||
if err := b.save(evt); err != nil && err != eventstore.ErrDupEvent {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
return nil, fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func eventTimestampComparator(e nostr.Event, t nostr.Timestamp) int {
|
||||
|
||||
+1
-1
@@ -26,7 +26,7 @@ type Store interface {
|
||||
|
||||
// ReplaceEvent atomically replaces a replaceable or addressable event.
|
||||
// Conceptually it is like a Query->Delete->Save, but streamlined.
|
||||
ReplaceEvent(nostr.Event) error
|
||||
ReplaceEvent(nostr.Event) (deleted []nostr.Event, err error)
|
||||
|
||||
// CountEvents counts all events that match a given filter
|
||||
CountEvents(nostr.Filter) (uint32, error)
|
||||
|
||||
@@ -128,6 +128,24 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, events[5].ID, results[0].ID, "author + kind query error")
|
||||
}
|
||||
|
||||
// test 5: until
|
||||
{
|
||||
results := slices.Collect(db.QueryEvents(nostr.Filter{Until: 102}, 1000))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 3)
|
||||
|
||||
resultsWithTag := slices.Collect(db.QueryEvents(nostr.Filter{
|
||||
Until: 102,
|
||||
Tags: nostr.TagMap{
|
||||
"e": []string{
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
},
|
||||
},
|
||||
}, 1000))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resultsWithTag, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// from another-basic-test.patch
|
||||
@@ -223,7 +241,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
||||
}
|
||||
originalProfile.Sign(sk3)
|
||||
|
||||
err = db.ReplaceEvent(originalProfile)
|
||||
_, err = db.ReplaceEvent(originalProfile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify
|
||||
@@ -244,7 +262,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
||||
newProfile.Sign(sk3)
|
||||
|
||||
// replace with newer event
|
||||
err = db.ReplaceEvent(newProfile)
|
||||
_, err = db.ReplaceEvent(newProfile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify only the newer event exists
|
||||
@@ -264,7 +282,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
||||
}
|
||||
olderProfile.Sign(sk3)
|
||||
|
||||
err = db.ReplaceEvent(olderProfile)
|
||||
_, err = db.ReplaceEvent(olderProfile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify the newer event is still there
|
||||
@@ -284,7 +302,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
||||
}
|
||||
articleV1.Sign(sk3)
|
||||
|
||||
err = db.ReplaceEvent(articleV1)
|
||||
_, err = db.ReplaceEvent(articleV1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify article was saved
|
||||
@@ -305,7 +323,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
||||
}
|
||||
articleV2.Sign(sk3)
|
||||
|
||||
err = db.ReplaceEvent(articleV2)
|
||||
_, err = db.ReplaceEvent(articleV2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify only the newer version exists
|
||||
@@ -327,7 +345,7 @@ func basicTest(t *testing.T, db eventstore.Store) {
|
||||
}
|
||||
differentArticle.Sign(sk3)
|
||||
|
||||
err = db.ReplaceEvent(differentArticle)
|
||||
_, err = db.ReplaceEvent(differentArticle)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify both articles exist (different d tags)
|
||||
|
||||
@@ -33,6 +33,7 @@ var tests = []struct {
|
||||
{"manyauthors", manyAuthorsTest},
|
||||
{"unbalanced", unbalancedTest},
|
||||
{"count", countTest},
|
||||
{"pfilter-until", pTagUntilMismatchTest},
|
||||
}
|
||||
|
||||
func TestSliceStore(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func pTagUntilMismatchTest(t *testing.T, db eventstore.Store) {
|
||||
err := db.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
targetP := "460c25e682fda7832b52d1f22d3d22b3176d972f60dcdc3212ed8c92ef85065c"
|
||||
author := nostr.MustPubKeyFromHex("7fa56f5d6962ab1e3cd424e758c3002b8665f7b0d8dcee9fe9e288d7751ac194")
|
||||
|
||||
events := []nostr.Event{
|
||||
{
|
||||
Kind: 9802,
|
||||
ID: nostr.MustIDFromHex("2c997233fa580b1a831f989d8fa320c409f8412d9c75b819c9a29df102d7f901"),
|
||||
PubKey: author,
|
||||
CreatedAt: 1773835689,
|
||||
Tags: nostr.Tags{
|
||||
{"e", "bcaa6599e69cff48ed6ab4b0b315d4f33a869a4ba8fa808287700faebe17195f", "wss://nos.lol/", "source"},
|
||||
{"p", targetP, "", "author"},
|
||||
},
|
||||
Content: "With so few people donating in the zap the devs button, the incentives are quite low to produce cool new things",
|
||||
Sig: sigFromHex(t, "b49206476c4d2a5f44590331541c83910fd826c0f4cdab99ceffd5bcf3aca94935e3db9d7820e7db3a0f1165c43a28dd3173a81fd08bf8348629ea4efde02537"),
|
||||
},
|
||||
{
|
||||
Kind: 9802,
|
||||
ID: nostr.MustIDFromHex("31c1eddb3a5201ef1bbce91b9fb3b7d8fe3e3eb25a66bedadcbc93c84d072c7d"),
|
||||
PubKey: author,
|
||||
CreatedAt: 1773154080,
|
||||
Tags: nostr.Tags{
|
||||
{"p", "5ea4648045bb1ff222655ddd36e6dceddc43590c26090c486bef38ef450da5bd", "", "mention"},
|
||||
{"p", "c8fb0d3aa788b9ace4f6cb92dd97d3f292db25b5c9f92462ef6c64926129fbaf", "", "mention"},
|
||||
{"p", "2f29aa33c2a3b45c2ef32212879248b2f4a49a002bd0de0fa16c94e138ac6f13", "", "mention"},
|
||||
{"p", targetP, "", "mention"},
|
||||
{"comment", "normie"},
|
||||
{"e", "5911eeba39a6886fe8abea82bb50612d27d1273d63904c9b64cde070c7088d48", "wss://relay.primal.net/", "source"},
|
||||
{"p", "3f770d65d3a764a9c5cb503ae123e62ec7598ad035d836e2a810f3877a745b24", "", "author"},
|
||||
},
|
||||
Content: "grimoire is cool, but it's too nerdy for me",
|
||||
Sig: sigFromHex(t, "0ee01515d54293d52fa1247a395e64c8499df96eee80c30204cb7c8fc5b5977023e6ac00cc240f137ce7e594818545340bf74bce6d3de86539f1a5d26fe33f24"),
|
||||
},
|
||||
{
|
||||
Kind: 9802,
|
||||
ID: nostr.MustIDFromHex("baeb90e2075c9d8a9b41286dbf1c52e5ef8ad6c030118839ce24d065e72df9b7"),
|
||||
PubKey: author,
|
||||
CreatedAt: 1773154058,
|
||||
Tags: nostr.Tags{
|
||||
{"p", "c8fb0d3aa788b9ace4f6cb92dd97d3f292db25b5c9f92462ef6c64926129fbaf", "", "mention"},
|
||||
{"p", "2f29aa33c2a3b45c2ef32212879248b2f4a49a002bd0de0fa16c94e138ac6f13", "", "mention"},
|
||||
{"p", targetP, "", "mention"},
|
||||
{"p", "3f770d65d3a764a9c5cb503ae123e62ec7598ad035d836e2a810f3877a745b24", "", "mention"},
|
||||
{"comment", "no lies detected"},
|
||||
{"e", "81171c564cedbc5f07e5b7a9d06842d1f43a81cd79c8755921190382de55c514", "wss://nos.lol/", "source"},
|
||||
{"p", "5ea4648045bb1ff222655ddd36e6dceddc43590c26090c486bef38ef450da5bd", "", "author"},
|
||||
},
|
||||
Content: "i never used grimoire once in my life, but that is not the point, it is still the best client",
|
||||
Sig: sigFromHex(t, "3e4b855c4e3a4d2b3a078d593e728f1bbdb07af91ba5831b7866e2d16df90ce58a5f9f1db5733603911b20b334bc7fc5ae1482c7870b9f7acde4a8ccc080a79d"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, evt := range events {
|
||||
err = db.SaveEvent(evt)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
results := slices.Collect(db.QueryEvents(nostr.Filter{
|
||||
Until: 1733934976,
|
||||
Limit: 3,
|
||||
Tags: nostr.TagMap{"p": []string{targetP}},
|
||||
}, 1000))
|
||||
require.Len(t, results, 0)
|
||||
}
|
||||
|
||||
func sigFromHex(t *testing.T, sigStr string) [64]byte {
|
||||
t.Helper()
|
||||
|
||||
raw, err := hex.DecodeString(sigStr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, raw, 64)
|
||||
|
||||
var sig [64]byte
|
||||
copy(sig[:], raw)
|
||||
return sig
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package wrappers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore"
|
||||
)
|
||||
|
||||
var _ nostr.Publisher = DynamicPublisher{}
|
||||
|
||||
type DynamicPublisher struct {
|
||||
GetStore func() eventstore.Store
|
||||
MaxLimit int
|
||||
}
|
||||
|
||||
func (w DynamicPublisher) QueryEvents(filter nostr.Filter) iter.Seq[nostr.Event] {
|
||||
return w.GetStore().QueryEvents(filter, w.MaxLimit)
|
||||
}
|
||||
|
||||
func (w DynamicPublisher) Publish(ctx context.Context, evt nostr.Event) error {
|
||||
if evt.Kind.IsEphemeral() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
if evt.Kind.IsRegular() {
|
||||
if err := w.GetStore().SaveEvent(evt); err != nil && err != eventstore.ErrDupEvent {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err := w.GetStore().ReplaceEvent(evt)
|
||||
return err
|
||||
}
|
||||
@@ -39,5 +39,6 @@ func (w StorePublisher) Publish(ctx context.Context, evt nostr.Event) error {
|
||||
}
|
||||
|
||||
// others are replaced
|
||||
return w.Store.ReplaceEvent(evt)
|
||||
_, err := w.Store.ReplaceEvent(evt)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ func (ef Filter) Matches(event Event) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func (ef Filter) MatchesIgnoringTimestampConstraints(event Event) bool {
|
||||
if ef.IDs != nil && !slices.Contains(ef.IDs, event.ID) {
|
||||
return false
|
||||
|
||||
@@ -40,8 +40,10 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
fiatjaf.com/lib v0.3.6
|
||||
github.com/dgraph-io/ristretto/v2 v2.3.0
|
||||
github.com/go-git/go-git/v5 v5.16.3
|
||||
github.com/pemistahl/lingua-go v1.4.0
|
||||
github.com/sivukhin/godjot v1.0.6
|
||||
github.com/templexxx/cpu v0.0.1
|
||||
github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b
|
||||
@@ -63,6 +65,7 @@ require (
|
||||
github.com/blevesearch/scorch_segment_api/v2 v2.2.16 // indirect
|
||||
github.com/blevesearch/segment v0.9.1 // indirect
|
||||
github.com/blevesearch/snowballstem v0.9.0 // indirect
|
||||
github.com/blevesearch/stempel v0.2.0 // indirect
|
||||
github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect
|
||||
github.com/blevesearch/vellum v1.0.11 // indirect
|
||||
github.com/blevesearch/zapx/v11 v11.3.10 // indirect
|
||||
@@ -93,6 +96,7 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect
|
||||
github.com/shopspring/decimal v1.3.1 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
fiatjaf.com/lib v0.3.6 h1:GRZNSxHI2EWdjSKVuzaT+c0aifLDtS16SzkeJaHyJfY=
|
||||
fiatjaf.com/lib v0.3.6/go.mod h1:UlHaZvPHj25PtKLh9GjZkUHRmQ2xZ8Jkoa4VRaLeeQ8=
|
||||
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc=
|
||||
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw=
|
||||
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc=
|
||||
@@ -40,6 +42,8 @@ github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+j
|
||||
github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw=
|
||||
github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s=
|
||||
github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs=
|
||||
github.com/blevesearch/stempel v0.2.0 h1:CYzVPaScODMvgE9o+kf6D4RJ/VRomyi9uHF+PtB+Afc=
|
||||
github.com/blevesearch/stempel v0.2.0/go.mod h1:wjeTHqQv+nQdbPuJ/YcvOjTInA2EIc6Ks1FoSUzSLvc=
|
||||
github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A=
|
||||
github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ=
|
||||
github.com/blevesearch/vellum v1.0.11 h1:SJI97toEFTtA9WsDZxkyGTaBWFdWl1n2LEDCXLCq/AU=
|
||||
@@ -190,6 +194,8 @@ github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/pemistahl/lingua-go v1.4.0 h1:ifYhthrlW7iO4icdubwlduYnmwU37V1sbNrwhKBR4rM=
|
||||
github.com/pemistahl/lingua-go v1.4.0/go.mod h1:ECuM1Hp/3hvyh7k8aWSqNCPlTxLemFZsRjocUf3KgME=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -207,6 +213,8 @@ github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc=
|
||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
|
||||
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
github.com/sivukhin/godjot v1.0.6 h1:yoRD+hlcDbSxP9Gd/KRVlEFXgtGyZyt0CHwhY6Gk3EQ=
|
||||
github.com/sivukhin/godjot v1.0.6/go.mod h1:wA6KdR4Z+XpwdwyViPDLWYYxT72pKjNc6XGA9I025gM=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
||||
+1
-1
@@ -46,7 +46,7 @@ func (rl *Relay) handleNormal(ctx context.Context, evt nostr.Event) (skipBroadca
|
||||
} else {
|
||||
// otherwise it's a replaceable
|
||||
if nil != rl.ReplaceEvent {
|
||||
if err := rl.ReplaceEvent(ctx, evt); err != nil {
|
||||
if _, err := rl.ReplaceEvent(ctx, evt); err != nil {
|
||||
switch err {
|
||||
case eventstore.ErrDupEvent:
|
||||
return true, nil
|
||||
|
||||
@@ -78,6 +78,9 @@ func (rl *Relay) handleDeleteRequest(ctx context.Context, evt nostr.Event) error
|
||||
}
|
||||
|
||||
haveDeletedSomething = true
|
||||
if rl.OnEventDeleted != nil {
|
||||
rl.OnEventDeleted(ctx, target)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
---
|
||||
outline: deep
|
||||
---
|
||||
|
||||
# Request Routing
|
||||
|
||||
If you have one (or more) set of policies that have to be executed in sequence (for example, first you check for the presence of a tag, then later in the next policies you use that tag without checking) and they only apply to some class of events, but you still want your relay to deal with other classes of events that can lead to cumbersome sets of rules, always having to check if an event meets the requirements and so on. There is where routing can help you.
|
||||
|
||||
```go
|
||||
sk := os.Getenv("RELAY_SECRET_KEY")
|
||||
|
||||
// a relay for NIP-29 groups
|
||||
groupsStore := boltdb.BoltBackend{}
|
||||
groupsStore.Init()
|
||||
groupsRelay, _ := khatru29.Init(relay29.Options{Domain: "example.com", DB: groupsStore, SecretKey: sk})
|
||||
// ...
|
||||
|
||||
// a relay for everything else
|
||||
publicStore := slicestore.SliceStore{}
|
||||
publicStore.Init()
|
||||
publicRelay := khatru.NewRelay()
|
||||
publicRelay.UseEventStore(publicStore, 1000)
|
||||
// ...
|
||||
|
||||
// a higher-level relay that just routes between the two above
|
||||
router := khatru.NewRouter()
|
||||
|
||||
// route requests and events to the groups relay
|
||||
router.Route().
|
||||
Req(func (filter nostr.Filter) bool {
|
||||
_, hasHTag := filter.Tags["h"]
|
||||
if hasHTag {
|
||||
return true
|
||||
}
|
||||
return slices.Contains(filter.Kinds, func (k int) bool { return k == 39000 || k == 39001 || k == 39002 })
|
||||
}).
|
||||
Event(func (event *nostr.Event) bool {
|
||||
switch {
|
||||
case event.Kind <= 9021 && event.Kind >= 9000:
|
||||
return true
|
||||
case event.Kind <= 39010 && event.Kind >= 39000:
|
||||
return true
|
||||
case event.Kind <= 12 && event.Kind >= 9:
|
||||
return true
|
||||
case event.Tags.Find("h") != nil:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}).
|
||||
Relay(groupsRelay)
|
||||
|
||||
// route requests and events to the other
|
||||
router.Route().
|
||||
Req(func (filter nostr.Filter) bool { return true }).
|
||||
Event(func (event *nostr.Event) bool { return true }).
|
||||
Relay(publicRelay)
|
||||
```
|
||||
@@ -1,61 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore/lmdb"
|
||||
"fiatjaf.com/nostr/eventstore/slicestore"
|
||||
"fiatjaf.com/nostr/khatru"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db1 := &slicestore.SliceStore{}
|
||||
db1.Init()
|
||||
r1 := khatru.NewRelay()
|
||||
r1.UseEventstore(db1, 400)
|
||||
|
||||
db2 := &lmdb.LMDBBackend{Path: "/tmp/t"}
|
||||
db2.Init()
|
||||
r2 := khatru.NewRelay()
|
||||
r2.UseEventstore(db2, 400)
|
||||
|
||||
db3 := &slicestore.SliceStore{}
|
||||
db3.Init()
|
||||
r3 := khatru.NewRelay()
|
||||
r3.UseEventstore(db3, 400)
|
||||
|
||||
router := khatru.NewRouter()
|
||||
|
||||
router.Route().
|
||||
Req(func(filter nostr.Filter) bool {
|
||||
return slices.Contains(filter.Kinds, 30023)
|
||||
}).
|
||||
Event(func(event *nostr.Event) bool {
|
||||
return event.Kind == 30023
|
||||
}).
|
||||
Relay(r1)
|
||||
|
||||
router.Route().
|
||||
Req(func(filter nostr.Filter) bool {
|
||||
return slices.Contains(filter.Kinds, 1) && slices.Contains(filter.Tags["t"], "spam")
|
||||
}).
|
||||
Event(func(event *nostr.Event) bool {
|
||||
return event.Kind == 1 && event.Tags.FindWithValue("t", "spam") != nil
|
||||
}).
|
||||
Relay(r2)
|
||||
|
||||
router.Route().
|
||||
Req(func(filter nostr.Filter) bool {
|
||||
return slices.Contains(filter.Kinds, 1)
|
||||
}).
|
||||
Event(func(event *nostr.Event) bool {
|
||||
return event.Kind == 1
|
||||
}).
|
||||
Relay(r3)
|
||||
|
||||
fmt.Println("running on :3334")
|
||||
http.ListenAndServe(":3334", router)
|
||||
}
|
||||
+15
-3
@@ -39,9 +39,15 @@ type expirationManager struct {
|
||||
events expiringEventHeap
|
||||
mu sync.Mutex
|
||||
|
||||
// a function to query the relay database, generally the same as relay.queryStored
|
||||
queryStored func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event]
|
||||
|
||||
// a function to delete an event from the relay database, generally the same as relay.DeleteEvent
|
||||
deleteEvent func(ctx context.Context, id nostr.ID) error
|
||||
|
||||
// a function to call after an event has been deleted, generally the same as relay.OnEventDeleted
|
||||
deleteCallback func(ctx context.Context, id nostr.Event)
|
||||
|
||||
interval time.Duration
|
||||
initialScanDone bool
|
||||
kill chan struct{} // used for manually killing this
|
||||
@@ -109,7 +115,11 @@ func (em *expirationManager) checkExpiredEvents(ctx context.Context) {
|
||||
heap.Pop(&em.events)
|
||||
|
||||
ctx := context.WithValue(ctx, internalCallKey, struct{}{})
|
||||
em.deleteEvent(ctx, next.id)
|
||||
if nil == em.deleteEvent(ctx, next.id) && em.deleteCallback != nil {
|
||||
for evt := range em.queryStored(ctx, nostr.Filter{IDs: []nostr.ID{next.id}}) {
|
||||
em.deleteCallback(ctx, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,12 +152,14 @@ func (em *expirationManager) removeEvent(id nostr.ID) {
|
||||
func (rl *Relay) StartExpirationManager(
|
||||
queryStored func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event],
|
||||
deleteEvent func(ctx context.Context, id nostr.ID) error,
|
||||
onDeleteCallback func(ctx context.Context, evt nostr.Event),
|
||||
) {
|
||||
rl.expirationManager = &expirationManager{
|
||||
events: make(expiringEventHeap, 0),
|
||||
|
||||
queryStored: queryStored,
|
||||
deleteEvent: deleteEvent,
|
||||
queryStored: queryStored,
|
||||
deleteEvent: deleteEvent,
|
||||
deleteCallback: onDeleteCallback,
|
||||
|
||||
interval: time.Hour,
|
||||
kill: make(chan struct{}),
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
package khatru
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/websocket"
|
||||
"github.com/rs/cors"
|
||||
)
|
||||
|
||||
func (rl *Relay) Router() *http.ServeMux {
|
||||
return rl.serveMux
|
||||
}
|
||||
|
||||
func (rl *Relay) SetRouter(mux *http.ServeMux) {
|
||||
rl.serveMux = mux
|
||||
}
|
||||
|
||||
// Start creates an http server and starts listening on given host and port.
|
||||
func (rl *Relay) Start(host string, port int, started ...chan bool) error {
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rl.Addr = ln.Addr().String()
|
||||
rl.httpServer = &http.Server{
|
||||
Handler: cors.Default().Handler(rl),
|
||||
Addr: addr,
|
||||
WriteTimeout: 2 * time.Second,
|
||||
ReadTimeout: 2 * time.Second,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// notify caller that we're starting
|
||||
for _, started := range started {
|
||||
close(started)
|
||||
}
|
||||
|
||||
if err := rl.httpServer.Serve(ln); err == http.ErrServerClosed {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown sends a websocket close control message to all connected clients.
|
||||
func (rl *Relay) Shutdown(ctx context.Context) {
|
||||
rl.httpServer.Shutdown(ctx)
|
||||
rl.clientsMutex.Lock()
|
||||
defer rl.clientsMutex.Unlock()
|
||||
for ws := range rl.clients {
|
||||
ws.conn.WriteControl(websocket.CloseMessage, nil, time.Now().Add(time.Second))
|
||||
ws.cancel()
|
||||
ws.conn.Close()
|
||||
}
|
||||
clear(rl.clients)
|
||||
rl.listeners = rl.listeners[:0]
|
||||
}
|
||||
+28
-40
@@ -108,17 +108,20 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
),
|
||||
)
|
||||
|
||||
killOnce := sync.Once{}
|
||||
kill := func() {
|
||||
if nil != rl.OnDisconnect {
|
||||
rl.OnDisconnect(ctx)
|
||||
}
|
||||
killOnce.Do(func() {
|
||||
if nil != rl.OnDisconnect {
|
||||
rl.OnDisconnect(ctx)
|
||||
}
|
||||
|
||||
ticker.Stop()
|
||||
cancel()
|
||||
ws.cancel()
|
||||
ws.conn.Close()
|
||||
ticker.Stop()
|
||||
cancel()
|
||||
ws.cancel()
|
||||
ws.conn.Close()
|
||||
|
||||
rl.removeClientAndListeners(ws)
|
||||
rl.removeClientAndListeners(ws)
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -214,35 +217,30 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
srl := rl
|
||||
if rl.getSubRelayFromEvent != nil {
|
||||
srl = rl.getSubRelayFromEvent(&env.Event)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
var writeErr error
|
||||
var skipBroadcast bool
|
||||
|
||||
if env.Event.Kind == nostr.KindDeletion {
|
||||
// store the delete event first
|
||||
skipBroadcast, writeErr = srl.handleNormal(ctx, env.Event)
|
||||
skipBroadcast, writeErr = rl.handleNormal(ctx, env.Event)
|
||||
if writeErr == nil {
|
||||
// this always returns "blocked: " whenever it returns an error
|
||||
writeErr = srl.handleDeleteRequest(ctx, env.Event)
|
||||
writeErr = rl.handleDeleteRequest(ctx, env.Event)
|
||||
}
|
||||
} else if env.Event.Kind.IsEphemeral() {
|
||||
// this will also always return a prefixed reason
|
||||
writeErr = srl.handleEphemeral(ctx, env.Event)
|
||||
writeErr = rl.handleEphemeral(ctx, env.Event)
|
||||
} else {
|
||||
// this will also always return a prefixed reason
|
||||
skipBroadcast, writeErr = srl.handleNormal(ctx, env.Event)
|
||||
skipBroadcast, writeErr = rl.handleNormal(ctx, env.Event)
|
||||
}
|
||||
|
||||
var reason string
|
||||
if writeErr == nil {
|
||||
ok = true
|
||||
if !skipBroadcast {
|
||||
n := srl.notifyListeners(env.Event, false)
|
||||
n := rl.notifyListeners(env.Event, false)
|
||||
|
||||
// the number of notified listeners matters in ephemeral events
|
||||
if env.Event.Kind.IsEphemeral() {
|
||||
@@ -275,15 +273,10 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
var total uint32
|
||||
var hll *hyperloglog.HyperLogLog
|
||||
|
||||
srl := rl
|
||||
if rl.getSubRelayFromFilter != nil {
|
||||
srl = rl.getSubRelayFromFilter(env.Filter)
|
||||
}
|
||||
|
||||
if offset := nip45.HyperLogLogEventPubkeyOffsetForFilter(env.Filter); offset != -1 {
|
||||
total, hll = srl.handleCountRequestWithHLL(ctx, ws, env.Filter, offset)
|
||||
total, hll = rl.handleCountRequestWithHLL(ctx, ws, env.Filter, offset)
|
||||
} else {
|
||||
total = srl.handleCountRequest(ctx, ws, env.Filter)
|
||||
total = rl.handleCountRequest(ctx, ws, env.Filter)
|
||||
}
|
||||
|
||||
resp := nostr.CountEnvelope{
|
||||
@@ -308,11 +301,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handle each filter separately -- dispatching events as they're loaded from databases
|
||||
for _, filter := range env.Filters {
|
||||
srl := rl
|
||||
if rl.getSubRelayFromFilter != nil {
|
||||
srl = rl.getSubRelayFromFilter(filter)
|
||||
}
|
||||
err := srl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
|
||||
err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
|
||||
if err != nil {
|
||||
// fail everything if any filter is rejected
|
||||
reason := err.Error()
|
||||
@@ -322,8 +311,11 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason})
|
||||
cancelReqCtx(errors.New("filter rejected"))
|
||||
return
|
||||
} else {
|
||||
rl.addListener(ws, env.SubscriptionID, srl, filter, cancelReqCtx)
|
||||
} else if filter.IDs == nil {
|
||||
// a query that is just a bunch of "ids": [...] will not add listeners.
|
||||
// is this a bug? maybe, but I don't think anyone is listening for an ID
|
||||
// that hasn't been published yet anywhere -- if yes we can change later
|
||||
rl.addListener(ws, env.SubscriptionID, filter, cancelReqCtx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,15 +352,11 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate: " + err.Error()})
|
||||
}
|
||||
case *nip77.OpenEnvelope:
|
||||
srl := rl
|
||||
if rl.getSubRelayFromFilter != nil {
|
||||
srl = rl.getSubRelayFromFilter(env.Filter)
|
||||
if !srl.Negentropy {
|
||||
// ignore
|
||||
return
|
||||
}
|
||||
if !rl.Negentropy {
|
||||
// ignore
|
||||
return
|
||||
}
|
||||
vec, err := srl.startNegentropySession(ctx, env.Filter)
|
||||
vec, err := rl.startNegentropySession(ctx, env.Filter)
|
||||
if err != nil {
|
||||
// fail everything if any filter is rejected
|
||||
reason := err.Error()
|
||||
|
||||
+225
-77
@@ -3,18 +3,19 @@ package khatru
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"iter"
|
||||
|
||||
"fiatjaf.com/lib/set"
|
||||
"fiatjaf.com/nostr"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
)
|
||||
|
||||
var ErrSubscriptionClosedByClient = errors.New("subscription closed by client")
|
||||
|
||||
type listenerSpec struct {
|
||||
id string // kept here so we can easily match against it removeListenerId
|
||||
cancel context.CancelCauseFunc
|
||||
index int
|
||||
subrelay *Relay // this is important when we're dealing with routing, otherwise it will be always the same
|
||||
ssid int // internal numeric id for a listener
|
||||
sid string // client-provided subscription id
|
||||
cancel context.CancelCauseFunc
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
@@ -23,10 +24,199 @@ type listener struct {
|
||||
ws *WebSocket
|
||||
}
|
||||
|
||||
type subscription struct {
|
||||
id string
|
||||
filter nostr.Filter
|
||||
ws *WebSocket
|
||||
}
|
||||
|
||||
type dispatcher struct {
|
||||
serial int
|
||||
subscriptions *xsync.MapOf[int, subscription]
|
||||
byAuthor map[nostr.PubKey]set.Set[int]
|
||||
byKind map[nostr.Kind]set.Set[int]
|
||||
fallback set.Set[int]
|
||||
}
|
||||
|
||||
func newDispatcher() dispatcher {
|
||||
return dispatcher{
|
||||
subscriptions: xsync.NewMapOf[int, subscription](),
|
||||
byAuthor: make(map[nostr.PubKey]set.Set[int]),
|
||||
byKind: make(map[nostr.Kind]set.Set[int]),
|
||||
fallback: set.NewSliceSet[int](),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dispatcher) addSubscription(sub subscription) int {
|
||||
d.serial++
|
||||
ssid := d.serial
|
||||
|
||||
d.subscriptions.Store(ssid, sub)
|
||||
|
||||
indexed := false
|
||||
if sub.filter.Authors != nil {
|
||||
indexed = true
|
||||
for _, author := range sub.filter.Authors {
|
||||
s, ok := d.byAuthor[author]
|
||||
if !ok {
|
||||
s = set.NewSliceSet[int]()
|
||||
d.byAuthor[author] = s
|
||||
}
|
||||
s.Add(ssid)
|
||||
}
|
||||
}
|
||||
|
||||
if sub.filter.Kinds != nil {
|
||||
indexed = true
|
||||
for _, kind := range sub.filter.Kinds {
|
||||
s, ok := d.byKind[kind]
|
||||
if !ok {
|
||||
s = set.NewSliceSet[int]()
|
||||
d.byKind[kind] = s
|
||||
}
|
||||
s.Add(ssid)
|
||||
}
|
||||
}
|
||||
|
||||
if !indexed {
|
||||
d.fallback.Add(ssid)
|
||||
}
|
||||
|
||||
return ssid
|
||||
}
|
||||
|
||||
func (d *dispatcher) removeSubscription(ssid int) {
|
||||
sub, ok := d.subscriptions.LoadAndDelete(ssid)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
indexed := false
|
||||
if sub.filter.Authors != nil {
|
||||
indexed = true
|
||||
for _, author := range sub.filter.Authors {
|
||||
s, ok := d.byAuthor[author]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.Remove(ssid)
|
||||
if s.Len() == 0 {
|
||||
delete(d.byAuthor, author)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sub.filter.Kinds != nil {
|
||||
indexed = true
|
||||
for _, kind := range sub.filter.Kinds {
|
||||
s, ok := d.byKind[kind]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.Remove(ssid)
|
||||
if s.Len() == 0 {
|
||||
delete(d.byKind, kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !indexed {
|
||||
d.fallback.Remove(ssid)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dispatcher) candidates(event nostr.Event) iter.Seq[subscription] {
|
||||
return func(yield func(subscription) bool) {
|
||||
authorSubs, hasAuthorSubs := d.byAuthor[event.PubKey]
|
||||
kindSubs, hasKindSubs := d.byKind[event.Kind]
|
||||
|
||||
if hasAuthorSubs && hasKindSubs {
|
||||
for _, ssid := range authorSubs.Slice() {
|
||||
sub, _ := d.subscriptions.Load(ssid)
|
||||
|
||||
if kindSubs.Has(ssid) {
|
||||
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||
if !yield(sub) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// matched author but not tags, so this event doesn't qualify for any filter
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else if hasAuthorSubs {
|
||||
for _, ssid := range authorSubs.Slice() {
|
||||
sub, _ := d.subscriptions.Load(ssid)
|
||||
|
||||
if sub.filter.Kinds != nil {
|
||||
// if there are any kinds in the filter we already know this doesn't qualify
|
||||
continue
|
||||
}
|
||||
|
||||
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||
if !yield(sub) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if hasKindSubs {
|
||||
for _, ssid := range kindSubs.Slice() {
|
||||
sub, _ := d.subscriptions.Load(ssid)
|
||||
|
||||
if sub.filter.Authors != nil {
|
||||
// if there are any authors in the filter we already know this doesn't qualify
|
||||
continue
|
||||
}
|
||||
|
||||
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||
if !yield(sub) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ssid := range d.fallback.Slice() {
|
||||
sub, _ := d.subscriptions.Load(ssid)
|
||||
|
||||
if filterMatchesTimestampConstraintsAndTags(sub.filter, event) {
|
||||
if !yield(sub) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func filterMatchesTimestampConstraintsAndTags(filter nostr.Filter, event nostr.Event) bool {
|
||||
if filter.Since != 0 && event.CreatedAt < filter.Since {
|
||||
return false
|
||||
}
|
||||
|
||||
if filter.Until != 0 && event.CreatedAt > filter.Until {
|
||||
return false
|
||||
}
|
||||
|
||||
for f, v := range filter.Tags {
|
||||
if !event.Tags.ContainsAny(f, v) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func tagKeyValueKey(tagKey, tagValue string) string {
|
||||
return tagKey + "\x00" + tagValue
|
||||
}
|
||||
|
||||
func (rl *Relay) GetListeningFilters() []nostr.Filter {
|
||||
respfilters := make([]nostr.Filter, len(rl.listeners))
|
||||
for i, l := range rl.listeners {
|
||||
respfilters[i] = l.filter
|
||||
respfilters := make([]nostr.Filter, 0, rl.dispatcher.subscriptions.Size())
|
||||
for _, sub := range rl.dispatcher.subscriptions.Range {
|
||||
respfilters = append(respfilters, sub.filter)
|
||||
}
|
||||
return respfilters
|
||||
}
|
||||
@@ -36,26 +226,27 @@ func (rl *Relay) GetListeningFilters() []nostr.Filter {
|
||||
func (rl *Relay) addListener(
|
||||
ws *WebSocket,
|
||||
id string,
|
||||
subrelay *Relay,
|
||||
filter nostr.Filter,
|
||||
cancel context.CancelCauseFunc,
|
||||
) {
|
||||
rl.clientsMutex.Lock()
|
||||
defer rl.clientsMutex.Unlock()
|
||||
select {
|
||||
case <-rl.clientsMutex.C():
|
||||
defer rl.clientsMutex.Unlock()
|
||||
case <-ws.Context.Done():
|
||||
return
|
||||
}
|
||||
|
||||
if specs, ok := rl.clients[ws]; ok /* this will always be true unless client has disconnected very rapidly */ {
|
||||
idx := len(subrelay.listeners)
|
||||
rl.clients[ws] = append(specs, listenerSpec{
|
||||
id: id,
|
||||
cancel: cancel,
|
||||
subrelay: subrelay,
|
||||
index: idx,
|
||||
})
|
||||
subrelay.listeners = append(subrelay.listeners, listener{
|
||||
ssid := rl.dispatcher.addSubscription(subscription{
|
||||
ws: ws,
|
||||
id: id,
|
||||
filter: filter,
|
||||
})
|
||||
rl.clients[ws] = append(specs, listenerSpec{
|
||||
ssid: ssid,
|
||||
cancel: cancel,
|
||||
sid: id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,35 +257,16 @@ func (rl *Relay) removeListenerId(ws *WebSocket, id string) {
|
||||
defer rl.clientsMutex.Unlock()
|
||||
|
||||
if specs, ok := rl.clients[ws]; ok {
|
||||
// swap delete specs that match this id
|
||||
for s := len(specs) - 1; s >= 0; s-- {
|
||||
spec := specs[s]
|
||||
if spec.id == id {
|
||||
kept := specs[:0]
|
||||
for _, spec := range specs {
|
||||
if spec.sid == id {
|
||||
spec.cancel(ErrSubscriptionClosedByClient)
|
||||
specs[s] = specs[len(specs)-1]
|
||||
specs = specs[0 : len(specs)-1]
|
||||
rl.clients[ws] = specs
|
||||
|
||||
// swap delete listeners one at a time, as they may be each in a different subrelay
|
||||
srl := spec.subrelay // == rl in normal cases, but different when this came from a route
|
||||
|
||||
if spec.index != len(srl.listeners)-1 {
|
||||
movedFromIndex := len(srl.listeners) - 1
|
||||
moved := srl.listeners[movedFromIndex] // this wasn't removed, but will be moved
|
||||
srl.listeners[spec.index] = moved
|
||||
|
||||
// now we must update the the listener we just moved
|
||||
// so its .index reflects its new position on srl.listeners
|
||||
movedSpecs := rl.clients[moved.ws]
|
||||
idx := slices.IndexFunc(movedSpecs, func(ls listenerSpec) bool {
|
||||
return ls.index == movedFromIndex && ls.subrelay == srl
|
||||
})
|
||||
movedSpecs[idx].index = spec.index
|
||||
rl.clients[moved.ws] = movedSpecs
|
||||
}
|
||||
srl.listeners = srl.listeners[0 : len(srl.listeners)-1] // finally reduce the slice length
|
||||
rl.dispatcher.removeSubscription(spec.ssid)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, spec)
|
||||
}
|
||||
rl.clients[ws] = kept
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,31 +274,9 @@ func (rl *Relay) removeClientAndListeners(ws *WebSocket) {
|
||||
rl.clientsMutex.Lock()
|
||||
defer rl.clientsMutex.Unlock()
|
||||
if specs, ok := rl.clients[ws]; ok {
|
||||
// swap delete listeners and delete client (all specs will be deleted)
|
||||
for s, spec := range specs {
|
||||
for _, spec := range specs {
|
||||
// no need to cancel contexts since they inherit from the main connection context
|
||||
// just delete the listeners (swap-delete)
|
||||
srl := spec.subrelay
|
||||
|
||||
if spec.index != len(srl.listeners)-1 {
|
||||
movedFromIndex := len(srl.listeners) - 1
|
||||
moved := srl.listeners[movedFromIndex] // this wasn't removed, but will be moved
|
||||
srl.listeners[spec.index] = moved
|
||||
|
||||
// temporarily update the spec of the listener being removed to have index == -1
|
||||
// (since it was removed) so it doesn't match in the search below
|
||||
rl.clients[ws][s].index = -1
|
||||
|
||||
// now we must update the the listener we just moved
|
||||
// so its .index reflects its new position on srl.listeners
|
||||
movedSpecs := rl.clients[moved.ws]
|
||||
idx := slices.IndexFunc(movedSpecs, func(ls listenerSpec) bool {
|
||||
return ls.index == movedFromIndex && ls.subrelay == srl
|
||||
})
|
||||
movedSpecs[idx].index = spec.index
|
||||
rl.clients[moved.ws] = movedSpecs
|
||||
}
|
||||
srl.listeners = srl.listeners[0 : len(srl.listeners)-1] // finally reduce the slice length
|
||||
rl.dispatcher.removeSubscription(spec.ssid)
|
||||
}
|
||||
}
|
||||
delete(rl.clients, ws)
|
||||
@@ -136,16 +286,14 @@ func (rl *Relay) removeClientAndListeners(ws *WebSocket) {
|
||||
func (rl *Relay) notifyListeners(event nostr.Event, skipPrevent bool) int {
|
||||
count := 0
|
||||
listenersloop:
|
||||
for _, listener := range rl.listeners {
|
||||
if listener.filter.Matches(event) {
|
||||
if !skipPrevent && nil != rl.PreventBroadcast {
|
||||
if rl.PreventBroadcast(listener.ws, listener.filter, event) {
|
||||
continue listenersloop
|
||||
}
|
||||
for sub := range rl.dispatcher.candidates(event) {
|
||||
if !skipPrevent && nil != rl.PreventBroadcast {
|
||||
if rl.PreventBroadcast(sub.ws, sub.filter, event) {
|
||||
continue listenersloop
|
||||
}
|
||||
listener.ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &listener.id, Event: event})
|
||||
count++
|
||||
}
|
||||
sub.ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &sub.id, Event: event})
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func FuzzRandomListenerClientRemoving(f *testing.F) {
|
||||
l := 0
|
||||
|
||||
for i := 0; i < totalWebsockets; i++ {
|
||||
ws := &WebSocket{}
|
||||
ws := &WebSocket{Context: rl.ctx}
|
||||
websockets = append(websockets, ws)
|
||||
rl.clients[ws] = nil
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func FuzzRandomListenerClientRemoving(f *testing.F) {
|
||||
|
||||
if s%addListenerFreq == 0 {
|
||||
l++
|
||||
rl.addListener(ws, w+":"+idFromSeqLower(j), rl, f, cancel)
|
||||
rl.addListener(ws, w+":"+idFromSeqLower(j), f, cancel)
|
||||
}
|
||||
|
||||
s++
|
||||
@@ -46,14 +46,22 @@ func FuzzRandomListenerClientRemoving(f *testing.F) {
|
||||
}
|
||||
|
||||
require.Len(t, rl.clients, totalWebsockets)
|
||||
require.Len(t, rl.listeners, l)
|
||||
ssidCount := 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, l, ssidCount)
|
||||
|
||||
for ws := range rl.clients {
|
||||
rl.removeClientAndListeners(ws)
|
||||
}
|
||||
|
||||
require.Len(t, rl.clients, 0)
|
||||
require.Len(t, rl.listeners, 0)
|
||||
ssidCount = 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, 0, ssidCount)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -84,7 +92,7 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
||||
extra := 0
|
||||
|
||||
for i := 0; i < totalWebsockets; i++ {
|
||||
ws := &WebSocket{}
|
||||
ws := &WebSocket{Context: rl.ctx}
|
||||
websockets = append(websockets, ws)
|
||||
rl.clients[ws] = nil
|
||||
}
|
||||
@@ -97,11 +105,11 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
||||
|
||||
if s%addListenerFreq == 0 {
|
||||
id := w + ":" + idFromSeqLower(j)
|
||||
rl.addListener(ws, id, rl, f, cancel)
|
||||
rl.addListener(ws, id, f, cancel)
|
||||
subs = append(subs, wsid{ws, id})
|
||||
|
||||
if s%addExtraListenerFreq == 0 {
|
||||
rl.addListener(ws, id, rl, f, cancel)
|
||||
rl.addListener(ws, id, f, cancel)
|
||||
extra++
|
||||
}
|
||||
}
|
||||
@@ -111,7 +119,11 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
||||
}
|
||||
|
||||
require.Len(t, rl.clients, totalWebsockets)
|
||||
require.Len(t, rl.listeners, len(subs)+extra)
|
||||
ssidCount := 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, len(subs)+extra, ssidCount)
|
||||
|
||||
rand.Shuffle(len(subs), func(i, j int) {
|
||||
subs[i], subs[j] = subs[j], subs[i]
|
||||
@@ -120,7 +132,11 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
||||
rl.removeListenerId(wsidToRemove.ws, wsidToRemove.id)
|
||||
}
|
||||
|
||||
require.Len(t, rl.listeners, 0)
|
||||
ssidCount = 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, 0, ssidCount)
|
||||
require.Len(t, rl.clients, totalWebsockets)
|
||||
for _, specs := range rl.clients {
|
||||
require.Len(t, specs, 0)
|
||||
@@ -129,23 +145,17 @@ func FuzzRandomListenerIdRemoving(f *testing.F) {
|
||||
}
|
||||
|
||||
func FuzzRouterListenersPabloCrash(f *testing.F) {
|
||||
f.Add(uint(3), uint(6), uint(2), uint(20))
|
||||
f.Fuzz(func(t *testing.T, totalRelays uint, totalConns uint, subFreq uint, subIterations uint) {
|
||||
totalRelays++
|
||||
f.Add(uint(6), uint(2), uint(20))
|
||||
f.Fuzz(func(t *testing.T, totalConns uint, subFreq uint, subIterations uint) {
|
||||
totalConns++
|
||||
subFreq++
|
||||
subIterations++
|
||||
|
||||
rl := NewRelay()
|
||||
|
||||
relays := make([]*Relay, int(totalRelays))
|
||||
for i := 0; i < int(totalRelays); i++ {
|
||||
relays[i] = NewRelay()
|
||||
}
|
||||
|
||||
conns := make([]*WebSocket, int(totalConns))
|
||||
for i := 0; i < int(totalConns); i++ {
|
||||
ws := &WebSocket{}
|
||||
ws := &WebSocket{Context: rl.ctx}
|
||||
conns[i] = ws
|
||||
rl.clients[ws] = make([]listenerSpec, 0, subIterations)
|
||||
}
|
||||
@@ -159,18 +169,16 @@ func FuzzRouterListenersPabloCrash(f *testing.F) {
|
||||
}
|
||||
|
||||
s := 0
|
||||
subs := make([]wsid, 0, subIterations*totalConns*totalRelays)
|
||||
subs := make([]wsid, 0, subIterations*totalConns)
|
||||
for i, conn := range conns {
|
||||
w := idFromSeqUpper(i)
|
||||
for j := 0; j < int(subIterations); j++ {
|
||||
id := w + ":" + idFromSeqLower(j)
|
||||
for _, rlt := range relays {
|
||||
if s%int(subFreq) == 0 {
|
||||
rl.addListener(conn, id, rlt, f, cancel)
|
||||
subs = append(subs, wsid{conn, id})
|
||||
}
|
||||
s++
|
||||
if s%int(subFreq) == 0 {
|
||||
rl.addListener(conn, id, f, cancel)
|
||||
subs = append(subs, wsid{conn, id})
|
||||
}
|
||||
s++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,8 +189,5 @@ func FuzzRouterListenersPabloCrash(f *testing.F) {
|
||||
for _, wsid := range subs {
|
||||
require.Len(t, rl.clients[wsid.ws], 0)
|
||||
}
|
||||
for _, rlt := range relays {
|
||||
require.Len(t, rlt.listeners, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+117
-225
@@ -26,8 +26,8 @@ func idFromSeq(seq int, min, max int) string {
|
||||
func TestListenerSetupAndRemoveOnce(t *testing.T) {
|
||||
rl := NewRelay()
|
||||
|
||||
ws1 := &WebSocket{}
|
||||
ws2 := &WebSocket{}
|
||||
ws1 := &WebSocket{Context: rl.ctx}
|
||||
ws2 := &WebSocket{Context: rl.ctx}
|
||||
|
||||
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
||||
@@ -39,28 +39,21 @@ func TestListenerSetupAndRemoveOnce(t *testing.T) {
|
||||
var cancel func(cause error) = nil
|
||||
|
||||
t.Run("adding listeners", func(t *testing.T) {
|
||||
rl.addListener(ws1, "1a", rl, f1, cancel)
|
||||
rl.addListener(ws1, "1b", rl, f2, cancel)
|
||||
rl.addListener(ws2, "2a", rl, f3, cancel)
|
||||
rl.addListener(ws1, "1c", rl, f3, cancel)
|
||||
rl.addListener(ws1, "1a", f1, cancel)
|
||||
rl.addListener(ws1, "1b", f2, cancel)
|
||||
rl.addListener(ws2, "2a", f3, cancel)
|
||||
rl.addListener(ws1, "1c", f3, cancel)
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"1a", cancel, 0, rl},
|
||||
{"1b", cancel, 1, rl},
|
||||
{"1c", cancel, 3, rl},
|
||||
{1, "1a", cancel},
|
||||
{2, "1b", cancel},
|
||||
{4, "1c", cancel},
|
||||
},
|
||||
ws2: {
|
||||
{"2a", cancel, 2, rl},
|
||||
{3, "2a", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"1a", f1, ws1},
|
||||
{"1b", f2, ws1},
|
||||
{"2a", f3, ws2},
|
||||
{"1c", f3, ws1},
|
||||
}, rl.listeners)
|
||||
})
|
||||
|
||||
t.Run("removing a client", func(t *testing.T) {
|
||||
@@ -68,23 +61,19 @@ func TestListenerSetupAndRemoveOnce(t *testing.T) {
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws2: {
|
||||
{"2a", cancel, 0, rl},
|
||||
{3, "2a", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"2a", f3, ws2},
|
||||
}, rl.listeners)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenerMoreConvolutedCase(t *testing.T) {
|
||||
rl := NewRelay()
|
||||
|
||||
ws1 := &WebSocket{}
|
||||
ws2 := &WebSocket{}
|
||||
ws3 := &WebSocket{}
|
||||
ws4 := &WebSocket{}
|
||||
ws1 := &WebSocket{Context: rl.ctx}
|
||||
ws2 := &WebSocket{Context: rl.ctx}
|
||||
ws3 := &WebSocket{Context: rl.ctx}
|
||||
ws4 := &WebSocket{Context: rl.ctx}
|
||||
|
||||
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
||||
@@ -98,35 +87,27 @@ func TestListenerMoreConvolutedCase(t *testing.T) {
|
||||
var cancel func(cause error) = nil
|
||||
|
||||
t.Run("adding listeners", func(t *testing.T) {
|
||||
rl.addListener(ws1, "c", rl, f1, cancel)
|
||||
rl.addListener(ws2, "b", rl, f2, cancel)
|
||||
rl.addListener(ws3, "a", rl, f3, cancel)
|
||||
rl.addListener(ws4, "d", rl, f3, cancel)
|
||||
rl.addListener(ws2, "b", rl, f1, cancel)
|
||||
rl.addListener(ws1, "c", f1, cancel)
|
||||
rl.addListener(ws2, "b", f2, cancel)
|
||||
rl.addListener(ws3, "a", f3, cancel)
|
||||
rl.addListener(ws4, "d", f3, cancel)
|
||||
rl.addListener(ws2, "b", f1, cancel)
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 0, rl},
|
||||
{1, "c", cancel},
|
||||
},
|
||||
ws2: {
|
||||
{"b", cancel, 1, rl},
|
||||
{"b", cancel, 4, rl},
|
||||
{2, "b", cancel},
|
||||
{5, "b", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"a", cancel, 2, rl},
|
||||
{3, "a", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"d", cancel, 3, rl},
|
||||
{4, "d", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"c", f1, ws1},
|
||||
{"b", f2, ws2},
|
||||
{"a", f3, ws3},
|
||||
{"d", f3, ws4},
|
||||
{"b", f1, ws2},
|
||||
}, rl.listeners)
|
||||
})
|
||||
|
||||
t.Run("removing a client", func(t *testing.T) {
|
||||
@@ -134,85 +115,62 @@ func TestListenerMoreConvolutedCase(t *testing.T) {
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 0, rl},
|
||||
{1, "c", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"a", cancel, 2, rl},
|
||||
{3, "a", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"d", cancel, 1, rl},
|
||||
{4, "d", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"c", f1, ws1},
|
||||
{"d", f3, ws4},
|
||||
{"a", f3, ws3},
|
||||
}, rl.listeners)
|
||||
})
|
||||
|
||||
t.Run("reorganize the first case differently and then remove again", func(t *testing.T) {
|
||||
rl.clients = map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 1, rl},
|
||||
{2, "c", cancel},
|
||||
},
|
||||
ws2: {
|
||||
{"b", cancel, 2, rl},
|
||||
{"b", cancel, 4, rl},
|
||||
{3, "b", cancel},
|
||||
{5, "b", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"a", cancel, 0, rl},
|
||||
{1, "a", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"d", cancel, 3, rl},
|
||||
{4, "d", cancel},
|
||||
},
|
||||
}
|
||||
rl.listeners = []listener{
|
||||
{"a", f3, ws3},
|
||||
{"c", f1, ws1},
|
||||
{"b", f2, ws2},
|
||||
{"d", f3, ws4},
|
||||
{"b", f1, ws2},
|
||||
}
|
||||
|
||||
rl.removeClientAndListeners(ws2)
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 1, rl},
|
||||
{2, "c", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"a", cancel, 0, rl},
|
||||
{1, "a", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"d", cancel, 2, rl},
|
||||
{4, "d", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"a", f3, ws3},
|
||||
{"c", f1, ws1},
|
||||
{"d", f3, ws4},
|
||||
}, rl.listeners)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
||||
rl := NewRelay()
|
||||
|
||||
ws1 := &WebSocket{}
|
||||
ws2 := &WebSocket{}
|
||||
ws3 := &WebSocket{}
|
||||
ws4 := &WebSocket{}
|
||||
ws1 := &WebSocket{Context: rl.ctx}
|
||||
ws2 := &WebSocket{Context: rl.ctx}
|
||||
ws3 := &WebSocket{Context: rl.ctx}
|
||||
ws4 := &WebSocket{Context: rl.ctx}
|
||||
|
||||
f1 := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||
f2 := nostr.Filter{Kinds: []nostr.Kind{2}}
|
||||
f3 := nostr.Filter{Kinds: []nostr.Kind{3}}
|
||||
|
||||
rlx := NewRelay()
|
||||
rly := NewRelay()
|
||||
rlz := NewRelay()
|
||||
|
||||
rl.clients[ws1] = nil
|
||||
rl.clients[ws2] = nil
|
||||
rl.clients[ws3] = nil
|
||||
@@ -221,56 +179,37 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
||||
var cancel func(cause error) = nil
|
||||
|
||||
t.Run("adding listeners", func(t *testing.T) {
|
||||
rl.addListener(ws1, "c", rlx, f1, cancel)
|
||||
rl.addListener(ws2, "b", rly, f2, cancel)
|
||||
rl.addListener(ws3, "a", rlz, f3, cancel)
|
||||
rl.addListener(ws4, "d", rlx, f3, cancel)
|
||||
rl.addListener(ws4, "e", rlx, f3, cancel)
|
||||
rl.addListener(ws3, "a", rlx, f3, cancel)
|
||||
rl.addListener(ws4, "e", rly, f3, cancel)
|
||||
rl.addListener(ws3, "f", rly, f3, cancel)
|
||||
rl.addListener(ws1, "g", rlz, f1, cancel)
|
||||
rl.addListener(ws2, "g", rlz, f2, cancel)
|
||||
rl.addListener(ws1, "c", f1, cancel)
|
||||
rl.addListener(ws2, "b", f2, cancel)
|
||||
rl.addListener(ws3, "a", f3, cancel)
|
||||
rl.addListener(ws4, "d", f3, cancel)
|
||||
rl.addListener(ws4, "e", f3, cancel)
|
||||
rl.addListener(ws3, "a", f3, cancel)
|
||||
rl.addListener(ws4, "e", f3, cancel)
|
||||
rl.addListener(ws3, "f", f3, cancel)
|
||||
rl.addListener(ws1, "g", f1, cancel)
|
||||
rl.addListener(ws2, "g", f2, cancel)
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 0, rlx},
|
||||
{"g", cancel, 1, rlz},
|
||||
{1, "c", cancel},
|
||||
{9, "g", cancel},
|
||||
},
|
||||
ws2: {
|
||||
{"b", cancel, 0, rly},
|
||||
{"g", cancel, 2, rlz},
|
||||
{2, "b", cancel},
|
||||
{10, "g", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"a", cancel, 0, rlz},
|
||||
{"a", cancel, 3, rlx},
|
||||
{"f", cancel, 2, rly},
|
||||
{3, "a", cancel},
|
||||
{6, "a", cancel},
|
||||
{8, "f", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"d", cancel, 1, rlx},
|
||||
{"e", cancel, 2, rlx},
|
||||
{"e", cancel, 1, rly},
|
||||
{4, "d", cancel},
|
||||
{5, "e", cancel},
|
||||
{7, "e", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"c", f1, ws1},
|
||||
{"d", f3, ws4},
|
||||
{"e", f3, ws4},
|
||||
{"a", f3, ws3},
|
||||
}, rlx.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"b", f2, ws2},
|
||||
{"e", f3, ws4},
|
||||
{"f", f3, ws3},
|
||||
}, rly.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"a", f3, ws3},
|
||||
{"g", f1, ws1},
|
||||
{"g", f2, ws2},
|
||||
}, rlz.listeners)
|
||||
})
|
||||
|
||||
t.Run("removing a subscription id", func(t *testing.T) {
|
||||
@@ -280,41 +219,23 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 0, rlx},
|
||||
{"g", cancel, 1, rlz},
|
||||
{1, "c", cancel},
|
||||
{9, "g", cancel},
|
||||
},
|
||||
ws2: {
|
||||
{"b", cancel, 0, rly},
|
||||
{"g", cancel, 2, rlz},
|
||||
{2, "b", cancel},
|
||||
{10, "g", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"a", cancel, 0, rlz},
|
||||
{"a", cancel, 1, rlx},
|
||||
{"f", cancel, 2, rly},
|
||||
{3, "a", cancel},
|
||||
{6, "a", cancel},
|
||||
{8, "f", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"e", cancel, 1, rly},
|
||||
{"e", cancel, 2, rlx},
|
||||
{5, "e", cancel},
|
||||
{7, "e", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"c", f1, ws1},
|
||||
{"a", f3, ws3},
|
||||
{"e", f3, ws4},
|
||||
}, rlx.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"b", f2, ws2},
|
||||
{"e", f3, ws4},
|
||||
{"f", f3, ws3},
|
||||
}, rly.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"a", f3, ws3},
|
||||
{"g", f1, ws1},
|
||||
{"g", f2, ws2},
|
||||
}, rlz.listeners)
|
||||
})
|
||||
|
||||
t.Run("removing another subscription id", func(t *testing.T) {
|
||||
@@ -325,37 +246,21 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 0, rlx},
|
||||
{"g", cancel, 1, rlz},
|
||||
{1, "c", cancel},
|
||||
{9, "g", cancel},
|
||||
},
|
||||
ws2: {
|
||||
{"b", cancel, 0, rly},
|
||||
{"g", cancel, 0, rlz},
|
||||
{2, "b", cancel},
|
||||
{10, "g", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"f", cancel, 2, rly},
|
||||
{8, "f", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"e", cancel, 1, rly},
|
||||
{"e", cancel, 1, rlx},
|
||||
{5, "e", cancel},
|
||||
{7, "e", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"c", f1, ws1},
|
||||
{"e", f3, ws4},
|
||||
}, rlx.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"b", f2, ws2},
|
||||
{"e", f3, ws4},
|
||||
{"f", f3, ws3},
|
||||
}, rly.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"g", f2, ws2},
|
||||
{"g", f1, ws1},
|
||||
}, rlz.listeners)
|
||||
})
|
||||
|
||||
t.Run("removing a connection", func(t *testing.T) {
|
||||
@@ -363,31 +268,17 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 0, rlx},
|
||||
{"g", cancel, 0, rlz},
|
||||
{1, "c", cancel},
|
||||
{9, "g", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"f", cancel, 0, rly},
|
||||
{8, "f", cancel},
|
||||
},
|
||||
ws4: {
|
||||
{"e", cancel, 1, rly},
|
||||
{"e", cancel, 1, rlx},
|
||||
{5, "e", cancel},
|
||||
{7, "e", cancel},
|
||||
},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"c", f1, ws1},
|
||||
{"e", f3, ws4},
|
||||
}, rlx.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"f", f3, ws3},
|
||||
{"e", f3, ws4},
|
||||
}, rly.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"g", f1, ws1},
|
||||
}, rlz.listeners)
|
||||
})
|
||||
|
||||
t.Run("removing another subscription id", func(t *testing.T) {
|
||||
@@ -398,26 +289,14 @@ func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
|
||||
|
||||
require.Equal(t, map[*WebSocket][]listenerSpec{
|
||||
ws1: {
|
||||
{"c", cancel, 0, rlx},
|
||||
{"g", cancel, 0, rlz},
|
||||
{1, "c", cancel},
|
||||
{9, "g", cancel},
|
||||
},
|
||||
ws3: {
|
||||
{"f", cancel, 0, rly},
|
||||
{8, "f", cancel},
|
||||
},
|
||||
ws4: {},
|
||||
}, rl.clients)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"c", f1, ws1},
|
||||
}, rlx.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"f", f3, ws3},
|
||||
}, rly.listeners)
|
||||
|
||||
require.Equal(t, []listener{
|
||||
{"g", f1, ws1},
|
||||
}, rlz.listeners)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -432,7 +311,7 @@ func TestRandomListenerClientRemoving(t *testing.T) {
|
||||
l := 0
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
ws := &WebSocket{}
|
||||
ws := &WebSocket{Context: rl.ctx}
|
||||
websockets = append(websockets, ws)
|
||||
rl.clients[ws] = nil
|
||||
}
|
||||
@@ -444,20 +323,28 @@ func TestRandomListenerClientRemoving(t *testing.T) {
|
||||
|
||||
if rand.Intn(2) < 1 {
|
||||
l++
|
||||
rl.addListener(ws, w+":"+idFromSeqLower(j), rl, f, cancel)
|
||||
rl.addListener(ws, w+":"+idFromSeqLower(j), f, cancel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, rl.clients, 20)
|
||||
require.Len(t, rl.listeners, l)
|
||||
ssidCount := 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, l, ssidCount)
|
||||
|
||||
for ws := range rl.clients {
|
||||
rl.removeClientAndListeners(ws)
|
||||
}
|
||||
|
||||
require.Len(t, rl.clients, 0)
|
||||
require.Len(t, rl.listeners, 0)
|
||||
ssidCount = 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, 0, ssidCount)
|
||||
}
|
||||
|
||||
func TestRandomListenerIdRemoving(t *testing.T) {
|
||||
@@ -477,7 +364,7 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
||||
extra := 0
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
ws := &WebSocket{}
|
||||
ws := &WebSocket{Context: rl.ctx}
|
||||
websockets = append(websockets, ws)
|
||||
rl.clients[ws] = nil
|
||||
}
|
||||
@@ -489,11 +376,11 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
||||
|
||||
if rand.Intn(2) < 1 {
|
||||
id := w + ":" + idFromSeqLower(j)
|
||||
rl.addListener(ws, id, rl, f, cancel)
|
||||
rl.addListener(ws, id, f, cancel)
|
||||
subs = append(subs, wsid{ws, id})
|
||||
|
||||
if rand.Intn(5) < 1 {
|
||||
rl.addListener(ws, id, rl, f, cancel)
|
||||
rl.addListener(ws, id, f, cancel)
|
||||
extra++
|
||||
}
|
||||
}
|
||||
@@ -501,7 +388,11 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Len(t, rl.clients, 20)
|
||||
require.Len(t, rl.listeners, len(subs)+extra)
|
||||
ssidCount := 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, len(subs)+extra, ssidCount)
|
||||
|
||||
rand.Shuffle(len(subs), func(i, j int) {
|
||||
subs[i], subs[j] = subs[j], subs[i]
|
||||
@@ -510,7 +401,11 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
||||
rl.removeListenerId(wsidToRemove.ws, wsidToRemove.id)
|
||||
}
|
||||
|
||||
require.Len(t, rl.listeners, 0)
|
||||
ssidCount = 0
|
||||
for _, specs := range rl.clients {
|
||||
ssidCount += len(specs)
|
||||
}
|
||||
require.Equal(t, 0, ssidCount)
|
||||
require.Len(t, rl.clients, 20)
|
||||
for _, specs := range rl.clients {
|
||||
require.Len(t, specs, 0)
|
||||
@@ -520,12 +415,9 @@ func TestRandomListenerIdRemoving(t *testing.T) {
|
||||
func TestRouterListenersPabloCrash(t *testing.T) {
|
||||
rl := NewRelay()
|
||||
|
||||
rla := NewRelay()
|
||||
rlb := NewRelay()
|
||||
|
||||
ws1 := &WebSocket{}
|
||||
ws2 := &WebSocket{}
|
||||
ws3 := &WebSocket{}
|
||||
ws1 := &WebSocket{Context: rl.ctx}
|
||||
ws2 := &WebSocket{Context: rl.ctx}
|
||||
ws3 := &WebSocket{Context: rl.ctx}
|
||||
|
||||
rl.clients[ws1] = nil
|
||||
rl.clients[ws2] = nil
|
||||
@@ -534,11 +426,11 @@ func TestRouterListenersPabloCrash(t *testing.T) {
|
||||
f := nostr.Filter{Kinds: []nostr.Kind{1}}
|
||||
cancel := func(cause error) {}
|
||||
|
||||
rl.addListener(ws1, ":1", rla, f, cancel)
|
||||
rl.addListener(ws2, ":1", rlb, f, cancel)
|
||||
rl.addListener(ws3, "a", rlb, f, cancel)
|
||||
rl.addListener(ws3, "b", rla, f, cancel)
|
||||
rl.addListener(ws3, "c", rlb, f, cancel)
|
||||
rl.addListener(ws1, ":1", f, cancel)
|
||||
rl.addListener(ws2, ":1", f, cancel)
|
||||
rl.addListener(ws3, "a", f, cancel)
|
||||
rl.addListener(ws3, "b", f, cancel)
|
||||
rl.addListener(ws3, "c", f, cancel)
|
||||
|
||||
rl.removeClientAndListeners(ws1)
|
||||
rl.removeClientAndListeners(ws3)
|
||||
|
||||
@@ -21,8 +21,10 @@ type RelayManagementAPI struct {
|
||||
|
||||
BanPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||
ListBannedPubKeys func(ctx context.Context) ([]nip86.PubKeyReason, error)
|
||||
UnbanPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||
AllowPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||
ListAllowedPubKeys func(ctx context.Context) ([]nip86.PubKeyReason, error)
|
||||
UnallowPubKey func(ctx context.Context, pubkey nostr.PubKey, reason string) error
|
||||
ListEventsNeedingModeration func(ctx context.Context) ([]nip86.IDReason, error)
|
||||
AllowEvent func(ctx context.Context, id nostr.ID, reason string) error
|
||||
BanEvent func(ctx context.Context, id nostr.ID, reason string) error
|
||||
@@ -168,6 +170,14 @@ func (rl *Relay) HandleNIP86(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
resp.Result = result
|
||||
}
|
||||
case nip86.UnbanPubKey:
|
||||
if rl.ManagementAPI.UnbanPubKey == nil {
|
||||
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||
} else if err := rl.ManagementAPI.UnbanPubKey(ctx, thing.PubKey, thing.Reason); err != nil {
|
||||
resp.Error = err.Error()
|
||||
} else {
|
||||
resp.Result = true
|
||||
}
|
||||
case nip86.AllowPubKey:
|
||||
if rl.ManagementAPI.AllowPubKey == nil {
|
||||
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||
@@ -184,6 +194,14 @@ func (rl *Relay) HandleNIP86(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
resp.Result = result
|
||||
}
|
||||
case nip86.UnallowPubKey:
|
||||
if rl.ManagementAPI.UnallowPubKey == nil {
|
||||
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||
} else if err := rl.ManagementAPI.UnallowPubKey(ctx, thing.PubKey, thing.Reason); err != nil {
|
||||
resp.Error = err.Error()
|
||||
} else {
|
||||
resp.Result = true
|
||||
}
|
||||
case nip86.BanEvent:
|
||||
if rl.ManagementAPI.BanEvent == nil {
|
||||
resp.Error = fmt.Sprintf("method %s not supported", thing.MethodName())
|
||||
|
||||
@@ -3,6 +3,7 @@ package policies
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -110,6 +111,9 @@ func RejectEventsWithBase64Media(ctx context.Context, evt nostr.Event) (bool, st
|
||||
}
|
||||
|
||||
func OnlyAllowNIP70ProtectedEvents(ctx context.Context, event nostr.Event) (reject bool, msg string) {
|
||||
if event.Kind == 5 {
|
||||
return false, ""
|
||||
}
|
||||
if nip70.IsProtected(event) {
|
||||
return false, ""
|
||||
}
|
||||
@@ -120,6 +124,9 @@ var nostrReferencesPrefix = regexp.MustCompile(`\b(nevent1|npub1|nprofile1|note1
|
||||
|
||||
func RejectUnprefixedNostrReferences(ctx context.Context, event nostr.Event) (bool, string) {
|
||||
content := sdk.GetMainContent(event)
|
||||
if content == "" {
|
||||
content = event.Content
|
||||
}
|
||||
|
||||
// only do it for stuff that wasn't parsed as blocks already
|
||||
// (since those are already good references or URLs)
|
||||
@@ -144,3 +151,55 @@ func RejectUnprefixedNostrReferences(ctx context.Context, event nostr.Event) (bo
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// PreventNormalDuplicates prevents normal events that refer to the same thing from being saved.
|
||||
// For kinds 6, 7, 16, 1018 it checks "e" tags.
|
||||
// For kind 1163 it checks "p" tags.
|
||||
// For kinds 1163, 6, 16, 7516, 7517 it checks "a" tags.
|
||||
func PreventNormalDuplicates(query func(nostr.Filter, int) iter.Seq[nostr.Event]) func(ctx context.Context, event nostr.Event) (bool, string) {
|
||||
exists := func(event nostr.Event, tagName string) bool {
|
||||
hasAll := true
|
||||
for t := range event.Tags.FindAll(tagName) {
|
||||
hasThis := false
|
||||
for range query(nostr.Filter{
|
||||
Authors: []nostr.PubKey{event.PubKey},
|
||||
Kinds: []nostr.Kind{event.Kind},
|
||||
Tags: nostr.TagMap{tagName: []string{t[1]}},
|
||||
}, 1) {
|
||||
hasThis = true
|
||||
}
|
||||
if !hasThis {
|
||||
hasAll = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return hasAll
|
||||
}
|
||||
|
||||
return func(ctx context.Context, event nostr.Event) (bool, string) {
|
||||
reject := false
|
||||
|
||||
switch event.Kind {
|
||||
case 6:
|
||||
reject = exists(event, "e") && exists(event, "a")
|
||||
case 7:
|
||||
reject = exists(event, "e") && exists(event, "a")
|
||||
case 16:
|
||||
reject = exists(event, "e") && exists(event, "a")
|
||||
case 1018:
|
||||
reject = exists(event, "e")
|
||||
case 1163:
|
||||
reject = exists(event, "p") && exists(event, "a")
|
||||
case 7516:
|
||||
reject = exists(event, "a")
|
||||
case 7517:
|
||||
reject = exists(event, "a")
|
||||
}
|
||||
|
||||
if reject {
|
||||
return true, "an event similar to this already exists"
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
}
|
||||
|
||||
+39
-13
@@ -8,9 +8,9 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"fiatjaf.com/lib/channelmutex"
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore"
|
||||
"fiatjaf.com/nostr/nip11"
|
||||
@@ -39,8 +39,10 @@ func NewRelay() *Relay {
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
|
||||
clients: make(map[*WebSocket][]listenerSpec, 100),
|
||||
listeners: make([]listener, 0, 100),
|
||||
clients: make(map[*WebSocket][]listenerSpec, 100),
|
||||
clientsMutex: channelmutex.New(),
|
||||
|
||||
dispatcher: newDispatcher(),
|
||||
|
||||
serveMux: &http.ServeMux{},
|
||||
|
||||
@@ -66,9 +68,10 @@ type Relay struct {
|
||||
// hooks that will be called at various times
|
||||
OnEvent func(ctx context.Context, event nostr.Event) (reject bool, msg string)
|
||||
StoreEvent func(ctx context.Context, event nostr.Event) error
|
||||
ReplaceEvent func(ctx context.Context, event nostr.Event) error
|
||||
ReplaceEvent func(ctx context.Context, event nostr.Event) ([]nostr.Event, error)
|
||||
DeleteEvent func(ctx context.Context, id nostr.ID) error
|
||||
OnEventSaved func(ctx context.Context, event nostr.Event)
|
||||
OnEventDeleted func(ctx context.Context, deleted nostr.Event)
|
||||
OnEphemeralEvent func(ctx context.Context, event nostr.Event)
|
||||
OnRequest func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
|
||||
OnCount func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
|
||||
@@ -84,11 +87,6 @@ type Relay struct {
|
||||
// this can be ignored unless you know what you're doing
|
||||
ChallengePrefix string
|
||||
|
||||
// these are used when this relays acts as a router
|
||||
routes []Route
|
||||
getSubRelayFromEvent func(*nostr.Event) *Relay // used for handling EVENTs
|
||||
getSubRelayFromFilter func(nostr.Filter) *Relay // used for handling REQs
|
||||
|
||||
// setting up handlers here will enable these methods
|
||||
ManagementAPI RelayManagementAPI
|
||||
|
||||
@@ -105,8 +103,8 @@ type Relay struct {
|
||||
// keep a connection reference to all connected clients for Server.Shutdown
|
||||
// also used for keeping track of who is listening to what
|
||||
clients map[*WebSocket][]listenerSpec
|
||||
listeners []listener
|
||||
clientsMutex sync.Mutex
|
||||
dispatcher dispatcher
|
||||
clientsMutex *channelmutex.Mutex
|
||||
|
||||
// set this to true to support negentropy
|
||||
Negentropy bool
|
||||
@@ -147,7 +145,7 @@ func (rl *Relay) UseEventstore(store eventstore.Store, maxQueryLimit int) {
|
||||
rl.StoreEvent = func(ctx context.Context, event nostr.Event) error {
|
||||
return store.SaveEvent(event)
|
||||
}
|
||||
rl.ReplaceEvent = func(ctx context.Context, event nostr.Event) error {
|
||||
rl.ReplaceEvent = func(ctx context.Context, event nostr.Event) ([]nostr.Event, error) {
|
||||
return store.ReplaceEvent(event)
|
||||
}
|
||||
rl.DeleteEvent = func(ctx context.Context, id nostr.ID) error {
|
||||
@@ -155,7 +153,15 @@ func (rl *Relay) UseEventstore(store eventstore.Store, maxQueryLimit int) {
|
||||
}
|
||||
|
||||
// only when using the eventstore we automatically set up the expiration manager
|
||||
rl.StartExpirationManager(rl.QueryStored, rl.DeleteEvent)
|
||||
rl.StartExpirationManager(func(ctx context.Context, filter nostr.Filter) iter.Seq[nostr.Event] {
|
||||
return rl.QueryStored(ctx, filter)
|
||||
}, func(ctx context.Context, id nostr.ID) error {
|
||||
return rl.DeleteEvent(ctx, id)
|
||||
}, func(ctx context.Context, evt nostr.Event) {
|
||||
if rl.OnEventDeleted != nil {
|
||||
rl.OnEventDeleted(ctx, evt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (rl *Relay) getBaseURL(r *http.Request) string {
|
||||
@@ -184,3 +190,23 @@ func (rl *Relay) getBaseURL(r *http.Request) string {
|
||||
|
||||
return proto + "://" + host + r.URL.Path
|
||||
}
|
||||
|
||||
// Stats returns the current number of connected clients and open listeners.
|
||||
func (rl *Relay) Stats() (clients, listeners int) {
|
||||
rl.clientsMutex.Lock()
|
||||
defer rl.clientsMutex.Unlock()
|
||||
|
||||
for _, specs := range rl.clients {
|
||||
listeners += len(specs)
|
||||
}
|
||||
|
||||
return len(rl.clients), listeners
|
||||
}
|
||||
|
||||
func (rl *Relay) Router() *http.ServeMux {
|
||||
return rl.serveMux
|
||||
}
|
||||
|
||||
func (rl *Relay) SetRouter(mux *http.ServeMux) {
|
||||
rl.serveMux = mux
|
||||
}
|
||||
|
||||
+35
-77
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore/slicestore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasicRelayFunctionality(t *testing.T) {
|
||||
@@ -46,15 +47,11 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
// connect two test clients
|
||||
url := "ws" + server.URL[4:]
|
||||
client1, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect client1: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to connect client1")
|
||||
defer client1.Close()
|
||||
|
||||
client2, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect client2: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to connect client2")
|
||||
defer client2.Close()
|
||||
|
||||
// test 1: store and query events
|
||||
@@ -64,18 +61,14 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
|
||||
evt1 := createEvent(sk1, 1, "hello world", nil)
|
||||
err := client1.Publish(ctx, evt1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish event")
|
||||
|
||||
// Query the event back
|
||||
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
||||
Authors: []nostr.PubKey{pk1},
|
||||
Kinds: []nostr.Kind{1},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to subscribe")
|
||||
defer sub.Unsub()
|
||||
|
||||
// Wait for event
|
||||
@@ -85,7 +78,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
t.Errorf("got wrong event: %v", env.ID)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for event")
|
||||
require.FailNow(t, "timeout waiting for event")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -99,17 +92,13 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
Authors: []nostr.PubKey{pk2},
|
||||
Kinds: []nostr.Kind{1},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to subscribe")
|
||||
defer sub.Unsub()
|
||||
|
||||
// Publish event from client2
|
||||
evt2 := createEvent(sk2, 1, "testing live events", nil)
|
||||
err = client2.Publish(ctx, evt2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish event")
|
||||
|
||||
// Wait for event on subscription
|
||||
select {
|
||||
@@ -118,7 +107,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
t.Errorf("got wrong event: %v", env.ID)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for live event")
|
||||
require.FailNow(t, "timeout waiting for live event")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -130,24 +119,18 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
// Create an event to be deleted
|
||||
evt3 := createEvent(sk1, 1, "delete me", nil)
|
||||
err = client1.Publish(ctx, evt3)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish event")
|
||||
|
||||
// Create deletion event
|
||||
delEvent := createEvent(sk1, 5, "deleting", nostr.Tags{{"e", evt3.ID.Hex()}})
|
||||
err = client1.Publish(ctx, delEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish deletion event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish deletion event")
|
||||
|
||||
// Try to query the deleted event
|
||||
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
||||
IDs: []nostr.ID{evt3.ID},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to subscribe")
|
||||
defer sub.Unsub()
|
||||
|
||||
// Should get EOSE without receiving the deleted event
|
||||
@@ -162,7 +145,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
}
|
||||
goto checkDeleteStored
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for EOSE")
|
||||
require.FailNow(t, "timeout waiting for EOSE")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,9 +154,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
subDelete, err := client2.Subscribe(ctx, nostr.Filter{
|
||||
IDs: []nostr.ID{delEvent.ID},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe to delete event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to subscribe to delete event")
|
||||
defer subDelete.Unsub()
|
||||
|
||||
gotDeleteEvent := false
|
||||
@@ -189,7 +170,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for EOSE on delete event")
|
||||
require.FailNow(t, "timeout waiting for EOSE on delete event")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -204,36 +185,28 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
evt1.CreatedAt = 1000 // Set specific timestamp for testing
|
||||
evt1.Sign(sk1)
|
||||
err = client1.Publish(ctx, evt1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish initial event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish initial event")
|
||||
|
||||
// create newer event that should replace the first
|
||||
evt2 := createEvent(sk1, 0, `{"name":"newer"}`, nil)
|
||||
evt2.CreatedAt = 2004 // Newer timestamp
|
||||
evt2.Sign(sk1)
|
||||
err = client1.Publish(ctx, evt2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish newer event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish newer event")
|
||||
|
||||
// create older event that should not replace the current one
|
||||
evt3 := createEvent(sk1, 0, `{"name":"older"}`, nil)
|
||||
evt3.CreatedAt = 1500 // Older than evt2
|
||||
evt3.Sign(sk1)
|
||||
err = client1.Publish(ctx, evt3)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish older event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish older event")
|
||||
|
||||
// query to verify only the newest event exists
|
||||
sub, err := client2.Subscribe(ctx, nostr.Filter{
|
||||
Authors: []nostr.PubKey{pk1},
|
||||
Kinds: []nostr.Kind{0},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to subscribe")
|
||||
defer sub.Unsub()
|
||||
|
||||
// should only get one event back (the newest one)
|
||||
@@ -251,7 +224,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for events")
|
||||
require.FailNow(t, "timeout waiting for events")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -281,26 +254,20 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
// connect test client
|
||||
url := "ws" + server.URL[4:]
|
||||
client, err := nostr.RelayConnect(t.Context(), url, nostr.RelayOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect client: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to connect client")
|
||||
defer client.Close()
|
||||
|
||||
// create event that expires in 2 seconds
|
||||
expiration := strconv.FormatInt(int64(nostr.Now()+2), 10)
|
||||
evt := createEvent(sk1, 1, "i will expire soon", nostr.Tags{{"expiration", expiration}})
|
||||
err = client.Publish(ctx, evt)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish event: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to publish event")
|
||||
|
||||
// verify event exists initially
|
||||
sub, err := client.Subscribe(ctx, nostr.Filter{
|
||||
IDs: []nostr.ID{evt.ID},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to subscribe")
|
||||
|
||||
// should get the event
|
||||
select {
|
||||
@@ -309,7 +276,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
t.Error("got wrong event")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for event")
|
||||
require.FailNow(t, "timeout waiting for event")
|
||||
}
|
||||
sub.Unsub()
|
||||
|
||||
@@ -320,9 +287,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
sub, err = client.Subscribe(ctx, nostr.Filter{
|
||||
IDs: []nostr.ID{evt.ID},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "failed to subscribe")
|
||||
defer sub.Unsub()
|
||||
|
||||
// should get EOSE without receiving the expired event
|
||||
@@ -337,7 +302,7 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for EOSE")
|
||||
require.FailNow(t, "timeout waiting for EOSE")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -350,33 +315,26 @@ func TestBasicRelayFunctionality(t *testing.T) {
|
||||
// create an event from client1
|
||||
evt4 := createEvent(sk1, 1, "try to delete me", nil)
|
||||
err = client1.Publish(ctx, evt4)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to publish event: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to delete it with client2
|
||||
// try to delete it with client2
|
||||
delEvent := createEvent(sk2, 5, "trying to delete", nostr.Tags{{"e", evt4.ID.Hex()}})
|
||||
err = client2.Publish(ctx, delEvent)
|
||||
if err == nil {
|
||||
t.Fatalf("should have failed to publish deletion event: %v", err)
|
||||
}
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify event still exists
|
||||
// verify event still exists
|
||||
sub, err := client1.Subscribe(ctx, nostr.Filter{
|
||||
IDs: []nostr.ID{evt4.ID},
|
||||
}, nostr.SubscriptionOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to subscribe: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer sub.Unsub()
|
||||
|
||||
select {
|
||||
case env := <-sub.Events:
|
||||
if env.ID != evt4.ID {
|
||||
t.Error("got wrong event")
|
||||
}
|
||||
case env, more := <-sub.Events:
|
||||
require.True(t, more, "should get an event, got nothing")
|
||||
require.Equal(t, env.ID, evt4.ID, "got wrong event")
|
||||
case <-ctx.Done():
|
||||
t.Fatal("event should still exist")
|
||||
require.FailNow(t, "event should still exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -30,7 +30,9 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr
|
||||
// run the function to query events
|
||||
if nil != rl.QueryStored {
|
||||
for event := range rl.QueryStored(ctx, filter) {
|
||||
ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &id, Event: event})
|
||||
if nil != ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &id, Event: event}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package khatru
|
||||
|
||||
import (
|
||||
"fiatjaf.com/nostr"
|
||||
)
|
||||
|
||||
type Router struct{ *Relay }
|
||||
|
||||
type Route struct {
|
||||
eventMatcher func(*nostr.Event) bool
|
||||
filterMatcher func(nostr.Filter) bool
|
||||
relay *Relay
|
||||
}
|
||||
|
||||
type routeBuilder struct {
|
||||
router *Router
|
||||
eventMatcher func(*nostr.Event) bool
|
||||
filterMatcher func(nostr.Filter) bool
|
||||
}
|
||||
|
||||
func NewRouter() *Router {
|
||||
rr := &Router{Relay: NewRelay()}
|
||||
rr.routes = make([]Route, 0, 3)
|
||||
rr.getSubRelayFromFilter = func(f nostr.Filter) *Relay {
|
||||
for _, route := range rr.routes {
|
||||
if route.filterMatcher == nil || route.filterMatcher(f) {
|
||||
return route.relay
|
||||
}
|
||||
}
|
||||
return rr.Relay
|
||||
}
|
||||
rr.getSubRelayFromEvent = func(e *nostr.Event) *Relay {
|
||||
for _, route := range rr.routes {
|
||||
if route.eventMatcher == nil || route.eventMatcher(e) {
|
||||
return route.relay
|
||||
}
|
||||
}
|
||||
return rr.Relay
|
||||
}
|
||||
return rr
|
||||
}
|
||||
|
||||
func (rr *Router) Route() routeBuilder {
|
||||
return routeBuilder{
|
||||
router: rr,
|
||||
filterMatcher: func(f nostr.Filter) bool { return false },
|
||||
eventMatcher: func(e *nostr.Event) bool { return false },
|
||||
}
|
||||
}
|
||||
|
||||
func (rb routeBuilder) Req(fn func(nostr.Filter) bool) routeBuilder {
|
||||
rb.filterMatcher = fn
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb routeBuilder) AnyReq() routeBuilder {
|
||||
rb.filterMatcher = nil
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb routeBuilder) Event(fn func(*nostr.Event) bool) routeBuilder {
|
||||
rb.eventMatcher = fn
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb routeBuilder) AnyEvent() routeBuilder {
|
||||
rb.eventMatcher = nil
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb routeBuilder) Relay(relay *Relay) {
|
||||
rb.router.routes = append(rb.router.routes, Route{
|
||||
filterMatcher: rb.filterMatcher,
|
||||
eventMatcher: rb.eventMatcher,
|
||||
relay: relay,
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package khatru
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
@@ -31,6 +32,9 @@ type WebSocket struct {
|
||||
}
|
||||
|
||||
func (ws *WebSocket) WriteJSON(any any) error {
|
||||
if ws == nil {
|
||||
return fmt.Errorf("connection doesn't exist")
|
||||
}
|
||||
ws.mutex.Lock()
|
||||
err := ws.conn.WriteJSON(any)
|
||||
ws.mutex.Unlock()
|
||||
|
||||
@@ -246,6 +246,8 @@ func (kind Kind) Name() string {
|
||||
return "SimpleGroupMembers"
|
||||
case KindSimpleGroupRoles:
|
||||
return "SimpleGroupRoles"
|
||||
case KindSimpleGroupLiveKitParticipants:
|
||||
return "SimpleGroupLiveKitParticipants"
|
||||
case KindWikiArticle:
|
||||
return "WikiArticle"
|
||||
case KindRedirects:
|
||||
@@ -277,138 +279,139 @@ func (kind Kind) Name() string {
|
||||
}
|
||||
|
||||
const (
|
||||
KindProfileMetadata Kind = 0
|
||||
KindTextNote Kind = 1
|
||||
KindRecommendServer Kind = 2
|
||||
KindFollowList Kind = 3
|
||||
KindEncryptedDirectMessage Kind = 4
|
||||
KindDeletion Kind = 5
|
||||
KindRepost Kind = 6
|
||||
KindReaction Kind = 7
|
||||
KindBadgeAward Kind = 8
|
||||
KindSimpleGroupChatMessage Kind = 9
|
||||
KindSimpleGroupThreadedReply Kind = 10
|
||||
KindSimpleGroupThread Kind = 11
|
||||
KindSimpleGroupReply Kind = 12
|
||||
KindSeal Kind = 13
|
||||
KindDirectMessage Kind = 14
|
||||
KindGenericRepost Kind = 16
|
||||
KindReactionToWebsite Kind = 17
|
||||
KindChannelCreation Kind = 40
|
||||
KindChannelMetadata Kind = 41
|
||||
KindChannelMessage Kind = 42
|
||||
KindChannelHideMessage Kind = 43
|
||||
KindChannelMuteUser Kind = 44
|
||||
KindChess Kind = 64
|
||||
KindMergeRequests Kind = 818
|
||||
KindComment Kind = 1111
|
||||
KindBid Kind = 1021
|
||||
KindBidConfirmation Kind = 1022
|
||||
KindOpenTimestamps Kind = 1040
|
||||
KindGiftWrap Kind = 1059
|
||||
KindFileMetadata Kind = 1063
|
||||
KindLiveChatMessage Kind = 1311
|
||||
KindPatch Kind = 1617
|
||||
KindIssue Kind = 1621
|
||||
KindReply Kind = 1622
|
||||
KindStatusOpen Kind = 1630
|
||||
KindStatusApplied Kind = 1631
|
||||
KindStatusClosed Kind = 1632
|
||||
KindStatusDraft Kind = 1633
|
||||
KindProblemTracker Kind = 1971
|
||||
KindReporting Kind = 1984
|
||||
KindLabel Kind = 1985
|
||||
KindRelayReviews Kind = 1986
|
||||
KindAIEmbeddings Kind = 1987
|
||||
KindTorrent Kind = 2003
|
||||
KindTorrentComment Kind = 2004
|
||||
KindCoinjoinPool Kind = 2022
|
||||
KindCommunityPostApproval Kind = 4550
|
||||
KindJobFeedback Kind = 7000
|
||||
KindSimpleGroupPutUser Kind = 9000
|
||||
KindSimpleGroupRemoveUser Kind = 9001
|
||||
KindSimpleGroupEditMetadata Kind = 9002
|
||||
KindSimpleGroupDeleteEvent Kind = 9005
|
||||
KindSimpleGroupCreateGroup Kind = 9007
|
||||
KindSimpleGroupDeleteGroup Kind = 9008
|
||||
KindSimpleGroupCreateInvite Kind = 9009
|
||||
KindSimpleGroupJoinRequest Kind = 9021
|
||||
KindSimpleGroupLeaveRequest Kind = 9022
|
||||
KindZapGoal Kind = 9041
|
||||
KindNutZap Kind = 9321
|
||||
KindTidalLogin Kind = 9467
|
||||
KindZapRequest Kind = 9734
|
||||
KindZap Kind = 9735
|
||||
KindHighlights Kind = 9802
|
||||
KindMuteList Kind = 10000
|
||||
KindPinList Kind = 10001
|
||||
KindRelayListMetadata Kind = 10002
|
||||
KindBookmarkList Kind = 10003
|
||||
KindCommunityList Kind = 10004
|
||||
KindPublicChatList Kind = 10005
|
||||
KindBlockedRelayList Kind = 10006
|
||||
KindSearchRelayList Kind = 10007
|
||||
KindSimpleGroupList Kind = 10009
|
||||
KindInterestList Kind = 10015
|
||||
KindNutZapInfo Kind = 10019
|
||||
KindEmojiList Kind = 10030
|
||||
KindDMRelayList Kind = 10050
|
||||
KindUserServerList Kind = 10063
|
||||
KindFileStorageServerList Kind = 10096
|
||||
KindGoodWikiAuthorList Kind = 10101
|
||||
KindGoodWikiRelayList Kind = 10102
|
||||
KindNWCWalletInfo Kind = 13194
|
||||
KindLightningPubRPC Kind = 21000
|
||||
KindClientAuthentication Kind = 22242
|
||||
KindNWCWalletRequest Kind = 23194
|
||||
KindNWCWalletResponse Kind = 23195
|
||||
KindNostrConnect Kind = 24133
|
||||
KindBlobs Kind = 24242
|
||||
KindHTTPAuth Kind = 27235
|
||||
KindCategorizedPeopleList Kind = 30000
|
||||
KindCategorizedBookmarksList Kind = 30001
|
||||
KindRelaySets Kind = 30002
|
||||
KindBookmarkSets Kind = 30003
|
||||
KindCuratedSets Kind = 30004
|
||||
KindCuratedVideoSets Kind = 30005
|
||||
KindMuteSets Kind = 30007
|
||||
KindProfileBadges Kind = 30008
|
||||
KindBadgeDefinition Kind = 30009
|
||||
KindInterestSets Kind = 30015
|
||||
KindStallDefinition Kind = 30017
|
||||
KindProductDefinition Kind = 30018
|
||||
KindMarketplaceUI Kind = 30019
|
||||
KindProductSoldAsAuction Kind = 30020
|
||||
KindArticle Kind = 30023
|
||||
KindDraftArticle Kind = 30024
|
||||
KindEmojiSets Kind = 30030
|
||||
KindModularArticleHeader Kind = 30040
|
||||
KindModularArticleContent Kind = 30041
|
||||
KindReleaseArtifactSets Kind = 30063
|
||||
KindApplicationSpecificData Kind = 30078
|
||||
KindLiveEvent Kind = 30311
|
||||
KindUserStatuses Kind = 30315
|
||||
KindClassifiedListing Kind = 30402
|
||||
KindDraftClassifiedListing Kind = 30403
|
||||
KindRepositoryAnnouncement Kind = 30617
|
||||
KindRepositoryState Kind = 30618
|
||||
KindSimpleGroupMetadata Kind = 39000
|
||||
KindSimpleGroupAdmins Kind = 39001
|
||||
KindSimpleGroupMembers Kind = 39002
|
||||
KindSimpleGroupRoles Kind = 39003
|
||||
KindWikiArticle Kind = 30818
|
||||
KindRedirects Kind = 30819
|
||||
KindFeed Kind = 31890
|
||||
KindDateCalendarEvent Kind = 31922
|
||||
KindTimeCalendarEvent Kind = 31923
|
||||
KindCalendar Kind = 31924
|
||||
KindCalendarEventRSVP Kind = 31925
|
||||
KindHandlerRecommendation Kind = 31989
|
||||
KindHandlerInformation Kind = 31990
|
||||
KindVideoEvent Kind = 34235
|
||||
KindShortVideoEvent Kind = 34236
|
||||
KindVideoViewEvent Kind = 34237
|
||||
KindCommunityDefinition Kind = 34550
|
||||
KindProfileMetadata Kind = 0
|
||||
KindTextNote Kind = 1
|
||||
KindRecommendServer Kind = 2
|
||||
KindFollowList Kind = 3
|
||||
KindEncryptedDirectMessage Kind = 4
|
||||
KindDeletion Kind = 5
|
||||
KindRepost Kind = 6
|
||||
KindReaction Kind = 7
|
||||
KindBadgeAward Kind = 8
|
||||
KindSimpleGroupChatMessage Kind = 9
|
||||
KindSimpleGroupThreadedReply Kind = 10
|
||||
KindSimpleGroupThread Kind = 11
|
||||
KindSimpleGroupReply Kind = 12
|
||||
KindSeal Kind = 13
|
||||
KindDirectMessage Kind = 14
|
||||
KindGenericRepost Kind = 16
|
||||
KindReactionToWebsite Kind = 17
|
||||
KindChannelCreation Kind = 40
|
||||
KindChannelMetadata Kind = 41
|
||||
KindChannelMessage Kind = 42
|
||||
KindChannelHideMessage Kind = 43
|
||||
KindChannelMuteUser Kind = 44
|
||||
KindChess Kind = 64
|
||||
KindMergeRequests Kind = 818
|
||||
KindComment Kind = 1111
|
||||
KindBid Kind = 1021
|
||||
KindBidConfirmation Kind = 1022
|
||||
KindOpenTimestamps Kind = 1040
|
||||
KindGiftWrap Kind = 1059
|
||||
KindFileMetadata Kind = 1063
|
||||
KindLiveChatMessage Kind = 1311
|
||||
KindPatch Kind = 1617
|
||||
KindIssue Kind = 1621
|
||||
KindReply Kind = 1622
|
||||
KindStatusOpen Kind = 1630
|
||||
KindStatusApplied Kind = 1631
|
||||
KindStatusClosed Kind = 1632
|
||||
KindStatusDraft Kind = 1633
|
||||
KindProblemTracker Kind = 1971
|
||||
KindReporting Kind = 1984
|
||||
KindLabel Kind = 1985
|
||||
KindRelayReviews Kind = 1986
|
||||
KindAIEmbeddings Kind = 1987
|
||||
KindTorrent Kind = 2003
|
||||
KindTorrentComment Kind = 2004
|
||||
KindCoinjoinPool Kind = 2022
|
||||
KindCommunityPostApproval Kind = 4550
|
||||
KindJobFeedback Kind = 7000
|
||||
KindSimpleGroupPutUser Kind = 9000
|
||||
KindSimpleGroupRemoveUser Kind = 9001
|
||||
KindSimpleGroupEditMetadata Kind = 9002
|
||||
KindSimpleGroupDeleteEvent Kind = 9005
|
||||
KindSimpleGroupCreateGroup Kind = 9007
|
||||
KindSimpleGroupDeleteGroup Kind = 9008
|
||||
KindSimpleGroupCreateInvite Kind = 9009
|
||||
KindSimpleGroupJoinRequest Kind = 9021
|
||||
KindSimpleGroupLeaveRequest Kind = 9022
|
||||
KindZapGoal Kind = 9041
|
||||
KindNutZap Kind = 9321
|
||||
KindTidalLogin Kind = 9467
|
||||
KindZapRequest Kind = 9734
|
||||
KindZap Kind = 9735
|
||||
KindHighlights Kind = 9802
|
||||
KindMuteList Kind = 10000
|
||||
KindPinList Kind = 10001
|
||||
KindRelayListMetadata Kind = 10002
|
||||
KindBookmarkList Kind = 10003
|
||||
KindCommunityList Kind = 10004
|
||||
KindPublicChatList Kind = 10005
|
||||
KindBlockedRelayList Kind = 10006
|
||||
KindSearchRelayList Kind = 10007
|
||||
KindSimpleGroupList Kind = 10009
|
||||
KindInterestList Kind = 10015
|
||||
KindNutZapInfo Kind = 10019
|
||||
KindEmojiList Kind = 10030
|
||||
KindDMRelayList Kind = 10050
|
||||
KindUserServerList Kind = 10063
|
||||
KindFileStorageServerList Kind = 10096
|
||||
KindGoodWikiAuthorList Kind = 10101
|
||||
KindGoodWikiRelayList Kind = 10102
|
||||
KindNWCWalletInfo Kind = 13194
|
||||
KindLightningPubRPC Kind = 21000
|
||||
KindClientAuthentication Kind = 22242
|
||||
KindNWCWalletRequest Kind = 23194
|
||||
KindNWCWalletResponse Kind = 23195
|
||||
KindNostrConnect Kind = 24133
|
||||
KindBlobs Kind = 24242
|
||||
KindHTTPAuth Kind = 27235
|
||||
KindCategorizedPeopleList Kind = 30000
|
||||
KindCategorizedBookmarksList Kind = 30001
|
||||
KindRelaySets Kind = 30002
|
||||
KindBookmarkSets Kind = 30003
|
||||
KindCuratedSets Kind = 30004
|
||||
KindCuratedVideoSets Kind = 30005
|
||||
KindMuteSets Kind = 30007
|
||||
KindProfileBadges Kind = 30008
|
||||
KindBadgeDefinition Kind = 30009
|
||||
KindInterestSets Kind = 30015
|
||||
KindStallDefinition Kind = 30017
|
||||
KindProductDefinition Kind = 30018
|
||||
KindMarketplaceUI Kind = 30019
|
||||
KindProductSoldAsAuction Kind = 30020
|
||||
KindArticle Kind = 30023
|
||||
KindDraftArticle Kind = 30024
|
||||
KindEmojiSets Kind = 30030
|
||||
KindModularArticleHeader Kind = 30040
|
||||
KindModularArticleContent Kind = 30041
|
||||
KindReleaseArtifactSets Kind = 30063
|
||||
KindApplicationSpecificData Kind = 30078
|
||||
KindLiveEvent Kind = 30311
|
||||
KindUserStatuses Kind = 30315
|
||||
KindClassifiedListing Kind = 30402
|
||||
KindDraftClassifiedListing Kind = 30403
|
||||
KindRepositoryAnnouncement Kind = 30617
|
||||
KindRepositoryState Kind = 30618
|
||||
KindSimpleGroupMetadata Kind = 39000
|
||||
KindSimpleGroupAdmins Kind = 39001
|
||||
KindSimpleGroupMembers Kind = 39002
|
||||
KindSimpleGroupRoles Kind = 39003
|
||||
KindSimpleGroupLiveKitParticipants Kind = 39004
|
||||
KindWikiArticle Kind = 30818
|
||||
KindRedirects Kind = 30819
|
||||
KindFeed Kind = 31890
|
||||
KindDateCalendarEvent Kind = 31922
|
||||
KindTimeCalendarEvent Kind = 31923
|
||||
KindCalendar Kind = 31924
|
||||
KindCalendarEventRSVP Kind = 31925
|
||||
KindHandlerRecommendation Kind = 31989
|
||||
KindHandlerInformation Kind = 31990
|
||||
KindVideoEvent Kind = 34235
|
||||
KindShortVideoEvent Kind = 34236
|
||||
KindVideoViewEvent Kind = 34237
|
||||
KindCommunityDefinition Kind = 34550
|
||||
)
|
||||
|
||||
func (kind Kind) IsRegular() bool {
|
||||
|
||||
+126
-35
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
@@ -38,10 +39,11 @@ func ParseGroupAddress(raw string) (GroupAddress, error) {
|
||||
type Group struct {
|
||||
Address GroupAddress
|
||||
|
||||
Name string
|
||||
Picture string
|
||||
About string
|
||||
Members map[nostr.PubKey][]*Role
|
||||
Name string
|
||||
Picture string
|
||||
About string
|
||||
Members map[nostr.PubKey][]*Role
|
||||
LiveKitParticipants []nostr.PubKey
|
||||
|
||||
// indicates that only members can read group messages
|
||||
Private bool
|
||||
@@ -55,13 +57,20 @@ type Group struct {
|
||||
// indicates that relays should hide group metadata from non-members
|
||||
Hidden bool
|
||||
|
||||
// indicates that the group supports audio/video live chat
|
||||
LiveKit bool
|
||||
|
||||
// indicates which event kinds this group supports
|
||||
SupportedKinds []nostr.Kind
|
||||
|
||||
Roles []*Role
|
||||
InviteCodes []string
|
||||
|
||||
LastMetadataUpdate nostr.Timestamp
|
||||
LastAdminsUpdate nostr.Timestamp
|
||||
LastMembersUpdate nostr.Timestamp
|
||||
LastRolesUpdate nostr.Timestamp
|
||||
LastMetadataUpdate nostr.Timestamp
|
||||
LastAdminsUpdate nostr.Timestamp
|
||||
LastMembersUpdate nostr.Timestamp
|
||||
LastRolesUpdate nostr.Timestamp
|
||||
LastLiveKitParticipantsUpdate nostr.Timestamp
|
||||
}
|
||||
|
||||
func (group Group) String() string {
|
||||
@@ -83,6 +92,11 @@ func (group Group) String() string {
|
||||
maybeClosed = " closed"
|
||||
}
|
||||
|
||||
maybeLiveKit := ""
|
||||
if group.LiveKit {
|
||||
maybeLiveKit = " livekit"
|
||||
}
|
||||
|
||||
members := make([]string, len(group.Members))
|
||||
i := 0
|
||||
for pubkey, roles := range group.Members {
|
||||
@@ -101,13 +115,14 @@ func (group Group) String() string {
|
||||
i++
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`<Group %s name="%s"%s%s%s%s picture="%s" about="%s" members=[%v]>`,
|
||||
return fmt.Sprintf(`<Group %s name="%s"%s%s%s%s%s picture="%s" about="%s" members=[%v]>`,
|
||||
group.Address,
|
||||
group.Name,
|
||||
maybePrivate,
|
||||
maybeRestricted,
|
||||
maybeHidden,
|
||||
maybeClosed,
|
||||
maybeLiveKit,
|
||||
group.Picture,
|
||||
group.About,
|
||||
strings.Join(members, " "),
|
||||
@@ -122,9 +137,10 @@ func NewGroup(gadstr string) (Group, error) {
|
||||
}
|
||||
|
||||
return Group{
|
||||
Address: gad,
|
||||
Name: gad.ID,
|
||||
Members: make(map[nostr.PubKey][]*Role),
|
||||
Address: gad,
|
||||
Name: gad.ID,
|
||||
Members: make(map[nostr.PubKey][]*Role),
|
||||
LiveKitParticipants: make([]nostr.PubKey, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -134,8 +150,9 @@ func NewGroupFromMetadataEvent(relayURL string, evt *nostr.Event) (Group, error)
|
||||
Relay: relayURL,
|
||||
ID: evt.Tags.GetD(),
|
||||
},
|
||||
Name: evt.Tags.GetD(),
|
||||
Members: make(map[nostr.PubKey][]*Role),
|
||||
Name: evt.Tags.GetD(),
|
||||
Members: make(map[nostr.PubKey][]*Role),
|
||||
LiveKitParticipants: make([]nostr.PubKey, 0),
|
||||
}
|
||||
|
||||
err := g.MergeInMetadataEvent(evt)
|
||||
@@ -173,6 +190,18 @@ func (group Group) ToMetadataEvent() nostr.Event {
|
||||
if group.Closed {
|
||||
evt.Tags = append(evt.Tags, nostr.Tag{"closed"})
|
||||
}
|
||||
if group.LiveKit {
|
||||
evt.Tags = append(evt.Tags, nostr.Tag{"livekit"})
|
||||
}
|
||||
|
||||
if group.SupportedKinds != nil {
|
||||
tag := make(nostr.Tag, 1, 1+len(group.SupportedKinds))
|
||||
tag[0] = "supported_kinds"
|
||||
for _, kind := range group.SupportedKinds {
|
||||
tag = append(tag, strconv.Itoa(int(kind)))
|
||||
}
|
||||
evt.Tags = append(evt.Tags, tag)
|
||||
}
|
||||
|
||||
return evt
|
||||
}
|
||||
@@ -236,6 +265,22 @@ func (group Group) ToRolesEvent() nostr.Event {
|
||||
return evt
|
||||
}
|
||||
|
||||
func (group Group) ToLiveKitParticipantsEvent() nostr.Event {
|
||||
evt := nostr.Event{
|
||||
Kind: nostr.KindSimpleGroupLiveKitParticipants,
|
||||
CreatedAt: group.LastLiveKitParticipantsUpdate,
|
||||
Tags: make(nostr.Tags, 1, 1+len(group.LiveKitParticipants)),
|
||||
}
|
||||
evt.Tags[0] = nostr.Tag{"d", group.Address.ID}
|
||||
|
||||
for _, member := range group.LiveKitParticipants {
|
||||
tag := nostr.Tag{"participant", member.Hex()}
|
||||
evt.Tags = append(evt.Tags, tag)
|
||||
}
|
||||
|
||||
return evt
|
||||
}
|
||||
|
||||
func (group *Group) MergeInMetadataEvent(evt *nostr.Event) error {
|
||||
if evt.Kind != nostr.KindSimpleGroupMetadata {
|
||||
return fmt.Errorf("expected kind %d, got %d", nostr.KindSimpleGroupMetadata, evt.Kind)
|
||||
@@ -247,27 +292,42 @@ func (group *Group) MergeInMetadataEvent(evt *nostr.Event) error {
|
||||
group.LastMetadataUpdate = evt.CreatedAt
|
||||
group.Name = group.Address.ID
|
||||
|
||||
if tag := evt.Tags.Find("name"); tag != nil {
|
||||
group.Name = tag[1]
|
||||
}
|
||||
if tag := evt.Tags.Find("about"); tag != nil {
|
||||
group.About = tag[1]
|
||||
}
|
||||
if tag := evt.Tags.Find("picture"); tag != nil {
|
||||
group.Picture = tag[1]
|
||||
}
|
||||
|
||||
if tag := evt.Tags.Find("private"); tag != nil {
|
||||
group.Private = true
|
||||
}
|
||||
if tag := evt.Tags.Find("restricted"); tag != nil {
|
||||
group.Restricted = true
|
||||
}
|
||||
if tag := evt.Tags.Find("hidden"); tag != nil {
|
||||
group.Hidden = true
|
||||
}
|
||||
if tag := evt.Tags.Find("closed"); tag != nil {
|
||||
group.Closed = true
|
||||
for _, tag := range evt.Tags {
|
||||
if len(tag) >= 1 {
|
||||
switch tag[0] {
|
||||
case "private":
|
||||
group.Private = true
|
||||
case "restricted":
|
||||
group.Restricted = true
|
||||
case "closed":
|
||||
group.Closed = true
|
||||
case "hidden":
|
||||
group.Hidden = true
|
||||
case "livekit":
|
||||
group.LiveKit = true
|
||||
case "supported_kinds":
|
||||
kinds := make([]nostr.Kind, 0, len(tag)-1)
|
||||
for _, raw := range tag[1:] {
|
||||
kind, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
kinds = append(kinds, nostr.Kind(kind))
|
||||
}
|
||||
group.SupportedKinds = kinds
|
||||
default:
|
||||
if len(tag) >= 2 {
|
||||
switch tag[0] {
|
||||
case "name":
|
||||
group.Name = tag[1]
|
||||
case "about":
|
||||
group.About = tag[1]
|
||||
case "picture":
|
||||
group.Picture = tag[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -368,3 +428,34 @@ func (group *Group) MergeInRolesEvent(evt *nostr.Event) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (group *Group) MergeInLiveKitParticipantsEvent(evt *nostr.Event) error {
|
||||
if evt.Kind != nostr.KindSimpleGroupLiveKitParticipants {
|
||||
return fmt.Errorf("expected kind %d, got %d", nostr.KindSimpleGroupLiveKitParticipants, evt.Kind)
|
||||
}
|
||||
if evt.CreatedAt < group.LastLiveKitParticipantsUpdate {
|
||||
return fmt.Errorf("event is older than our last update (%d vs %d)", evt.CreatedAt, group.LastLiveKitParticipantsUpdate)
|
||||
}
|
||||
|
||||
group.LastLiveKitParticipantsUpdate = evt.CreatedAt
|
||||
group.LiveKitParticipants = make([]nostr.PubKey, 0, len(evt.Tags))
|
||||
for _, tag := range evt.Tags {
|
||||
if len(tag) < 2 {
|
||||
continue
|
||||
}
|
||||
if tag[0] != "participant" {
|
||||
continue
|
||||
}
|
||||
|
||||
member, err := nostr.PubKeyFromHex(tag[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(group.LiveKitParticipants, member) {
|
||||
continue
|
||||
}
|
||||
group.LiveKitParticipants = append(group.LiveKitParticipants, member)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+129
-48
@@ -3,6 +3,7 @@ package nip29
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
)
|
||||
@@ -78,48 +79,101 @@ var moderationActionFactories = map[nostr.Kind]func(nostr.Event) (Action, error)
|
||||
nostr.KindSimpleGroupEditMetadata: func(evt nostr.Event) (Action, error) {
|
||||
ok := false
|
||||
edit := EditMetadata{When: evt.CreatedAt}
|
||||
if t := evt.Tags.Find("name"); t != nil {
|
||||
edit.NameValue = &t[1]
|
||||
ok = true
|
||||
}
|
||||
if t := evt.Tags.Find("picture"); t != nil {
|
||||
edit.PictureValue = &t[1]
|
||||
ok = true
|
||||
}
|
||||
if t := evt.Tags.Find("about"); t != nil {
|
||||
edit.AboutValue = &t[1]
|
||||
ok = true
|
||||
}
|
||||
|
||||
y := true
|
||||
n := false
|
||||
if evt.Tags.Has("closed") {
|
||||
edit.ClosedValue = &y
|
||||
ok = true
|
||||
} else if evt.Tags.Has("open") {
|
||||
edit.ClosedValue = &n
|
||||
ok = true
|
||||
}
|
||||
if evt.Tags.Has("restricted") {
|
||||
edit.RestrictedValue = &y
|
||||
ok = true
|
||||
} else if evt.Tags.Has("unrestricted") {
|
||||
edit.RestrictedValue = &n
|
||||
ok = true
|
||||
}
|
||||
if evt.Tags.Has("hidden") {
|
||||
edit.HiddenValue = &y
|
||||
ok = true
|
||||
} else if evt.Tags.Has("visible") {
|
||||
edit.HiddenValue = &n
|
||||
ok = true
|
||||
}
|
||||
if evt.Tags.Has("private") {
|
||||
edit.PrivateValue = &y
|
||||
ok = true
|
||||
} else if evt.Tags.Has("public") {
|
||||
edit.PrivateValue = &n
|
||||
ok = true
|
||||
|
||||
hasName := false
|
||||
|
||||
// DEPRECATED: remove all the fields not tagged with Replace = true eventually
|
||||
// edit-metadata to become a PUT rather than a PATCH
|
||||
|
||||
for _, tag := range evt.Tags {
|
||||
if len(tag) >= 1 {
|
||||
switch tag[0] {
|
||||
case "name":
|
||||
if len(tag) >= 2 {
|
||||
edit.NameValue = &tag[1]
|
||||
if ok {
|
||||
edit.Replace = true
|
||||
}
|
||||
ok = true
|
||||
hasName = true
|
||||
}
|
||||
case "picture":
|
||||
if len(tag) >= 2 {
|
||||
edit.PictureValue = &tag[1]
|
||||
if hasName {
|
||||
edit.Replace = true
|
||||
}
|
||||
ok = true
|
||||
}
|
||||
case "about":
|
||||
if len(tag) >= 2 {
|
||||
edit.AboutValue = &tag[1]
|
||||
if hasName {
|
||||
edit.Replace = true
|
||||
}
|
||||
ok = true
|
||||
}
|
||||
case "supported_kinds":
|
||||
kinds := make([]nostr.Kind, 0, len(tag)-1)
|
||||
for _, kstr := range tag[1:] {
|
||||
if kind, err := strconv.ParseUint(kstr, 10, 16); err != nil {
|
||||
return nil, fmt.Errorf("invalid kind: %w", err)
|
||||
} else {
|
||||
kinds = append(kinds, nostr.Kind(kind))
|
||||
}
|
||||
}
|
||||
edit.SupportedKindsValue = &kinds
|
||||
edit.Replace = true
|
||||
case "closed":
|
||||
edit.ClosedValue = &y
|
||||
if hasName {
|
||||
edit.Replace = true
|
||||
}
|
||||
ok = true
|
||||
case "open":
|
||||
edit.ClosedValue = &n
|
||||
ok = true
|
||||
case "restricted":
|
||||
edit.RestrictedValue = &y
|
||||
if hasName {
|
||||
edit.Replace = true
|
||||
}
|
||||
ok = true
|
||||
case "unrestricted":
|
||||
edit.RestrictedValue = &n
|
||||
ok = true
|
||||
case "hidden":
|
||||
edit.HiddenValue = &y
|
||||
if hasName {
|
||||
edit.Replace = true
|
||||
}
|
||||
ok = true
|
||||
case "visible":
|
||||
edit.HiddenValue = &n
|
||||
ok = true
|
||||
case "private":
|
||||
edit.PrivateValue = &y
|
||||
if hasName {
|
||||
edit.Replace = true
|
||||
}
|
||||
ok = true
|
||||
case "public":
|
||||
edit.PrivateValue = &n
|
||||
ok = true
|
||||
case "livekit":
|
||||
edit.LiveKitValue = &y
|
||||
edit.Replace = true
|
||||
ok = true
|
||||
case "no-livekit":
|
||||
edit.LiveKitValue = &n
|
||||
ok = true
|
||||
case "no-text":
|
||||
edit.SupportedKindsValue = nil
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ok {
|
||||
@@ -226,19 +280,36 @@ func (a RemoveUser) Apply(group *Group) {
|
||||
}
|
||||
|
||||
type EditMetadata struct {
|
||||
NameValue *string
|
||||
PictureValue *string
|
||||
AboutValue *string
|
||||
RestrictedValue *bool
|
||||
ClosedValue *bool
|
||||
HiddenValue *bool
|
||||
PrivateValue *bool
|
||||
When nostr.Timestamp
|
||||
NameValue *string
|
||||
PictureValue *string
|
||||
AboutValue *string
|
||||
RestrictedValue *bool
|
||||
ClosedValue *bool
|
||||
HiddenValue *bool
|
||||
PrivateValue *bool
|
||||
LiveKitValue *bool
|
||||
SupportedKindsValue *[]nostr.Kind
|
||||
|
||||
Replace bool
|
||||
When nostr.Timestamp
|
||||
}
|
||||
|
||||
func (_ EditMetadata) Name() string { return "edit-metadata" }
|
||||
func (a EditMetadata) Apply(group *Group) {
|
||||
group.LastMetadataUpdate = a.When
|
||||
|
||||
if a.Replace {
|
||||
group.Name = ""
|
||||
group.Picture = ""
|
||||
group.About = ""
|
||||
group.Restricted = false
|
||||
group.Closed = false
|
||||
group.Hidden = false
|
||||
group.Private = false
|
||||
group.LiveKit = false
|
||||
group.SupportedKinds = nil
|
||||
}
|
||||
|
||||
if a.NameValue != nil {
|
||||
group.Name = *a.NameValue
|
||||
}
|
||||
@@ -260,6 +331,12 @@ func (a EditMetadata) Apply(group *Group) {
|
||||
if a.PrivateValue != nil {
|
||||
group.Private = *a.PrivateValue
|
||||
}
|
||||
if a.LiveKitValue != nil {
|
||||
group.LiveKit = *a.LiveKitValue
|
||||
}
|
||||
if a.SupportedKindsValue != nil {
|
||||
group.SupportedKinds = *a.SupportedKindsValue
|
||||
}
|
||||
}
|
||||
|
||||
type CreateGroup struct {
|
||||
@@ -272,6 +349,7 @@ func (a CreateGroup) Apply(group *Group) {
|
||||
group.LastMetadataUpdate = a.When
|
||||
group.LastAdminsUpdate = a.When
|
||||
group.LastMembersUpdate = a.When
|
||||
group.LastLiveKitParticipantsUpdate = a.When
|
||||
}
|
||||
|
||||
type DeleteGroup struct {
|
||||
@@ -281,6 +359,7 @@ type DeleteGroup struct {
|
||||
func (_ DeleteGroup) Name() string { return "delete-group" }
|
||||
func (a DeleteGroup) Apply(group *Group) {
|
||||
group.Members = make(map[nostr.PubKey][]*Role)
|
||||
group.LiveKitParticipants = make([]nostr.PubKey, 0)
|
||||
group.Closed = true
|
||||
group.Private = true
|
||||
group.Restricted = true
|
||||
@@ -288,9 +367,11 @@ func (a DeleteGroup) Apply(group *Group) {
|
||||
group.Name = "[deleted]"
|
||||
group.About = ""
|
||||
group.Picture = ""
|
||||
group.LiveKit = false
|
||||
group.LastMetadataUpdate = a.When
|
||||
group.LastAdminsUpdate = a.When
|
||||
group.LastMembersUpdate = a.When
|
||||
group.LastLiveKitParticipantsUpdate = a.When
|
||||
}
|
||||
|
||||
type CreateInvite struct {
|
||||
|
||||
@@ -28,6 +28,7 @@ var MetadataEventKinds = KindRange{
|
||||
nostr.KindSimpleGroupAdmins,
|
||||
nostr.KindSimpleGroupMembers,
|
||||
nostr.KindSimpleGroupRoles,
|
||||
nostr.KindSimpleGroupLiveKitParticipants,
|
||||
}
|
||||
|
||||
func (kr KindRange) Includes(kind nostr.Kind) bool {
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Person struct {
|
||||
Name string
|
||||
Email string
|
||||
Timestamp int64
|
||||
Timezone string
|
||||
}
|
||||
|
||||
type Commit struct {
|
||||
Hash string
|
||||
Tree string
|
||||
Parents []string
|
||||
Author Person
|
||||
Committer Person
|
||||
Message string
|
||||
}
|
||||
|
||||
func ParseCommit(data []byte, hash string) (*Commit, error) {
|
||||
content := string(data)
|
||||
|
||||
headerEndIndex := strings.Index(content, "\n\n")
|
||||
if headerEndIndex == -1 {
|
||||
return nil, fmt.Errorf("invalid commit format for %s: no message separator found", hash)
|
||||
}
|
||||
|
||||
header := content[:headerEndIndex]
|
||||
message := content[headerEndIndex+2:]
|
||||
|
||||
lines := strings.Split(header, "\n")
|
||||
result := &Commit{
|
||||
Hash: hash,
|
||||
Parents: []string{},
|
||||
Message: message,
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "tree ") {
|
||||
result.Tree = line[5:]
|
||||
} else if strings.HasPrefix(line, "parent ") {
|
||||
result.Parents = append(result.Parents, line[7:])
|
||||
} else if strings.HasPrefix(line, "author ") {
|
||||
person, err := parsePerson(line[7:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid author in commit %s: %w", hash, err)
|
||||
}
|
||||
result.Author = person
|
||||
} else if strings.HasPrefix(line, "committer ") {
|
||||
person, err := parsePerson(line[10:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid committer in commit %s: %w", hash, err)
|
||||
}
|
||||
result.Committer = person
|
||||
}
|
||||
}
|
||||
|
||||
if result.Tree == "" {
|
||||
return nil, fmt.Errorf("invalid commit format for %s: missing tree", hash)
|
||||
}
|
||||
if result.Author.Name == "" {
|
||||
return nil, fmt.Errorf("invalid commit format for %s: missing author", hash)
|
||||
}
|
||||
if result.Committer.Name == "" {
|
||||
return nil, fmt.Errorf("invalid commit format for %s: missing committer", hash)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func parsePerson(line string) (Person, error) {
|
||||
emailStart := strings.Index(line, " <")
|
||||
if emailStart == -1 {
|
||||
return Person{}, fmt.Errorf("invalid person format: %s", line)
|
||||
}
|
||||
name := line[:emailStart]
|
||||
|
||||
emailEnd := strings.Index(line[emailStart+2:], ">")
|
||||
if emailEnd == -1 {
|
||||
return Person{}, fmt.Errorf("invalid person format: %s", line)
|
||||
}
|
||||
email := line[emailStart+2 : emailStart+2+emailEnd]
|
||||
|
||||
rest := strings.TrimSpace(line[emailStart+2+emailEnd+1:])
|
||||
parts := strings.SplitN(rest, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return Person{}, fmt.Errorf("invalid person format: %s", line)
|
||||
}
|
||||
|
||||
timestamp, err := strconv.ParseInt(parts[0], 10, 64)
|
||||
if err != nil {
|
||||
return Person{}, fmt.Errorf("invalid timestamp in person: %s", parts[0])
|
||||
}
|
||||
|
||||
return Person{
|
||||
Name: name,
|
||||
Email: email,
|
||||
Timestamp: timestamp,
|
||||
Timezone: parts[1],
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type DiffLine struct {
|
||||
Index int
|
||||
Status string
|
||||
Text string
|
||||
Change string
|
||||
}
|
||||
|
||||
type DiffFile struct {
|
||||
Path string
|
||||
Status string
|
||||
Content []byte
|
||||
Lines []DiffLine
|
||||
}
|
||||
|
||||
type changedEntry struct {
|
||||
newVersion TreeEntry
|
||||
oldVersions []TreeEntry
|
||||
}
|
||||
|
||||
func GetCommitDiff(url string, commitOrRef string) ([]DiffFile, error) {
|
||||
commit, err := GetSingleCommit(url, commitOrRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
added := make(map[string]TreeEntry)
|
||||
deleted := make(map[string]TreeEntry)
|
||||
changed := make(map[string]*changedEntry)
|
||||
unchanged := make(map[string]bool)
|
||||
|
||||
for _, parent := range commit.Parents {
|
||||
parentCommit, err := GetSingleCommit(url, parent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = computeTreeDiffs(url, commit.Tree, parentCommit.Tree, "", added, deleted, changed, unchanged)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var diff []DiffFile
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var firstErr error
|
||||
|
||||
for path, entry := range changed {
|
||||
p := path
|
||||
e := entry
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
curr, err := GetObject(url, e.newVersion.Hash)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
if curr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(e.oldVersions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
oldObj, err := GetObject(url, e.oldVersions[0].Hash)
|
||||
if err != nil || oldObj == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if isBinary(curr.Data) || isBinary(oldObj.Data) {
|
||||
mu.Lock()
|
||||
diff = append(diff, DiffFile{
|
||||
Path: p,
|
||||
Status: "changed-binary",
|
||||
})
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
lines := diffTextLines(oldObj.Data, curr.Data)
|
||||
mu.Lock()
|
||||
diff = append(diff, DiffFile{
|
||||
Path: p,
|
||||
Status: "changed",
|
||||
Lines: lines,
|
||||
})
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
for path, entry := range deleted {
|
||||
p := path
|
||||
e := entry
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
obj, err := GetObject(url, e.Hash)
|
||||
if err != nil || obj == nil {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
diff = append(diff, DiffFile{
|
||||
Path: p,
|
||||
Status: "deleted",
|
||||
Content: obj.Data,
|
||||
})
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
for path, entry := range added {
|
||||
p := path
|
||||
e := entry
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
obj, err := GetObject(url, e.Hash)
|
||||
if err != nil || obj == nil {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
diff = append(diff, DiffFile{
|
||||
Path: p,
|
||||
Status: "added",
|
||||
Content: obj.Data,
|
||||
})
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if firstErr != nil {
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
return diff, nil
|
||||
}
|
||||
|
||||
func isBinary(data []byte) bool {
|
||||
for _, b := range data {
|
||||
if b == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func diffTextLines(oldData []byte, newData []byte) []DiffLine {
|
||||
oldText := string(oldData)
|
||||
newText := string(newData)
|
||||
oldLines := splitLines(oldText)
|
||||
newLines := splitLines(newText)
|
||||
|
||||
ops := lcsOperations(oldLines, newLines)
|
||||
allLines := make([]DiffLine, 0, len(ops))
|
||||
oldIndex := 1
|
||||
newIndex := 1
|
||||
|
||||
for i := 0; i < len(ops); i++ {
|
||||
op := ops[i]
|
||||
var next *lcsOp
|
||||
if i+1 < len(ops) {
|
||||
next = &ops[i+1]
|
||||
}
|
||||
|
||||
if op.typ == "del" && next != nil && next.typ == "add" {
|
||||
allLines = append(allLines, DiffLine{
|
||||
Status: "changed",
|
||||
Index: newIndex,
|
||||
Text: next.line,
|
||||
})
|
||||
oldIndex++
|
||||
newIndex++
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if op.typ == "add" && next != nil && next.typ == "del" {
|
||||
allLines = append(allLines, DiffLine{
|
||||
Status: "changed",
|
||||
Index: newIndex,
|
||||
Text: op.line,
|
||||
})
|
||||
oldIndex++
|
||||
newIndex++
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if op.typ == "add" {
|
||||
allLines = append(allLines, DiffLine{
|
||||
Status: "added",
|
||||
Index: newIndex,
|
||||
Text: op.line,
|
||||
})
|
||||
newIndex++
|
||||
continue
|
||||
}
|
||||
|
||||
if op.typ == "del" {
|
||||
allLines = append(allLines, DiffLine{
|
||||
Status: "deleted",
|
||||
Index: oldIndex,
|
||||
Text: op.line,
|
||||
})
|
||||
oldIndex++
|
||||
continue
|
||||
}
|
||||
|
||||
oldIndex++
|
||||
newIndex++
|
||||
}
|
||||
|
||||
if len(allLines) == 0 {
|
||||
return allLines
|
||||
}
|
||||
|
||||
keep := make([]bool, len(allLines))
|
||||
for i := 0; i < len(allLines); i++ {
|
||||
start := i - 3
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
end := i + 3
|
||||
if end >= len(allLines) {
|
||||
end = len(allLines) - 1
|
||||
}
|
||||
for j := start; j <= end; j++ {
|
||||
keep[j] = true
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]DiffLine, 0, len(allLines))
|
||||
for i := 0; i < len(allLines); i++ {
|
||||
if keep[i] {
|
||||
result = append(result, allLines[i])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type lcsOp struct {
|
||||
typ string
|
||||
line string
|
||||
}
|
||||
|
||||
func splitLines(text string) []string {
|
||||
lines := strings.Split(text, "\n")
|
||||
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1]
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func lcsOperations(oldLines []string, newLines []string) []lcsOp {
|
||||
n := len(oldLines)
|
||||
m := len(newLines)
|
||||
|
||||
dp := make([][]uint32, n+1)
|
||||
for i := range dp {
|
||||
dp[i] = make([]uint32, m+1)
|
||||
}
|
||||
|
||||
for i := 1; i <= n; i++ {
|
||||
for j := 1; j <= m; j++ {
|
||||
if oldLines[i-1] == newLines[j-1] {
|
||||
dp[i][j] = dp[i-1][j-1] + 1
|
||||
} else if dp[i-1][j] >= dp[i][j-1] {
|
||||
dp[i][j] = dp[i-1][j]
|
||||
} else {
|
||||
dp[i][j] = dp[i][j-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ops := make([]lcsOp, 0, n+m)
|
||||
i := n
|
||||
j := m
|
||||
for i > 0 || j > 0 {
|
||||
if i > 0 && j > 0 && oldLines[i-1] == newLines[j-1] {
|
||||
ops = append(ops, lcsOp{typ: "equal", line: oldLines[i-1]})
|
||||
i--
|
||||
j--
|
||||
continue
|
||||
}
|
||||
|
||||
if i > 0 && (j == 0 || dp[i-1][j] >= dp[i][j-1]) {
|
||||
ops = append(ops, lcsOp{typ: "del", line: oldLines[i-1]})
|
||||
i--
|
||||
continue
|
||||
}
|
||||
|
||||
if j > 0 {
|
||||
ops = append(ops, lcsOp{typ: "add", line: newLines[j-1]})
|
||||
j--
|
||||
}
|
||||
}
|
||||
|
||||
for i, j := 0, len(ops)-1; i < j; i, j = i+1, j-1 {
|
||||
ops[i], ops[j] = ops[j], ops[i]
|
||||
}
|
||||
|
||||
return ops
|
||||
}
|
||||
|
||||
func computeTreeDiffs(
|
||||
url string,
|
||||
treeHash string,
|
||||
parentTreeHash string,
|
||||
basePath string,
|
||||
added map[string]TreeEntry,
|
||||
deleted map[string]TreeEntry,
|
||||
changed map[string]*changedEntry,
|
||||
unchanged map[string]bool,
|
||||
) error {
|
||||
var newTree []TreeEntry
|
||||
var oldTree []TreeEntry
|
||||
|
||||
if treeHash != "" {
|
||||
obj, err := GetObject(url, treeHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if obj != nil {
|
||||
newTree = ParseTree(obj.Data)
|
||||
}
|
||||
}
|
||||
|
||||
if parentTreeHash != "" {
|
||||
obj, err := GetObject(url, parentTreeHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if obj != nil {
|
||||
oldTree = ParseTree(obj.Data)
|
||||
}
|
||||
}
|
||||
|
||||
for _, entry := range newTree {
|
||||
var old *TreeEntry
|
||||
for _, o := range oldTree {
|
||||
if o.Path == entry.Path {
|
||||
o := o
|
||||
old = &o
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if old != nil {
|
||||
delete(added, basePath+entry.Path)
|
||||
|
||||
if old.Hash == entry.Hash {
|
||||
unchanged[basePath+entry.Path] = true
|
||||
} else {
|
||||
if entry.IsDir {
|
||||
err := computeTreeDiffs(url, entry.Hash, old.Hash, basePath+entry.Path+"/", added, deleted, changed, unchanged)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if existing, exists := changed[basePath+entry.Path]; !exists {
|
||||
changed[basePath+entry.Path] = &changedEntry{
|
||||
newVersion: entry,
|
||||
oldVersions: []TreeEntry{*old},
|
||||
}
|
||||
} else {
|
||||
existing.oldVersions = append(existing.oldVersions, *old)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if entry.IsDir {
|
||||
err := computeTreeDiffs(url, entry.Hash, "", basePath+entry.Path+"/", added, deleted, changed, unchanged)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
added[basePath+entry.Path] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, old := range oldTree {
|
||||
if unchanged[basePath+old.Path] || changed[basePath+old.Path] != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if old.IsDir {
|
||||
err := computeTreeDiffs(url, "", old.Hash, basePath+old.Path+"/", added, deleted, changed, unchanged)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
deleted[basePath+old.Path] = old
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MissingCapability struct {
|
||||
URL string
|
||||
Capability string
|
||||
}
|
||||
|
||||
func (e *MissingCapability) Error() string {
|
||||
return fmt.Sprintf("server at %s is missing required capability %s", e.URL, e.Capability)
|
||||
}
|
||||
|
||||
func prepareRequest(url string, commitOrRef string, needFilter bool) (resolvedRef string, capabilities []string, err error) {
|
||||
var info *InfoRefsUploadPackResponse
|
||||
if strings.HasPrefix(commitOrRef, "refs/") {
|
||||
info, err = GetInfoRefs(url)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
resolved, ok := info.Refs[commitOrRef]
|
||||
if !ok {
|
||||
return "", nil, fmt.Errorf("ref %s not found", commitOrRef)
|
||||
}
|
||||
commitOrRef = resolved
|
||||
}
|
||||
|
||||
caps, err := GetCapabilities(url, info)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
for _, c := range DefaultCapabilities {
|
||||
if slices.Contains(caps, c) {
|
||||
capabilities = append(capabilities, c)
|
||||
}
|
||||
}
|
||||
for _, c := range NecessaryCapabilities {
|
||||
if slices.Contains(caps, c) {
|
||||
capabilities = append(capabilities, c)
|
||||
} else {
|
||||
return "", nil, &MissingCapability{URL: url, Capability: c}
|
||||
}
|
||||
}
|
||||
for _, c := range RequiredCapabilities {
|
||||
if !slices.Contains(caps, c) {
|
||||
return "", nil, &MissingCapability{URL: url, Capability: c}
|
||||
}
|
||||
}
|
||||
if needFilter {
|
||||
if slices.Contains(caps, "filter") {
|
||||
capabilities = append(capabilities, "filter")
|
||||
} else {
|
||||
return "", nil, &MissingCapability{URL: url, Capability: "filter"}
|
||||
}
|
||||
}
|
||||
|
||||
return commitOrRef, capabilities, nil
|
||||
}
|
||||
|
||||
func GetObject(url string, blobHash string) (*ParsedObject, error) {
|
||||
ref, caps, err := prepareRequest(url, blobHash, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deepen := 1
|
||||
want, err := CreateWantRequest(ref, caps, &deepen, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := FetchPackfile(url, want)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.Objects[blobHash], nil
|
||||
}
|
||||
|
||||
func GetDirectoryTreeAt(url string, commitOrRef string, nestLimit *int) (*Tree, error) {
|
||||
ref, caps, err := prepareRequest(url, commitOrRef, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
want, err := CreateWantRequest(ref, caps, nestLimit, "blob:none")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := FetchPackfile(url, want)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commit := result.Objects[ref]
|
||||
if commit == nil {
|
||||
return nil, fmt.Errorf("commit %s not found in packfile", ref)
|
||||
}
|
||||
|
||||
treeHash := string(commit.Data[5:45])
|
||||
rootTree := result.Objects[treeHash]
|
||||
if rootTree == nil {
|
||||
return nil, fmt.Errorf("root tree %s not found in packfile", treeHash)
|
||||
}
|
||||
|
||||
return LoadTree(rootTree, result.Objects, nestLimit), nil
|
||||
}
|
||||
|
||||
func ShallowCloneRepositoryAt(url string, commitOrRef string) (*Commit, *Tree, error) {
|
||||
ref, caps, err := prepareRequest(url, commitOrRef, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
deepen := 1
|
||||
want, err := CreateWantRequest(ref, caps, &deepen, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
result, err := FetchPackfile(url, want)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
commitObj := result.Objects[ref]
|
||||
if commitObj == nil {
|
||||
return nil, nil, fmt.Errorf("commit %s not found in packfile", ref)
|
||||
}
|
||||
|
||||
treeHash := string(commitObj.Data[5:45])
|
||||
rootTree := result.Objects[treeHash]
|
||||
if rootTree == nil {
|
||||
return nil, nil, fmt.Errorf("root tree %s not found in packfile", treeHash)
|
||||
}
|
||||
|
||||
commit, err := ParseCommit(commitObj.Data, commitObj.Hash)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
tree := LoadTree(rootTree, result.Objects, nil)
|
||||
return commit, tree, nil
|
||||
}
|
||||
|
||||
func FetchCommitsOnly(url string, commitOrRef string, maxCommits *int) ([]*Commit, error) {
|
||||
ref, caps, err := prepareRequest(url, commitOrRef, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
want, err := CreateWantRequest(ref, caps, maxCommits, "tree:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := FetchPackfile(url, want)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commitMap := make(map[string]*Commit, len(result.Objects))
|
||||
for hash, obj := range result.Objects {
|
||||
commit, err := ParseCommit(obj.Data, hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
commitMap[hash] = commit
|
||||
}
|
||||
|
||||
// sort topologically starting from the requested ref
|
||||
sorted := make([]*Commit, 0, len(commitMap))
|
||||
visited := make(map[string]bool, len(commitMap))
|
||||
var visit func(hash string)
|
||||
visit = func(hash string) {
|
||||
if visited[hash] {
|
||||
return
|
||||
}
|
||||
c, ok := commitMap[hash]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
visited[hash] = true
|
||||
sorted = append(sorted, c)
|
||||
for _, parent := range c.Parents {
|
||||
visit(parent)
|
||||
}
|
||||
}
|
||||
visit(ref)
|
||||
|
||||
for _, c := range commitMap {
|
||||
if !visited[c.Hash] {
|
||||
sorted = append(sorted, c)
|
||||
}
|
||||
}
|
||||
|
||||
return sorted, nil
|
||||
}
|
||||
|
||||
func GetSingleCommit(url string, commitOrRef string) (*Commit, error) {
|
||||
maxCommits := 1
|
||||
commits, err := FetchCommitsOnly(url, commitOrRef, &maxCommits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(commits) == 0 {
|
||||
return nil, fmt.Errorf("no commit found for reference: %s", commitOrRef)
|
||||
}
|
||||
return commits[0], nil
|
||||
}
|
||||
|
||||
func GetObjectByPath(url string, commitOrRef string, path string) (*TreeEntry, error) {
|
||||
normalizedPath := strings.ReplaceAll(path, "\\", "/")
|
||||
normalizedPath = strings.TrimLeft(normalizedPath, "/")
|
||||
normalizedPath = strings.TrimRight(normalizedPath, "/")
|
||||
|
||||
var pathSegments []string
|
||||
if normalizedPath != "" {
|
||||
pathSegments = strings.Split(normalizedPath, "/")
|
||||
}
|
||||
requiredDepth := len(pathSegments)
|
||||
|
||||
tree, err := GetDirectoryTreeAt(url, commitOrRef, &requiredDepth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentLevel := tree
|
||||
nextSegment:
|
||||
for i, segment := range pathSegments {
|
||||
isLastSegment := i == len(pathSegments)-1
|
||||
|
||||
for _, dir := range currentLevel.Directories {
|
||||
if dir.Name == segment {
|
||||
if isLastSegment {
|
||||
return &TreeEntry{Path: segment, Mode: "40000", IsDir: true, Hash: dir.Hash}, nil
|
||||
}
|
||||
if dir.Content != nil {
|
||||
currentLevel = dir.Content
|
||||
continue nextSegment
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if isLastSegment {
|
||||
for _, file := range currentLevel.Files {
|
||||
if file.Name == segment {
|
||||
return &TreeEntry{Path: segment, Mode: "100644", IsDir: false, Hash: file.Hash}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -0,0 +1,289 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetRefs(t *testing.T) {
|
||||
info, err := GetInfoRefs("https://codeberg.org/dluvian/gitplaza.git")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, info.Capabilities, "shallow")
|
||||
require.Contains(t, info.Capabilities, "object-format=sha1")
|
||||
require.Greater(t, len(info.Refs), 5)
|
||||
require.Contains(t, info.Refs, "refs/heads/master")
|
||||
require.Equal(t, "a04d0761564b0d23c5edbadf494ab4f1cc4656f4", info.Refs["refs/tags/v0.1.0"])
|
||||
require.Equal(t, "refs/heads/master", info.Symrefs["HEAD"])
|
||||
}
|
||||
|
||||
func TestGetOnlyTreeAtCurrentCommit(t *testing.T) {
|
||||
urls := []string{
|
||||
"https://codeberg.org/dluvian/gitplaza.git",
|
||||
"https://github.com/fiatjaf/pyramid.git",
|
||||
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||
}
|
||||
|
||||
for _, url := range urls {
|
||||
t.Run(url, func(t *testing.T) {
|
||||
tree, err := GetDirectoryTreeAt(url, "refs/heads/master", nil)
|
||||
require.NoError(t, err)
|
||||
for _, file := range tree.Files {
|
||||
require.Nil(t, file.Content, "file %q should have nil content at %s", file.Name, url)
|
||||
}
|
||||
require.Greater(t, len(tree.Directories), 2, "at %s", url)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneRepositoryAtCurrentCommit(t *testing.T) {
|
||||
urls := []string{
|
||||
"https://codeberg.org/dluvian/gitplaza.git",
|
||||
"https://github.com/fiatjaf/pyramid.git",
|
||||
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||
}
|
||||
|
||||
for _, url := range urls {
|
||||
t.Run(url, func(t *testing.T) {
|
||||
commit, tree, err := ShallowCloneRepositoryAt(url, "refs/heads/master")
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(tree.Files), 5, "at %s", url)
|
||||
require.Greater(t, len(tree.Directories), 2, "at %s", url)
|
||||
|
||||
info, err := GetInfoRefs(url)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Refs["refs/heads/master"], commit.Hash, "at %s", url)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSpecificObject(t *testing.T) {
|
||||
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||
hash := "0f9438a8fd68594cd663fb8dbd23c5f5139f5263" // shell.nix
|
||||
|
||||
blob, err := GetObject(url, hash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, blob)
|
||||
require.Equal(t, ObjectTypeBlob, blob.Type)
|
||||
|
||||
expected := "(builtins.getFlake\n (\"git+file://\" + toString ./.)).devShells.${builtins.currentSystem}.default\n"
|
||||
require.Equal(t, expected, string(blob.Data))
|
||||
}
|
||||
|
||||
func TestGetNonExistentCommit(t *testing.T) {
|
||||
urls := []string{
|
||||
"https://codeberg.org/dluvian/gitplaza.git",
|
||||
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||
"https://github.com/fiatjaf/nak.git",
|
||||
}
|
||||
|
||||
commit := "1d4438a8fd68594cd663fb8dbd23c5f5139fabcd" // doesn't exist
|
||||
|
||||
for _, url := range urls {
|
||||
t.Run(url, func(t *testing.T) {
|
||||
_, err := GetDirectoryTreeAt(url, commit, nil)
|
||||
require.Error(t, err)
|
||||
var missingRef *MissingRef
|
||||
require.ErrorAs(t, err, &missingRef)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchListOfCommits(t *testing.T) {
|
||||
commits, err := FetchCommitsOnly(
|
||||
"https://pyramid.fiatjaf.com/npub180cvv07tjdrrgpa0j7j7tmnyl2yr6yr7l8j4s3evf6u64th6gkwsyjh6w6/nostrlib.git",
|
||||
"refs/heads/master",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(commits), 10)
|
||||
}
|
||||
|
||||
func TestFetch10PastCommits(t *testing.T) {
|
||||
maxCommits := 10
|
||||
commits, err := FetchCommitsOnly(
|
||||
"https://github.com/fiatjaf/pyramid.git",
|
||||
"57712756e37d7c60d1ac53e0f6b59e9ecad67c9a",
|
||||
&maxCommits,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, commits, 10)
|
||||
|
||||
c := commits[1]
|
||||
require.Equal(t, "49c1b48f5120bad4089535a190d2233c96188fa2", c.Hash)
|
||||
require.Equal(t, "286786a6f1072a2ef5ae057fbb611858b8e88bc4", c.Tree)
|
||||
require.Equal(t, []string{"1599e46c0ee6f460e25048880754868d4f9644fd"}, c.Parents)
|
||||
require.Equal(t, "fiatjaf", c.Author.Name)
|
||||
require.Equal(t, "fiatjaf@gmail.com", c.Author.Email)
|
||||
require.Equal(t, int64(1767157644), c.Author.Timestamp)
|
||||
require.Equal(t, "-0300", c.Author.Timezone)
|
||||
require.Equal(t, "fiatjaf", c.Committer.Name)
|
||||
require.Equal(t, "fiatjaf@gmail.com", c.Committer.Email)
|
||||
require.Equal(t, int64(1767157724), c.Committer.Timestamp)
|
||||
require.Equal(t, "-0300", c.Committer.Timezone)
|
||||
require.Equal(t, "scheduled events.\n", c.Message)
|
||||
|
||||
expectedMsg5 := "turn off groups logic on QueryStore and PreventBroadcast when groups is turned off.\n\nthis was causing crashes that Golang's bizarre iter API showed as happening inside SortedMerge.\n"
|
||||
require.Equal(t, expectedMsg5, commits[5].Message)
|
||||
}
|
||||
|
||||
func TestGetSingleCommit(t *testing.T) {
|
||||
url := "https://github.com/fiatjaf/pyramid.git"
|
||||
commit, err := GetSingleCommit(url, "5e982dd1122a0bb1b0154c222ec4ba841f3820c6")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "5e982dd1122a0bb1b0154c222ec4ba841f3820c6", commit.Hash)
|
||||
require.Equal(t, "fiatjaf", commit.Author.Name)
|
||||
require.Equal(t, "validate incoming git-related stuff.\n", commit.Message)
|
||||
}
|
||||
|
||||
func TestGetDirectoryTreeWithDepthLimit(t *testing.T) {
|
||||
url := "https://github.com/fiatjaf/pyramid.git"
|
||||
|
||||
fullTree, err := GetDirectoryTreeAt(url, "refs/heads/master", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
depth := 1
|
||||
shallowTree, err := GetDirectoryTreeAt(url, "refs/heads/master", &depth)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, len(fullTree.Directories), len(shallowTree.Directories))
|
||||
|
||||
for _, dir := range shallowTree.Directories {
|
||||
require.NotNil(t, dir.Content, "directory %q content should not be nil at depth 1", dir.Name)
|
||||
for _, file := range dir.Content.Files {
|
||||
require.NotEmpty(t, file.Name)
|
||||
require.Nil(t, file.Content, "file %q content should be nil", file.Name)
|
||||
}
|
||||
for _, subdir := range dir.Content.Directories {
|
||||
require.NotEmpty(t, subdir.Name)
|
||||
require.Nil(t, subdir.Content, "subdir %q content should be nil at depth 1", subdir.Name)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, len(fullTree.Files), len(shallowTree.Files))
|
||||
}
|
||||
|
||||
func TestGetObjectByPathExistingFile(t *testing.T) {
|
||||
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||
entry, err := GetObjectByPath(url, "refs/heads/master", "README.md")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, entry)
|
||||
require.Equal(t, "README.md", entry.Path)
|
||||
require.False(t, entry.IsDir)
|
||||
require.Equal(t, "100644", entry.Mode)
|
||||
require.NotEmpty(t, entry.Hash)
|
||||
}
|
||||
|
||||
func TestGetObjectByPathExistingDirectory(t *testing.T) {
|
||||
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||
entry, err := GetObjectByPath(url, "refs/heads/master", "src")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, entry)
|
||||
require.Equal(t, "src", entry.Path)
|
||||
require.True(t, entry.IsDir)
|
||||
require.Equal(t, "40000", entry.Mode)
|
||||
require.NotEmpty(t, entry.Hash)
|
||||
}
|
||||
|
||||
func TestGetObjectByPathNestedFile(t *testing.T) {
|
||||
url := "https://github.com/fiatjaf/pyramid.git"
|
||||
entry, err := GetObjectByPath(url, "d567c18cd5c144a58b0214216f454b3caf49d4ff", "grasp/grasp.templ")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, entry)
|
||||
require.Equal(t, "grasp.templ", entry.Path)
|
||||
require.False(t, entry.IsDir)
|
||||
require.Equal(t, "100644", entry.Mode)
|
||||
require.Equal(t, "05bce14339ece5f48c670d0592faa8dece9e8957", entry.Hash)
|
||||
}
|
||||
|
||||
func TestGetObjectByPathNonExistent(t *testing.T) {
|
||||
url := "https://codeberg.org/dluvian/gitplaza.git"
|
||||
entry, err := GetObjectByPath(url, "refs/heads/master", "whatever/something/x/y/z/non-existent-file.txt")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, entry)
|
||||
}
|
||||
|
||||
func TestGetCommitDiff(t *testing.T) {
|
||||
url := "https://github.com/smallhelm/diff-lines.git"
|
||||
diff, err := GetCommitDiff(url, "a73592653fe9d01f948ca3035e088e45f722eca7")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, diff)
|
||||
require.Len(t, diff, 5)
|
||||
|
||||
byPath := make(map[string]DiffFile, len(diff))
|
||||
for _, f := range diff {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
path string
|
||||
status string
|
||||
}{
|
||||
{".travis.yml", "added"},
|
||||
{".gitignore", "added"},
|
||||
{"index.js", "added"},
|
||||
{"tests.js", "added"},
|
||||
{"package.json", "changed"},
|
||||
} {
|
||||
f, ok := byPath[tc.path]
|
||||
require.True(t, ok, "missing diff file %q", tc.path)
|
||||
require.Equal(t, tc.status, f.Status, "%s status", tc.path)
|
||||
}
|
||||
|
||||
gitignore, ok := byPath[".gitignore"]
|
||||
require.True(t, ok, "missing .gitignore in diff")
|
||||
require.Equal(t, "/node_modules\n", string(gitignore.Content))
|
||||
|
||||
pkg, ok := byPath["package.json"]
|
||||
require.True(t, ok, "missing package.json in diff")
|
||||
require.NotEmpty(t, pkg.Lines, "package.json should have diff lines")
|
||||
|
||||
normalizeLineStatus := func(status string) string {
|
||||
if status == "same" {
|
||||
return "changed"
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
lineByIndex := make(map[int]DiffLine, len(pkg.Lines))
|
||||
for _, line := range pkg.Lines {
|
||||
lineByIndex[line.Index] = line
|
||||
}
|
||||
|
||||
expectedLines := []struct {
|
||||
Index int
|
||||
Status string
|
||||
Text string
|
||||
}{
|
||||
{Index: 22, Status: "added", Text: " },"},
|
||||
{Index: 23, Status: "added", Text: " \"homepage\": \"https://github.com/smallhelm/diff-lines#readme\","},
|
||||
{Index: 24, Status: "added", Text: " \"devDependencies\": {"},
|
||||
{Index: 25, Status: "added", Text: " \"tape\": \"^4.6.0\""},
|
||||
{Index: 27, Status: "added", Text: " \"dependencies\": {"},
|
||||
{Index: 28, Status: "added", Text: " \"diff\": \"^2.2.3\""},
|
||||
{Index: 29, Status: "changed", Text: " }"},
|
||||
}
|
||||
|
||||
actualLines := make([]struct {
|
||||
Index int
|
||||
Status string
|
||||
Text string
|
||||
}, 0, len(expectedLines))
|
||||
for _, expected := range expectedLines {
|
||||
line, ok := lineByIndex[expected.Index]
|
||||
require.True(t, ok, "missing package.json diff line %d", expected.Index)
|
||||
actualLines = append(actualLines, struct {
|
||||
Index int
|
||||
Status string
|
||||
Text string
|
||||
}{
|
||||
Index: line.Index,
|
||||
Status: normalizeLineStatus(line.Status),
|
||||
Text: line.Text,
|
||||
})
|
||||
}
|
||||
|
||||
require.Equal(t, expectedLines, actualLines)
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var NecessaryCapabilities = []string{
|
||||
"multi_ack_detailed",
|
||||
"side-band-64k",
|
||||
}
|
||||
|
||||
var RequiredCapabilities = []string{
|
||||
"shallow",
|
||||
"object-format=sha1",
|
||||
}
|
||||
|
||||
var DefaultCapabilities = []string{
|
||||
"ofs-delta",
|
||||
"no-progress",
|
||||
}
|
||||
|
||||
type MissingRef struct{}
|
||||
|
||||
func (e *MissingRef) Error() string { return "missing ref" }
|
||||
|
||||
type InvalidCommit struct {
|
||||
Commit string
|
||||
}
|
||||
|
||||
func (e *InvalidCommit) Error() string {
|
||||
return fmt.Sprintf("invalid commit '%s', must be 20 byte hex", e.Commit)
|
||||
}
|
||||
|
||||
func FetchPackfile(url string, want string) (*PackfileResult, error) {
|
||||
req, err := http.NewRequest("POST", url+"/git-upload-pack", strings.NewReader(want))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create git-upload-pack request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-git-upload-pack-request")
|
||||
req.Header.Set("Accept", "application/x-git-upload-pack-result")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call git-upload-pack: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to call git-upload-pack: %s", string(body))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read git-upload-pack response: %w", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return nil, fmt.Errorf("empty response")
|
||||
}
|
||||
|
||||
offset := 0
|
||||
for offset < len(data) {
|
||||
prev := offset
|
||||
if prev+1 >= len(data) {
|
||||
break
|
||||
}
|
||||
nlIdx := bytes.IndexByte(data[prev+1:], '\n')
|
||||
if nlIdx == -1 {
|
||||
if len(data) >= 32 && string(data[4:32]) == "ERR upload-pack: not our ref" {
|
||||
return nil, &MissingRef{}
|
||||
}
|
||||
end := len(data)
|
||||
if end > 63 {
|
||||
end = 63
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected '%s'", string(data[:end]))
|
||||
}
|
||||
offset = prev + nlIdx + 1
|
||||
if offset >= 3 && string(data[offset-3:offset]) == "NAK" {
|
||||
break
|
||||
}
|
||||
}
|
||||
offset++
|
||||
|
||||
var packfileData []byte
|
||||
for offset < len(data) {
|
||||
if offset+5 > len(data) {
|
||||
break
|
||||
}
|
||||
pktLen, err := strconv.ParseInt(string(data[offset:offset+4]), 16, 32)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
length := int(pktLen)
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
if offset+length > len(data) {
|
||||
break
|
||||
}
|
||||
if data[offset+4] == 2 {
|
||||
// progress message, ignore
|
||||
} else if data[offset+4] == 1 {
|
||||
packfileData = append(packfileData, data[offset+5:offset+length]...)
|
||||
}
|
||||
offset += length
|
||||
}
|
||||
|
||||
if len(packfileData) == 0 {
|
||||
return nil, &MissingRef{}
|
||||
}
|
||||
|
||||
return ParsePackfile(packfileData)
|
||||
}
|
||||
|
||||
func CreateWantRequest(commitSha string, capabilities []string, deepen *int, filter string) (string, error) {
|
||||
if len(commitSha) != 40 {
|
||||
return "", &InvalidCommit{Commit: commitSha}
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
|
||||
wantLine := fmt.Sprintf("want %s %s agent=nsa/1.0.0\n", commitSha, strings.Join(capabilities, " "))
|
||||
buf.WriteString(pktEncode(wantLine))
|
||||
|
||||
if deepen != nil {
|
||||
deepenLine := fmt.Sprintf("deepen %d\n", *deepen)
|
||||
buf.WriteString(pktEncode(deepenLine))
|
||||
}
|
||||
|
||||
if filter != "" {
|
||||
filterLine := fmt.Sprintf("filter %s\n", filter)
|
||||
buf.WriteString(pktEncode(filterLine))
|
||||
}
|
||||
|
||||
buf.WriteString("0000")
|
||||
buf.WriteString(pktEncode("done\n"))
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func pktEncode(data string) string {
|
||||
if len(data) == 0 {
|
||||
return "0000"
|
||||
}
|
||||
length := len(data) + 4
|
||||
return fmt.Sprintf("%04x%s", length, data)
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
ObjectTypeCommit = 1
|
||||
ObjectTypeTree = 2
|
||||
ObjectTypeBlob = 3
|
||||
ObjectTypeTag = 4
|
||||
ObjectTypeOfsDelta = 6
|
||||
ObjectTypeRefDelta = 7
|
||||
)
|
||||
|
||||
type ParsedObject struct {
|
||||
Type int
|
||||
Size int
|
||||
Data []byte
|
||||
Offset int
|
||||
Hash string
|
||||
}
|
||||
|
||||
type PackfileResult struct {
|
||||
Version int
|
||||
Count int
|
||||
Objects map[string]*ParsedObject
|
||||
}
|
||||
|
||||
func ParsePackfile(data []byte) (*PackfileResult, error) {
|
||||
if len(data) < 12 {
|
||||
return nil, fmt.Errorf("packfile too short")
|
||||
}
|
||||
|
||||
header := string(data[0:4])
|
||||
if header != "PACK" {
|
||||
return nil, fmt.Errorf("invalid packfile header: %s", header)
|
||||
}
|
||||
|
||||
version := int(binary.BigEndian.Uint32(data[4:8]))
|
||||
if version != 2 {
|
||||
return nil, fmt.Errorf("unsupported packfile version: %d", version)
|
||||
}
|
||||
|
||||
count := int(binary.BigEndian.Uint32(data[8:12]))
|
||||
|
||||
objects := make(map[string]*ParsedObject)
|
||||
pos := 12
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
obj, newPos, err := parsePackObject(data, pos, objects)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing object %d/%d: %w", i+1, count, err)
|
||||
}
|
||||
objects[obj.Hash] = obj
|
||||
pos = newPos
|
||||
}
|
||||
|
||||
return &PackfileResult{
|
||||
Version: version,
|
||||
Count: count,
|
||||
Objects: objects,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePackObject(data []byte, startPos int, objects map[string]*ParsedObject) (*ParsedObject, int, error) {
|
||||
pos := startPos
|
||||
offset := startPos
|
||||
|
||||
b := data[pos]
|
||||
pos++
|
||||
objType := int((b >> 4) & 0x07)
|
||||
size := int(b & 0x0f)
|
||||
shift := 4
|
||||
|
||||
for b&0x80 != 0 {
|
||||
b = data[pos]
|
||||
pos++
|
||||
size |= int(b&0x7f) << shift
|
||||
shift += 7
|
||||
}
|
||||
|
||||
var objData []byte
|
||||
var err error
|
||||
|
||||
switch objType {
|
||||
case ObjectTypeOfsDelta:
|
||||
var actualType int
|
||||
objData, pos, actualType, err = parseOfsDelta(data, pos, offset, objects)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
objType = actualType
|
||||
case ObjectTypeRefDelta:
|
||||
var actualType int
|
||||
objData, pos, actualType, err = parseRefDelta(data, pos, objects)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
objType = actualType
|
||||
case ObjectTypeCommit, ObjectTypeTree, ObjectTypeBlob, ObjectTypeTag:
|
||||
objData, pos, err = zlibDecompress(data, pos)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("unknown object type: %d", objType)
|
||||
}
|
||||
|
||||
hash, err := computeObjectHash(objType, objData)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return &ParsedObject{
|
||||
Type: objType,
|
||||
Size: size,
|
||||
Data: objData,
|
||||
Offset: offset,
|
||||
Hash: hash,
|
||||
}, pos, nil
|
||||
}
|
||||
|
||||
func parseOfsDelta(data []byte, pos int, currentOffset int, objects map[string]*ParsedObject) ([]byte, int, int, error) {
|
||||
b := data[pos]
|
||||
pos++
|
||||
offset := int(b & 0x7f)
|
||||
|
||||
for b&0x80 != 0 {
|
||||
offset++
|
||||
offset <<= 7
|
||||
b = data[pos]
|
||||
pos++
|
||||
offset += int(b & 0x7f)
|
||||
}
|
||||
|
||||
baseOffset := currentOffset - offset
|
||||
baseObject, _, err := parsePackObject(data, baseOffset, objects)
|
||||
if err != nil {
|
||||
return nil, 0, 0, fmt.Errorf("failed to parse base object at offset %d: %w", baseOffset, err)
|
||||
}
|
||||
|
||||
delta, newPos, err := zlibDecompress(data, pos)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
fullObj, err := applyDelta(delta, baseObject.Data)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
return fullObj, newPos, baseObject.Type, nil
|
||||
}
|
||||
|
||||
func parseRefDelta(data []byte, pos int, objects map[string]*ParsedObject) ([]byte, int, int, error) {
|
||||
baseName := hex.EncodeToString(data[pos : pos+20])
|
||||
pos += 20
|
||||
|
||||
delta, newPos, err := zlibDecompress(data, pos)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
baseObject, ok := objects[baseName]
|
||||
if !ok {
|
||||
return nil, 0, 0, fmt.Errorf("base object not found with name %s", baseName)
|
||||
}
|
||||
|
||||
fullObj, err := applyDelta(delta, baseObject.Data)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
return fullObj, newPos, baseObject.Type, nil
|
||||
}
|
||||
|
||||
func computeObjectHash(objType int, data []byte) (string, error) {
|
||||
var typeStr string
|
||||
switch objType {
|
||||
case ObjectTypeCommit:
|
||||
typeStr = "commit"
|
||||
case ObjectTypeTree:
|
||||
typeStr = "tree"
|
||||
case ObjectTypeBlob:
|
||||
typeStr = "blob"
|
||||
case ObjectTypeTag:
|
||||
typeStr = "tag"
|
||||
default:
|
||||
return "", fmt.Errorf("unknown type when computing object hash: %d", objType)
|
||||
}
|
||||
|
||||
header := fmt.Sprintf("%s %d\x00", typeStr, len(data))
|
||||
h := sha1.New()
|
||||
h.Write([]byte(header))
|
||||
h.Write(data)
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func applyDelta(delta []byte, base []byte) ([]byte, error) {
|
||||
pos := 0
|
||||
|
||||
_, bytesRead := readVariableInt(delta, pos)
|
||||
pos += bytesRead
|
||||
|
||||
resultSize, bytesRead := readVariableInt(delta, pos)
|
||||
pos += bytesRead
|
||||
|
||||
result := make([]byte, resultSize)
|
||||
resultOffset := 0
|
||||
|
||||
for pos < len(delta) {
|
||||
cmd := delta[pos]
|
||||
pos++
|
||||
|
||||
if cmd&0x80 != 0 {
|
||||
var copyOffset, copySize int
|
||||
|
||||
if cmd&0x01 != 0 {
|
||||
copyOffset = int(delta[pos])
|
||||
pos++
|
||||
}
|
||||
if cmd&0x02 != 0 {
|
||||
copyOffset |= int(delta[pos]) << 8
|
||||
pos++
|
||||
}
|
||||
if cmd&0x04 != 0 {
|
||||
copyOffset |= int(delta[pos]) << 16
|
||||
pos++
|
||||
}
|
||||
if cmd&0x08 != 0 {
|
||||
copyOffset |= int(delta[pos]) << 24
|
||||
pos++
|
||||
}
|
||||
|
||||
if cmd&0x10 != 0 {
|
||||
copySize = int(delta[pos])
|
||||
pos++
|
||||
}
|
||||
if cmd&0x20 != 0 {
|
||||
copySize |= int(delta[pos]) << 8
|
||||
pos++
|
||||
}
|
||||
if cmd&0x40 != 0 {
|
||||
copySize |= int(delta[pos]) << 16
|
||||
pos++
|
||||
}
|
||||
|
||||
if copySize == 0 {
|
||||
copySize = 0x10000
|
||||
}
|
||||
|
||||
copy(result[resultOffset:], base[copyOffset:copyOffset+copySize])
|
||||
resultOffset += copySize
|
||||
} else if cmd > 0 {
|
||||
copy(result[resultOffset:], delta[pos:pos+int(cmd)])
|
||||
pos += int(cmd)
|
||||
resultOffset += int(cmd)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid delta command")
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func zlibDecompress(data []byte, pos int) ([]byte, int, error) {
|
||||
br := bytes.NewReader(data[pos:])
|
||||
r, err := zlib.NewReader(br)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("zlib init error: %w", err)
|
||||
}
|
||||
|
||||
decompressed, err := io.ReadAll(r)
|
||||
r.Close()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("zlib decompress error: %w", err)
|
||||
}
|
||||
|
||||
newPos := len(data) - br.Len()
|
||||
return decompressed, newPos, nil
|
||||
}
|
||||
|
||||
func readVariableInt(data []byte, pos int) (int, int) {
|
||||
value := 0
|
||||
shift := 0
|
||||
bytesRead := 0
|
||||
|
||||
for {
|
||||
b := data[pos]
|
||||
pos++
|
||||
bytesRead++
|
||||
value |= int(b&0x7f) << shift
|
||||
shift += 7
|
||||
if b&0x80 == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return value, bytesRead
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type InfoRefsUploadPackResponse struct {
|
||||
Refs map[string]string
|
||||
Capabilities []string
|
||||
Symrefs map[string]string
|
||||
}
|
||||
|
||||
var capabilitiesCache sync.Map
|
||||
|
||||
func GetCapabilities(url string, existingInfo *InfoRefsUploadPackResponse) ([]string, error) {
|
||||
if existingInfo != nil {
|
||||
capabilitiesCache.Store(url, existingInfo.Capabilities)
|
||||
return existingInfo.Capabilities, nil
|
||||
}
|
||||
|
||||
if cached, ok := capabilitiesCache.Load(url); ok {
|
||||
return cached.([]string), nil
|
||||
}
|
||||
|
||||
info, err := GetInfoRefs(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
capabilitiesCache.Store(url, info.Capabilities)
|
||||
return info.Capabilities, nil
|
||||
}
|
||||
|
||||
func GetInfoRefs(url string) (*InfoRefsUploadPackResponse, error) {
|
||||
resp, err := http.Get(url + "/info/refs?service=git-upload-pack")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch info/refs: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read info/refs response: %w", err)
|
||||
}
|
||||
response := string(body)
|
||||
|
||||
result := &InfoRefsUploadPackResponse{
|
||||
Refs: make(map[string]string),
|
||||
Symrefs: make(map[string]string),
|
||||
}
|
||||
|
||||
lines := strings.Split(response, "\n")
|
||||
firstRef := true
|
||||
for _, line := range lines {
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "0000") {
|
||||
line = line[4:]
|
||||
}
|
||||
if len(line) < 4 {
|
||||
continue
|
||||
}
|
||||
|
||||
length, err := strconv.ParseInt(line[:4], 16, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
endIdx := int(length)
|
||||
if endIdx > len(line) {
|
||||
endIdx = len(line)
|
||||
}
|
||||
if endIdx <= 4 {
|
||||
continue
|
||||
}
|
||||
content := line[4:endIdx]
|
||||
|
||||
if firstRef && strings.HasPrefix(content, "# service=") {
|
||||
firstRef = false
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.Contains(content, " ") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(content, " ", 2)
|
||||
hash := parts[0]
|
||||
refAndCaps := parts[1]
|
||||
|
||||
if strings.Contains(refAndCaps, "\x00") {
|
||||
nulParts := strings.SplitN(refAndCaps, "\x00", 2)
|
||||
ref := strings.TrimSpace(nulParts[0])
|
||||
result.Refs[ref] = hash
|
||||
|
||||
caps := strings.Fields(nulParts[1])
|
||||
result.Capabilities = caps
|
||||
|
||||
for _, cap := range caps {
|
||||
if strings.HasPrefix(cap, "symref=") {
|
||||
symrefData := cap[7:]
|
||||
colonIdx := strings.Index(symrefData, ":")
|
||||
if colonIdx != -1 {
|
||||
result.Symrefs[symrefData[:colonIdx]] = symrefData[colonIdx+1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.Refs[strings.TrimSpace(refAndCaps)] = hash
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package gitnaturalapi
|
||||
|
||||
import "encoding/hex"
|
||||
|
||||
type TreeEntry struct {
|
||||
Path string
|
||||
Mode string
|
||||
IsDir bool
|
||||
Hash string
|
||||
}
|
||||
|
||||
type TreeFile struct {
|
||||
Name string
|
||||
Hash string
|
||||
Content []byte
|
||||
}
|
||||
|
||||
type TreeDirectory struct {
|
||||
Name string
|
||||
Hash string
|
||||
Content *Tree
|
||||
}
|
||||
|
||||
type Tree struct {
|
||||
Directories []TreeDirectory
|
||||
Files []TreeFile
|
||||
}
|
||||
|
||||
func LoadTree(obj *ParsedObject, objects map[string]*ParsedObject, depth *int) *Tree {
|
||||
directories := make([]TreeDirectory, 0)
|
||||
files := make([]TreeFile, 0)
|
||||
entries := ParseTree(obj.Data)
|
||||
|
||||
for _, entry := range entries {
|
||||
child := objects[entry.Hash]
|
||||
|
||||
if entry.IsDir {
|
||||
var content *Tree
|
||||
if child != nil && (depth == nil || *depth > 0) {
|
||||
var newDepth *int
|
||||
if depth != nil {
|
||||
d := *depth - 1
|
||||
newDepth = &d
|
||||
}
|
||||
content = LoadTree(child, objects, newDepth)
|
||||
}
|
||||
directories = append(directories, TreeDirectory{
|
||||
Name: entry.Path,
|
||||
Hash: entry.Hash,
|
||||
Content: content,
|
||||
})
|
||||
} else {
|
||||
var content []byte
|
||||
if child != nil {
|
||||
content = child.Data
|
||||
}
|
||||
files = append(files, TreeFile{
|
||||
Name: entry.Path,
|
||||
Hash: entry.Hash,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &Tree{Directories: directories, Files: files}
|
||||
}
|
||||
|
||||
func ParseTree(treeData []byte) []TreeEntry {
|
||||
entries := make([]TreeEntry, 0)
|
||||
offset := 0
|
||||
|
||||
for offset < len(treeData) {
|
||||
modeEnd := offset
|
||||
for treeData[modeEnd] != 0x20 {
|
||||
modeEnd++
|
||||
}
|
||||
mode := string(treeData[offset:modeEnd])
|
||||
offset = modeEnd + 1
|
||||
|
||||
filenameEnd := offset
|
||||
for treeData[filenameEnd] != 0x00 {
|
||||
filenameEnd++
|
||||
}
|
||||
path := string(treeData[offset:filenameEnd])
|
||||
offset = filenameEnd + 1
|
||||
|
||||
hash := hex.EncodeToString(treeData[offset : offset+20])
|
||||
offset += 20
|
||||
|
||||
isDir := mode == "40000" || mode == "040000"
|
||||
|
||||
entries = append(entries, TreeEntry{
|
||||
Mode: mode,
|
||||
Path: path,
|
||||
Hash: hash,
|
||||
IsDir: isDir,
|
||||
})
|
||||
}
|
||||
|
||||
return entries
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package grasp
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"fiatjaf.com/nostr/nip19"
|
||||
)
|
||||
|
||||
func IsGraspURL(u string) bool {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.Count(parsed.Path, "/") != 2 || len(parsed.Path) < 65 {
|
||||
return false
|
||||
}
|
||||
|
||||
if prefix, _, err := nip19.Decode(parsed.Path[1:64]); err != nil || prefix != "npub" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
+1
-1
@@ -112,7 +112,7 @@ func NewBunker(
|
||||
onAuth func(string),
|
||||
) *BunkerClient {
|
||||
if pool == nil {
|
||||
pool = nostr.NewPool(nostr.PoolOptions{})
|
||||
pool = nostr.NewPool()
|
||||
}
|
||||
|
||||
clientPublicKey := nostr.GetPublicKey(clientSecretKey)
|
||||
|
||||
@@ -67,7 +67,7 @@ func NewBunkerFromNostrConnect(
|
||||
pool *nostr.Pool,
|
||||
) (*BunkerClient, error) {
|
||||
if pool == nil {
|
||||
pool = nostr.NewPool(nostr.PoolOptions{})
|
||||
pool = nostr.NewPool()
|
||||
}
|
||||
|
||||
if len(relayURLs) == 0 {
|
||||
|
||||
+12
-8
@@ -11,20 +11,24 @@ import (
|
||||
)
|
||||
|
||||
func NormalizeIdentifier(name string) string {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
res, _, _ := transform.Bytes(norm.NFKC, []byte(name))
|
||||
runes := []rune(string(res))
|
||||
runes := []rune(strings.ToLower(string(res)))
|
||||
|
||||
b := make([]rune, len(runes))
|
||||
for i, letter := range runes {
|
||||
words := make([]string, 0, 3)
|
||||
word := make([]rune, 0, 12)
|
||||
for _, letter := range runes {
|
||||
if unicode.IsLetter(letter) || unicode.IsNumber(letter) {
|
||||
b[i] = letter
|
||||
} else {
|
||||
b[i] = '-'
|
||||
word = append(word, letter)
|
||||
} else if len(word) > 0 {
|
||||
words = append(words, string(word))
|
||||
word = make([]rune, 0, 12)
|
||||
}
|
||||
}
|
||||
if len(word) > 0 {
|
||||
words = append(words, string(word))
|
||||
}
|
||||
|
||||
return string(b)
|
||||
return strings.Join(words, "-")
|
||||
}
|
||||
|
||||
func ArticleAsHTML(content string) string {
|
||||
|
||||
+1
-1
@@ -13,7 +13,7 @@ func TestNormalization(t *testing.T) {
|
||||
}{
|
||||
{" hello ", "hello"},
|
||||
{"Goodbye", "goodbye"},
|
||||
{"the long and winding road / that leads to your door", "the-long-and-winding-road---that-leads-to-your-door"},
|
||||
{"the long and winding road / that leads to your door", "the-long-and-winding-road-that-leads-to-your-door"},
|
||||
{"it's 平仮名", "it-s-平仮名"},
|
||||
} {
|
||||
if norm := NormalizeIdentifier(vector.before); norm != vector.after {
|
||||
|
||||
+40
-8
@@ -94,6 +94,11 @@ meltworked:
|
||||
nil,
|
||||
)
|
||||
|
||||
// mark tokens as reserved before attempting melt
|
||||
for _, i := range chosen.tokenIndexes {
|
||||
w.Tokens[i].reserved = true
|
||||
}
|
||||
|
||||
// request from mint to _melt_ into paying the invoice
|
||||
delay := 200 * time.Millisecond
|
||||
// this request will block until the invoice is paid or it fails
|
||||
@@ -103,17 +108,44 @@ meltworked:
|
||||
Inputs: chosen.proofs,
|
||||
Outputs: preChange.bm,
|
||||
})
|
||||
inspectmeltstatusresponse:
|
||||
if err != nil || meltStatus.State == nut05.Unpaid {
|
||||
return "", fmt.Errorf("error melting token: %w", err)
|
||||
} else if meltStatus.State == nut05.Unknown {
|
||||
return "", fmt.Errorf("we don't know what happened with the melt at %s: %v", chosen.mint, meltStatus)
|
||||
} else if meltStatus.State == nut05.Pending {
|
||||
for {
|
||||
for {
|
||||
if err != nil || meltStatus.State == nut05.Unpaid {
|
||||
// unreserve tokens to available state on failure
|
||||
for _, i := range chosen.tokenIndexes {
|
||||
w.Tokens[i].reserved = false
|
||||
}
|
||||
return "", fmt.Errorf("error melting token: %w", err)
|
||||
} else if meltStatus.State == nut05.Unknown {
|
||||
// unreserve tokens to available state on failure
|
||||
for _, i := range chosen.tokenIndexes {
|
||||
w.Tokens[i].reserved = false
|
||||
}
|
||||
return "", fmt.Errorf("we don't know what happened with the melt at %s: %v", chosen.mint, meltStatus)
|
||||
} else if meltStatus.State == nut05.Pending {
|
||||
time.Sleep(delay)
|
||||
delay *= 2
|
||||
meltStatus, err = client.GetMeltQuoteState(ctx, chosen.mint, meltStatus.Quote)
|
||||
goto inspectmeltstatusresponse
|
||||
if err != nil {
|
||||
// unreserve tokens to available state on failure
|
||||
for _, i := range chosen.tokenIndexes {
|
||||
w.Tokens[i].reserved = false
|
||||
}
|
||||
return "", fmt.Errorf("error checking melt status: %w", err)
|
||||
}
|
||||
if meltStatus.State == nut05.Unpaid || meltStatus.State == nut05.Unknown {
|
||||
// unreserve tokens to available state on failure
|
||||
for _, i := range chosen.tokenIndexes {
|
||||
w.Tokens[i].reserved = false
|
||||
}
|
||||
return "", fmt.Errorf("melt failed with state %v", meltStatus.State)
|
||||
} else if meltStatus.State == nut05.Paid {
|
||||
// payment successful
|
||||
break
|
||||
}
|
||||
// continue looping for pending state
|
||||
continue
|
||||
} else if meltStatus.State == nut05.Paid {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -153,6 +153,9 @@ func (w *Wallet) getProofsForSending(
|
||||
) (chosenTokens, uint64, error) {
|
||||
byMint := make(map[string]chosenTokens)
|
||||
for t, token := range w.Tokens {
|
||||
if token.reserved {
|
||||
continue
|
||||
}
|
||||
if fromMint != "" && token.Mint != fromMint {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type Token struct {
|
||||
Proofs cashu.Proofs `json:"proofs"`
|
||||
Deleted []nostr.ID `json:"del,omitempty"`
|
||||
|
||||
reserved bool
|
||||
mintedAt nostr.Timestamp
|
||||
event *nostr.Event
|
||||
}
|
||||
|
||||
@@ -249,6 +249,10 @@ func (w *Wallet) removeDeletedToken(eventId nostr.ID) {
|
||||
func (w *Wallet) Balance() uint64 {
|
||||
var sum uint64
|
||||
for _, token := range w.Tokens {
|
||||
if token.reserved {
|
||||
continue
|
||||
}
|
||||
|
||||
sum += token.Proofs.Amount()
|
||||
}
|
||||
return sum
|
||||
|
||||
@@ -66,13 +66,13 @@ func (bw *BoundWriter) WriteTimestamp(w *bytes.Buffer, timestamp nostr.Timestamp
|
||||
bw.lastTimestampOut = timestamp
|
||||
|
||||
// add 1 to prevent zeroes from being read as infinites
|
||||
WriteVarInt(w, int(delta+1))
|
||||
WriteVarInt(w, uint64(delta)+1)
|
||||
return
|
||||
}
|
||||
|
||||
func (bw *BoundWriter) WriteBound(w *bytes.Buffer, bound Bound) {
|
||||
bw.WriteTimestamp(w, bound.Timestamp)
|
||||
WriteVarInt(w, len(bound.IDPrefix))
|
||||
WriteVarInt(w, uint64(len(bound.IDPrefix)))
|
||||
w.Write(bound.IDPrefix)
|
||||
}
|
||||
|
||||
@@ -111,33 +111,25 @@ func ReadVarInt(reader *bytes.Reader) (int, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func WriteVarInt(w *bytes.Buffer, n int) {
|
||||
func WriteVarInt(w *bytes.Buffer, n uint64) {
|
||||
if n == 0 {
|
||||
w.WriteByte(0)
|
||||
return
|
||||
}
|
||||
|
||||
w.Write(EncodeVarInt(n))
|
||||
}
|
||||
|
||||
func EncodeVarInt(n int) []byte {
|
||||
if n == 0 {
|
||||
return []byte{0}
|
||||
}
|
||||
|
||||
result := make([]byte, 8)
|
||||
idx := 7
|
||||
var buf [10]byte
|
||||
idx := 9
|
||||
|
||||
for n != 0 {
|
||||
result[idx] = byte(n & 0x7F)
|
||||
buf[idx] = byte(n & 0x7F)
|
||||
n >>= 7
|
||||
idx--
|
||||
}
|
||||
|
||||
result = result[idx+1:]
|
||||
result := buf[idx+1:]
|
||||
for i := 0; i < len(result)-1; i++ {
|
||||
result[i] |= 0x80
|
||||
}
|
||||
|
||||
return result
|
||||
w.Write(result)
|
||||
}
|
||||
|
||||
@@ -230,7 +230,7 @@ func (n *Negentropy) reconcileAux(reader *bytes.Reader) ([]byte, error) {
|
||||
finishSkip()
|
||||
|
||||
responseIds := make([]byte, 0, 32*100)
|
||||
responses := 0
|
||||
var responses uint64 = 0
|
||||
|
||||
endBound := currBound
|
||||
|
||||
@@ -284,7 +284,7 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *byte
|
||||
// we just send the full ids here
|
||||
n.WriteBound(output, upperBound)
|
||||
output.WriteByte(byte(IdListMode))
|
||||
WriteVarInt(output, numElems)
|
||||
WriteVarInt(output, uint64(numElems))
|
||||
|
||||
for _, item := range n.storage.Range(lower, upper) {
|
||||
output.Write(item.ID[:])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
|
||||
@@ -41,8 +42,8 @@ func (acc *Accumulator) AddBytes(other []byte) {
|
||||
}
|
||||
|
||||
func (acc *Accumulator) GetFingerprint(n int) [negentropy.FingerprintSize]byte {
|
||||
input := acc.Buf[:32]
|
||||
input = append(input, negentropy.EncodeVarInt(n)...)
|
||||
hash := sha256.Sum256(input)
|
||||
input := bytes.NewBuffer(acc.Buf[:32])
|
||||
negentropy.WriteVarInt(input, uint64(n))
|
||||
hash := sha256.Sum256(input.Bytes())
|
||||
return [negentropy.FingerprintSize]byte(hash[:negentropy.FingerprintSize])
|
||||
}
|
||||
|
||||
+16
-6
@@ -128,17 +128,27 @@ func NegentropySync(
|
||||
})
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
wg.Wait()
|
||||
errch <- nil
|
||||
select {
|
||||
case errch <- nil:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
err = <-errch
|
||||
if err != nil {
|
||||
return err
|
||||
select {
|
||||
case err = <-errch:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SyncEventsFromIDs(ctx context.Context, dir Direction) {
|
||||
|
||||
@@ -32,6 +32,24 @@ func DecodeRequest(req Request) (MethodParams, error) {
|
||||
return BanPubKey{pk, reason}, nil
|
||||
case "listbannedpubkeys":
|
||||
return ListBannedPubKeys{}, nil
|
||||
case "unbanpubkey":
|
||||
if len(req.Params) == 0 {
|
||||
return nil, fmt.Errorf("invalid number of params for '%s'", req.Method)
|
||||
}
|
||||
pkh, ok := req.Params[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing pubkey param for '%s'", req.Method)
|
||||
}
|
||||
pk, err := nostr.PubKeyFromHex(pkh)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pubkey param for '%s'", req.Method)
|
||||
}
|
||||
|
||||
var reason string
|
||||
if len(req.Params) >= 2 {
|
||||
reason, _ = req.Params[1].(string)
|
||||
}
|
||||
return UnbanPubKey{pk, reason}, nil
|
||||
case "allowpubkey":
|
||||
if len(req.Params) == 0 {
|
||||
return nil, fmt.Errorf("invalid number of params for '%s'", req.Method)
|
||||
@@ -52,6 +70,24 @@ func DecodeRequest(req Request) (MethodParams, error) {
|
||||
return AllowPubKey{pk, reason}, nil
|
||||
case "listallowedpubkeys":
|
||||
return ListAllowedPubKeys{}, nil
|
||||
case "unallowpubkey":
|
||||
if len(req.Params) == 0 {
|
||||
return nil, fmt.Errorf("invalid number of params for '%s'", req.Method)
|
||||
}
|
||||
pkh, ok := req.Params[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing pubkey param for '%s'", req.Method)
|
||||
}
|
||||
pk, err := nostr.PubKeyFromHex(pkh)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pubkey param for '%s'", req.Method)
|
||||
}
|
||||
|
||||
var reason string
|
||||
if len(req.Params) >= 2 {
|
||||
reason, _ = req.Params[1].(string)
|
||||
}
|
||||
return UnallowPubKey{pk, reason}, nil
|
||||
case "listeventsneedingmoderation":
|
||||
return ListEventsNeedingModeration{}, nil
|
||||
case "allowevent":
|
||||
@@ -219,8 +255,10 @@ var (
|
||||
_ MethodParams = (*SupportedMethods)(nil)
|
||||
_ MethodParams = (*BanPubKey)(nil)
|
||||
_ MethodParams = (*ListBannedPubKeys)(nil)
|
||||
_ MethodParams = (*UnbanPubKey)(nil)
|
||||
_ MethodParams = (*AllowPubKey)(nil)
|
||||
_ MethodParams = (*ListAllowedPubKeys)(nil)
|
||||
_ MethodParams = (*UnallowPubKey)(nil)
|
||||
_ MethodParams = (*ListEventsNeedingModeration)(nil)
|
||||
_ MethodParams = (*AllowEvent)(nil)
|
||||
_ MethodParams = (*BanEvent)(nil)
|
||||
@@ -256,6 +294,13 @@ type ListBannedPubKeys struct{}
|
||||
|
||||
func (ListBannedPubKeys) MethodName() string { return "listbannedpubkeys" }
|
||||
|
||||
type UnbanPubKey struct {
|
||||
PubKey nostr.PubKey
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (UnbanPubKey) MethodName() string { return "unbanpubkey" }
|
||||
|
||||
type AllowPubKey struct {
|
||||
PubKey nostr.PubKey
|
||||
Reason string
|
||||
@@ -267,6 +312,13 @@ type ListAllowedPubKeys struct{}
|
||||
|
||||
func (ListAllowedPubKeys) MethodName() string { return "listallowedpubkeys" }
|
||||
|
||||
type UnallowPubKey struct {
|
||||
PubKey nostr.PubKey
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (UnallowPubKey) MethodName() string { return "unallowpubkey" }
|
||||
|
||||
type ListEventsNeedingModeration struct{}
|
||||
|
||||
func (ListEventsNeedingModeration) MethodName() string { return "listeventsneedingmoderation" }
|
||||
|
||||
@@ -58,6 +58,12 @@ func (c *Client) httpCall(
|
||||
}
|
||||
if resp.Header.StatusCode() >= 300 {
|
||||
reason := resp.Header.Peek("X-Reason")
|
||||
if len(reason) == 0 {
|
||||
reason = resp.Body()
|
||||
if len(reason) > 200 {
|
||||
reason = append(reason[0:199], []byte("…")...)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s returned an error (%d): %s", url, resp.StatusCode(), string(reason))
|
||||
}
|
||||
|
||||
|
||||
+11
-3
@@ -9,6 +9,7 @@ func GetExtension(mimetype string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// hardcode some common cases (abd jbiwb oribkenatuc cases kuje ,ogg/.oga or .mov/.moov)
|
||||
switch mimetype {
|
||||
case "image/jpeg":
|
||||
return ".jpg"
|
||||
@@ -22,13 +23,20 @@ func GetExtension(mimetype string) string {
|
||||
return ".mp4"
|
||||
case "application/vnd.android.package-archive":
|
||||
return ".apk"
|
||||
case "video/quicktime":
|
||||
return ".mov"
|
||||
case "application/vnd.sqlite3":
|
||||
return "sqlite3"
|
||||
case "text/markdown":
|
||||
return "md"
|
||||
case "audio/midi":
|
||||
return "midi"
|
||||
case "audio/x-aiff":
|
||||
return "aiff"
|
||||
}
|
||||
|
||||
exts, _ := mime.ExtensionsByType(mimetype)
|
||||
if len(exts) > 0 {
|
||||
if exts[0] == ".moov" {
|
||||
return ".mov"
|
||||
}
|
||||
return exts[0]
|
||||
}
|
||||
|
||||
|
||||
@@ -27,14 +27,13 @@ type Pool struct {
|
||||
authRequiredHandler func(context.Context, *Event) error
|
||||
cancel context.CancelCauseFunc
|
||||
|
||||
eventMiddleware func(RelayEvent)
|
||||
duplicateMiddleware func(relay string, id ID)
|
||||
queryMiddleware func(relay string, pubkey PubKey, kind Kind)
|
||||
relayOptions RelayOptions
|
||||
EventMiddleware func(RelayEvent)
|
||||
DuplicateMiddleware func(relay string, id ID)
|
||||
QueryMiddleware func(relay string, pubkey PubKey, kind Kind)
|
||||
RelayOptions RelayOptions
|
||||
|
||||
// custom things not often used
|
||||
penaltyBoxMu sync.Mutex
|
||||
penaltyBox map[string][2]float64
|
||||
penaltyBox *xsync.MapOf[string, [2]float64]
|
||||
}
|
||||
|
||||
// DirectedFilter combines a Filter with a specific relay URL.
|
||||
@@ -50,27 +49,15 @@ func (df DirectedFilter) String() string {
|
||||
func (ie RelayEvent) String() string { return fmt.Sprintf("[%s] >> %s", ie.Relay.URL, ie.Event) }
|
||||
|
||||
// NewPool creates a new Pool with the given context and options.
|
||||
func NewPool(opts PoolOptions) *Pool {
|
||||
func NewPool() *Pool {
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
|
||||
pool := &Pool{
|
||||
return &Pool{
|
||||
Relays: xsync.NewMapOf[string, *Relay](),
|
||||
|
||||
Context: ctx,
|
||||
cancel: cancel,
|
||||
|
||||
authRequiredHandler: opts.AuthRequiredHandler,
|
||||
eventMiddleware: opts.EventMiddleware,
|
||||
duplicateMiddleware: opts.DuplicateMiddleware,
|
||||
queryMiddleware: opts.AuthorKindQueryMiddleware,
|
||||
relayOptions: opts.RelayOptions,
|
||||
}
|
||||
|
||||
if opts.PenaltyBox {
|
||||
go pool.startPenaltyBox()
|
||||
}
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
type PoolOptions struct {
|
||||
@@ -98,36 +85,50 @@ type PoolOptions struct {
|
||||
RelayOptions RelayOptions
|
||||
}
|
||||
|
||||
func (pool *Pool) startPenaltyBox() {
|
||||
pool.penaltyBox = make(map[string][2]float64)
|
||||
func (pool *Pool) StartPenaltyBox() {
|
||||
pool.penaltyBox = xsync.NewMapOf[string, [2]float64]()
|
||||
|
||||
go func() {
|
||||
sleep := 30.0
|
||||
for {
|
||||
time.Sleep(time.Duration(sleep) * time.Second)
|
||||
select {
|
||||
case <-pool.Context.Done():
|
||||
return
|
||||
case <-time.After(time.Duration(sleep) * time.Second):
|
||||
|
||||
pool.penaltyBoxMu.Lock()
|
||||
nextSleep := 300.0
|
||||
for url, v := range pool.penaltyBox {
|
||||
remainingSeconds := v[1]
|
||||
remainingSeconds -= sleep
|
||||
if remainingSeconds <= 0 {
|
||||
pool.penaltyBox[url] = [2]float64{v[0], 0}
|
||||
continue
|
||||
} else {
|
||||
pool.penaltyBox[url] = [2]float64{v[0], remainingSeconds}
|
||||
nextSleep := 300.0
|
||||
for url, v := range pool.penaltyBox.Range {
|
||||
remainingSeconds := v[1]
|
||||
remainingSeconds -= sleep
|
||||
if remainingSeconds <= 0 {
|
||||
pool.penaltyBox.Store(url, [2]float64{v[0], 0})
|
||||
continue
|
||||
} else {
|
||||
pool.penaltyBox.Store(url, [2]float64{v[0], remainingSeconds})
|
||||
}
|
||||
|
||||
if remainingSeconds < nextSleep {
|
||||
nextSleep = remainingSeconds
|
||||
}
|
||||
}
|
||||
|
||||
if remainingSeconds < nextSleep {
|
||||
nextSleep = remainingSeconds
|
||||
}
|
||||
sleep = nextSleep
|
||||
}
|
||||
|
||||
sleep = nextSleep
|
||||
pool.penaltyBoxMu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// AddToPenaltyBox manually adds a relay to the penalty box for the specified duration.
|
||||
// This prevents EnsureRelay from attempting to connect to the relay until the duration expires.
|
||||
func (pool *Pool) AddToPenaltyBox(url string, duration time.Duration) {
|
||||
if pool.penaltyBox == nil {
|
||||
return
|
||||
}
|
||||
nm := NormalizeURL(url)
|
||||
pool.penaltyBox.Store(nm, [2]float64{0, duration.Seconds()})
|
||||
pool.Relays.Store(nm, nil) // mark as explicitly disconnected for penalty box detection
|
||||
}
|
||||
|
||||
// EnsureRelay ensures that a relay connection exists and is active.
|
||||
// If the relay is not connected, it attempts to connect.
|
||||
func (pool *Pool) EnsureRelay(url string) (*Relay, error) {
|
||||
@@ -137,9 +138,7 @@ func (pool *Pool) EnsureRelay(url string) (*Relay, error) {
|
||||
relay, ok := pool.Relays.Load(nm)
|
||||
if ok && relay == nil {
|
||||
if pool.penaltyBox != nil {
|
||||
pool.penaltyBoxMu.Lock()
|
||||
defer pool.penaltyBoxMu.Unlock()
|
||||
v, _ := pool.penaltyBox[nm]
|
||||
v, _ := pool.penaltyBox.Load(nm)
|
||||
if v[1] > 0 {
|
||||
return nil, fmt.Errorf("in penalty box, %fs remaining", v[1])
|
||||
}
|
||||
@@ -149,21 +148,27 @@ func (pool *Pool) EnsureRelay(url string) (*Relay, error) {
|
||||
return relay, nil
|
||||
}
|
||||
|
||||
relay = NewRelay(pool.Context, url, pool.relayOptions)
|
||||
relay = NewRelay(pool.Context, url, pool.RelayOptions)
|
||||
// try to connect
|
||||
// we use this ctx here so when the pool dies everything dies
|
||||
if err := relay.Connect(pool.Context); err != nil {
|
||||
if pool.penaltyBox != nil {
|
||||
// putting relay in penalty box
|
||||
pool.penaltyBoxMu.Lock()
|
||||
defer pool.penaltyBoxMu.Unlock()
|
||||
v, _ := pool.penaltyBox[nm]
|
||||
pool.penaltyBox[nm] = [2]float64{v[0] + 1, 30.0 + math.Pow(2, v[0]+1)}
|
||||
pool.penaltyBox.Compute(nm, func(v [2]float64, loaded bool) (newV [2]float64, delete bool) {
|
||||
return [2]float64{v[0] + 1, 30.0 + math.Pow(2, v[0]+1)}, false
|
||||
})
|
||||
pool.Relays.Store(nm, nil) // this is important for penalty box detection on EnsureRelay
|
||||
}
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
pool.Relays.Store(nm, relay)
|
||||
go func(r *Relay, relayURL string) {
|
||||
<-r.Context().Done()
|
||||
if current, ok := pool.Relays.Load(relayURL); ok && current == r {
|
||||
pool.Relays.Delete(relayURL)
|
||||
}
|
||||
}(relay, nm)
|
||||
return relay, nil
|
||||
}
|
||||
|
||||
@@ -275,8 +280,8 @@ func (pool *Pool) fetchMany(
|
||||
if opts.CheckDuplicate == nil {
|
||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
||||
if exists && pool.duplicateMiddleware != nil {
|
||||
pool.duplicateMiddleware(relay, id)
|
||||
if exists && pool.DuplicateMiddleware != nil {
|
||||
pool.DuplicateMiddleware(relay, id)
|
||||
}
|
||||
return exists
|
||||
}
|
||||
@@ -357,7 +362,7 @@ func (pool *Pool) FetchManyReplaceable(
|
||||
go func(nm string) {
|
||||
defer wg.Done()
|
||||
|
||||
if mh := pool.queryMiddleware; mh != nil {
|
||||
if mh := pool.QueryMiddleware; mh != nil {
|
||||
if filter.Kinds != nil && filter.Authors != nil {
|
||||
for _, kind := range filter.Kinds {
|
||||
for _, author := range filter.Authors {
|
||||
@@ -405,7 +410,7 @@ func (pool *Pool) FetchManyReplaceable(
|
||||
}
|
||||
|
||||
ie := RelayEvent{Event: evt, Relay: relay}
|
||||
if mh := pool.eventMiddleware; mh != nil {
|
||||
if mh := pool.EventMiddleware; mh != nil {
|
||||
mh(ie)
|
||||
}
|
||||
|
||||
@@ -448,51 +453,55 @@ func (pool *Pool) subMany(
|
||||
if opts.CheckDuplicate == nil {
|
||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||
_, exists := seenAlready.LoadOrStore(id, Now())
|
||||
if exists && pool.duplicateMiddleware != nil {
|
||||
pool.duplicateMiddleware(relay, id)
|
||||
if exists && pool.DuplicateMiddleware != nil {
|
||||
pool.DuplicateMiddleware(relay, id)
|
||||
}
|
||||
return exists
|
||||
}
|
||||
}
|
||||
|
||||
pending := xsync.NewCounter()
|
||||
pending.Add(int64(len(urls)))
|
||||
pendingWg := sync.WaitGroup{}
|
||||
pendingWg.Add(len(urls))
|
||||
|
||||
go func() {
|
||||
pendingWg.Wait()
|
||||
close(events)
|
||||
cancel(fmt.Errorf("aborted: %w", context.Cause(ctx)))
|
||||
if closedChan != nil {
|
||||
close(closedChan)
|
||||
}
|
||||
}()
|
||||
|
||||
for i, url := range urls {
|
||||
url = NormalizeURL(url)
|
||||
urls[i] = url
|
||||
if idx := slices.Index(urls, url); idx != i {
|
||||
// skip duplicate relays in the list
|
||||
eoseWg.Done()
|
||||
pending.Dec()
|
||||
pendingWg.Done()
|
||||
continue
|
||||
}
|
||||
|
||||
eosed := atomic.Bool{}
|
||||
|
||||
go func(nm string) {
|
||||
go func(nm string, filter Filter) {
|
||||
defer func() {
|
||||
pending.Dec()
|
||||
if pending.Value() == 0 {
|
||||
close(events)
|
||||
cancel(fmt.Errorf("aborted: %w", context.Cause(ctx)))
|
||||
}
|
||||
if eosed.CompareAndSwap(false, true) {
|
||||
eoseWg.Done()
|
||||
}
|
||||
pendingWg.Done()
|
||||
}()
|
||||
|
||||
hasAuthed := false
|
||||
interval := 3 * time.Second
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
var sub *Subscription
|
||||
|
||||
if mh := pool.queryMiddleware; mh != nil {
|
||||
if mh := pool.QueryMiddleware; mh != nil {
|
||||
if filter.Kinds != nil && filter.Authors != nil {
|
||||
for _, kind := range filter.Kinds {
|
||||
for _, author := range filter.Authors {
|
||||
@@ -542,7 +551,7 @@ func (pool *Pool) subMany(
|
||||
}
|
||||
|
||||
ie := RelayEvent{Event: evt, Relay: relay}
|
||||
if mh := pool.eventMiddleware; mh != nil {
|
||||
if mh := pool.EventMiddleware; mh != nil {
|
||||
mh(ie)
|
||||
}
|
||||
|
||||
@@ -567,10 +576,13 @@ func (pool *Pool) subMany(
|
||||
if err == nil {
|
||||
hasAuthed = true // so we don't keep doing AUTH again and again
|
||||
if closedChan != nil {
|
||||
closedChan <- RelayClosed{
|
||||
select {
|
||||
case closedChan <- RelayClosed{
|
||||
Reason: reason,
|
||||
Relay: relay,
|
||||
HandledAuth: true,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
goto subscribe
|
||||
@@ -578,9 +590,12 @@ func (pool *Pool) subMany(
|
||||
}
|
||||
debugLogf("CLOSED from %s: '%s'\n", nm, reason)
|
||||
if closedChan != nil {
|
||||
closedChan <- RelayClosed{
|
||||
select {
|
||||
case closedChan <- RelayClosed{
|
||||
Reason: reason,
|
||||
Relay: relay,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
@@ -597,7 +612,7 @@ func (pool *Pool) subMany(
|
||||
time.Sleep(interval)
|
||||
interval = min(10*time.Minute, interval*17/10) // the next time we try we will wait longer
|
||||
}
|
||||
}(url)
|
||||
}(url, filter)
|
||||
}
|
||||
|
||||
return events
|
||||
@@ -621,13 +636,16 @@ func (pool *Pool) subManyEose(
|
||||
wg.Wait()
|
||||
cancel(errors.New("all subscriptions ended"))
|
||||
close(events)
|
||||
if closedChan != nil {
|
||||
close(closedChan)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, url := range urls {
|
||||
go func(nm string) {
|
||||
defer wg.Done()
|
||||
|
||||
if mh := pool.queryMiddleware; mh != nil {
|
||||
if mh := pool.QueryMiddleware; mh != nil {
|
||||
if filter.Kinds != nil && filter.Authors != nil {
|
||||
for _, kind := range filter.Kinds {
|
||||
for _, author := range filter.Authors {
|
||||
@@ -665,10 +683,13 @@ func (pool *Pool) subManyEose(
|
||||
if err == nil {
|
||||
hasAuthed = true // so we don't keep doing AUTH again and again
|
||||
if closedChan != nil {
|
||||
closedChan <- RelayClosed{
|
||||
select {
|
||||
case closedChan <- RelayClosed{
|
||||
Relay: relay,
|
||||
Reason: reason,
|
||||
HandledAuth: true,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
goto subscribe
|
||||
@@ -676,9 +697,12 @@ func (pool *Pool) subManyEose(
|
||||
}
|
||||
debugLogf("[pool] CLOSED from %s: '%s'\n", nm, reason)
|
||||
if closedChan != nil {
|
||||
closedChan <- RelayClosed{
|
||||
select {
|
||||
case closedChan <- RelayClosed{
|
||||
Relay: relay,
|
||||
Reason: reason,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -688,7 +712,7 @@ func (pool *Pool) subManyEose(
|
||||
}
|
||||
|
||||
ie := RelayEvent{Event: evt, Relay: relay}
|
||||
if mh := pool.eventMiddleware; mh != nil {
|
||||
if mh := pool.EventMiddleware; mh != nil {
|
||||
mh(ie)
|
||||
}
|
||||
|
||||
@@ -783,21 +807,40 @@ func (pool *Pool) batchedQueryMany(
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(dfs))
|
||||
seenAlready := xsync.NewMapOf[ID, struct{}]()
|
||||
forwardWg := sync.WaitGroup{}
|
||||
|
||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
||||
if exists && pool.duplicateMiddleware != nil {
|
||||
pool.duplicateMiddleware(relay, id)
|
||||
if exists && pool.DuplicateMiddleware != nil {
|
||||
pool.DuplicateMiddleware(relay, id)
|
||||
}
|
||||
return exists
|
||||
}
|
||||
|
||||
for _, df := range dfs {
|
||||
go func(df DirectedFilter) {
|
||||
var innerClosed chan RelayClosed
|
||||
if closedChan != nil {
|
||||
innerClosed = make(chan RelayClosed)
|
||||
forwardWg.Add(1)
|
||||
go func() {
|
||||
defer forwardWg.Done()
|
||||
for rc := range innerClosed {
|
||||
select {
|
||||
case closedChan <- rc:
|
||||
case <-ctx.Done():
|
||||
for range innerClosed {
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for ie := range pool.subManyEose(ctx,
|
||||
[]string{df.Relay},
|
||||
df.Filter,
|
||||
closedChan,
|
||||
innerClosed,
|
||||
opts,
|
||||
) {
|
||||
select {
|
||||
@@ -814,6 +857,10 @@ func (pool *Pool) batchedQueryMany(
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(res)
|
||||
if closedChan != nil {
|
||||
forwardWg.Wait()
|
||||
close(closedChan)
|
||||
}
|
||||
}()
|
||||
|
||||
return res
|
||||
@@ -849,22 +896,41 @@ func (pool *Pool) batchedSubscribeMany(
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(dfs))
|
||||
seenAlready := xsync.NewMapOf[ID, struct{}]()
|
||||
forwardWg := sync.WaitGroup{}
|
||||
|
||||
opts.CheckDuplicate = func(id ID, relay string) bool {
|
||||
_, exists := seenAlready.LoadOrStore(id, struct{}{})
|
||||
if exists && pool.duplicateMiddleware != nil {
|
||||
pool.duplicateMiddleware(relay, id)
|
||||
if exists && pool.DuplicateMiddleware != nil {
|
||||
pool.DuplicateMiddleware(relay, id)
|
||||
}
|
||||
return exists
|
||||
}
|
||||
|
||||
for _, df := range dfs {
|
||||
go func(df DirectedFilter) {
|
||||
var innerClosed chan RelayClosed
|
||||
if closedChan != nil {
|
||||
innerClosed = make(chan RelayClosed)
|
||||
forwardWg.Add(1)
|
||||
go func() {
|
||||
defer forwardWg.Done()
|
||||
for rc := range innerClosed {
|
||||
select {
|
||||
case closedChan <- rc:
|
||||
case <-ctx.Done():
|
||||
for range innerClosed {
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for ie := range pool.subMany(ctx,
|
||||
[]string{df.Relay},
|
||||
df.Filter,
|
||||
nil,
|
||||
closedChan,
|
||||
innerClosed,
|
||||
opts,
|
||||
) {
|
||||
select {
|
||||
@@ -881,6 +947,10 @@ func (pool *Pool) batchedSubscribeMany(
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(res)
|
||||
if closedChan != nil {
|
||||
forwardWg.Wait()
|
||||
close(closedChan)
|
||||
}
|
||||
}()
|
||||
|
||||
return res
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -20,18 +24,37 @@ import (
|
||||
|
||||
var subscriptionIDCounter atomic.Int64
|
||||
|
||||
var (
|
||||
ErrDisconnected = errors.New("<disconnected>")
|
||||
ErrPingFailed = errors.New("<ping failed>")
|
||||
)
|
||||
|
||||
type writeRequest struct {
|
||||
msg []byte
|
||||
answer chan error
|
||||
}
|
||||
|
||||
type closeCause struct {
|
||||
code ws.StatusCode
|
||||
reason string
|
||||
}
|
||||
|
||||
func (c closeCause) Error() string {
|
||||
if c.reason == "" {
|
||||
return "relay closed"
|
||||
}
|
||||
return c.reason
|
||||
}
|
||||
|
||||
// Relay represents a connection to a Nostr relay.
|
||||
type Relay struct {
|
||||
closeMutex sync.Mutex
|
||||
|
||||
URL string
|
||||
requestHeader http.Header // e.g. for origin header
|
||||
|
||||
// websocket connection
|
||||
conn *ws.Conn
|
||||
writeQueue chan writeRequest
|
||||
closed *atomic.Bool
|
||||
closedNotify chan struct{}
|
||||
conn *ws.Conn
|
||||
writeQueue chan writeRequest
|
||||
closed *atomic.Bool
|
||||
|
||||
Subscriptions *xsync.MapOf[int64, *Subscription]
|
||||
|
||||
@@ -39,7 +62,10 @@ type Relay struct {
|
||||
connectionContext context.Context // will be canceled when the connection closes
|
||||
connectionContextCancel context.CancelCauseFunc
|
||||
|
||||
challenge string // NIP-42 challenge, we only keep the last
|
||||
challenge string // NIP-42 challenge, we only keep the last
|
||||
performAuth sync.Once
|
||||
authed bool
|
||||
|
||||
authHandler func(context.Context, *Relay, *Event) error
|
||||
noticeHandler func(*Relay, string) // NIP-01 NOTICEs
|
||||
customHandler func(string) // nonstandard unparseable messages
|
||||
@@ -66,8 +92,32 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
|
||||
customHandler: opts.CustomHandler,
|
||||
noticeHandler: opts.NoticeHandler,
|
||||
authHandler: opts.AuthHandler,
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
if wasClosed := r.closed.Swap(true); wasClosed {
|
||||
return
|
||||
}
|
||||
|
||||
if r.conn != nil {
|
||||
cause := context.Cause(ctx)
|
||||
code := ws.StatusNormalClosure
|
||||
reason := ""
|
||||
var cc closeCause
|
||||
if errors.As(cause, &cc) {
|
||||
code = cc.code
|
||||
reason = cc.reason
|
||||
} else if cause != nil {
|
||||
reason = cause.Error()
|
||||
}
|
||||
|
||||
_ = r.conn.Close(code, reason)
|
||||
}
|
||||
}()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -109,7 +159,18 @@ func (r *Relay) String() string {
|
||||
func (r *Relay) Context() context.Context { return r.connectionContext }
|
||||
|
||||
// IsConnected returns true if the connection to this relay seems to be active.
|
||||
func (r *Relay) IsConnected() bool { return !r.closed.Load() }
|
||||
func (r *Relay) IsConnected() bool {
|
||||
if r.closed.Load() {
|
||||
return false
|
||||
}
|
||||
if r.conn == nil {
|
||||
return false
|
||||
}
|
||||
if r.connectionContext == nil {
|
||||
return false
|
||||
}
|
||||
return r.connectionContext.Err() == nil
|
||||
}
|
||||
|
||||
// Connect tries to establish a websocket connection to r.URL.
|
||||
// If the context expires before the connection is complete, an error is returned.
|
||||
@@ -136,6 +197,9 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
|
||||
if r.connectionContext == nil || r.Subscriptions == nil {
|
||||
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
|
||||
}
|
||||
if r.connectionContext.Err() != nil {
|
||||
return fmt.Errorf("relay context canceled")
|
||||
}
|
||||
|
||||
if r.URL == "" {
|
||||
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
||||
@@ -148,6 +212,128 @@ func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
|
||||
debugLogf("{%s} connecting!\n", r.URL)
|
||||
|
||||
dialCtx := ctx
|
||||
if _, ok := dialCtx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||
}
|
||||
|
||||
dialOpts := &ws.DialOptions{
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
|
||||
},
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
for k, v := range r.requestHeader {
|
||||
dialOpts.HTTPHeader[k] = v
|
||||
}
|
||||
|
||||
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.SetReadLimit(2 << 24) // 33MB
|
||||
|
||||
// ping every 19 seconds
|
||||
ticker := time.NewTicker(19 * time.Second)
|
||||
|
||||
// main websocket loop
|
||||
readQueue := make(chan string, 64 /* add some buffer to account for processing/IO mismatches */)
|
||||
|
||||
r.conn = c
|
||||
r.writeQueue = make(chan writeRequest, 64 /* idem */)
|
||||
r.closed = &atomic.Bool{}
|
||||
|
||||
connCtx := r.connectionContext
|
||||
go func() {
|
||||
pingAttempt := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
debugLogf("{%s} pinging\n", r.URL)
|
||||
pingCtx, cancel := context.WithTimeoutCause(connCtx, time.Millisecond*800, errors.New("ping took too long"))
|
||||
err := c.Ping(pingCtx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
pingAttempt++
|
||||
debugLogf("{%s} error writing ping (attempt %d): %v", r.URL, pingAttempt, err)
|
||||
|
||||
if pingAttempt >= 3 {
|
||||
debugLogf("{%s} error writing ping after multiple attempts; closing websocket", r.URL)
|
||||
_ = r.close(ErrPingFailed)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// ping was OK
|
||||
debugLogf("{%s} ping OK", r.URL)
|
||||
pingAttempt = 0
|
||||
case wr := <-r.writeQueue:
|
||||
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
|
||||
writeCtx, cancel := context.WithTimeoutCause(connCtx, time.Second*10, errors.New("write took too long"))
|
||||
err := c.Write(writeCtx, ws.MessageText, wr.msg)
|
||||
cancel()
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
|
||||
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "write failed"})
|
||||
if wr.answer != nil {
|
||||
wr.answer <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
if wr.answer != nil {
|
||||
close(wr.answer)
|
||||
}
|
||||
case msg := <-readQueue:
|
||||
debugLogf("{%s} received %v\n", r.URL, msg)
|
||||
r.handleMessage(msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// read loop -- loops back to the main loop
|
||||
go func() {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
for {
|
||||
buf.Reset()
|
||||
|
||||
_, reader, err := c.Reader(connCtx)
|
||||
if err != nil {
|
||||
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
|
||||
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to get reader"})
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(buf, reader); err != nil {
|
||||
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
|
||||
_ = r.close(closeCause{code: ws.StatusAbnormalClosure, reason: "failed to read"})
|
||||
return
|
||||
}
|
||||
|
||||
msg := string(buf.Bytes())
|
||||
select {
|
||||
case readQueue <- msg:
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||
}
|
||||
|
||||
func (r *Relay) handleMessage(message string) {
|
||||
// if this is an "EVENT" we will have this preparser logic that should speed things up a little
|
||||
// as we skip handling duplicate events
|
||||
@@ -178,7 +364,6 @@ func (r *Relay) handleMessage(message string) {
|
||||
|
||||
switch env := envelope.(type) {
|
||||
case *NoticeEnvelope:
|
||||
// see WithNoticeHandler
|
||||
if r.noticeHandler != nil {
|
||||
r.noticeHandler(r, string(*env))
|
||||
} else {
|
||||
@@ -188,7 +373,10 @@ func (r *Relay) handleMessage(message string) {
|
||||
if env.Challenge == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.performAuth = sync.Once{} // this ensures we can try to auth again
|
||||
r.challenge = *env.Challenge
|
||||
|
||||
if r.authHandler != nil {
|
||||
go func() {
|
||||
r.Auth(r.Context(), func(ctx context.Context, evt *Event) error {
|
||||
@@ -244,14 +432,6 @@ func (r *Relay) handleMessage(message string) {
|
||||
|
||||
// Write queues an arbitrary message to be sent to the relay.
|
||||
func (r *Relay) Write(msg []byte) {
|
||||
r.closeMutex.Lock()
|
||||
defer r.closeMutex.Unlock()
|
||||
select {
|
||||
case <-r.closedNotify:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-r.connectionContext.Done():
|
||||
case r.writeQueue <- writeRequest{msg: msg, answer: nil}:
|
||||
@@ -260,20 +440,24 @@ func (r *Relay) Write(msg []byte) {
|
||||
|
||||
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
|
||||
func (r *Relay) WriteWithError(msg []byte) error {
|
||||
ch := make(chan error)
|
||||
r.closeMutex.Lock()
|
||||
defer r.closeMutex.Unlock()
|
||||
select {
|
||||
case <-r.closedNotify:
|
||||
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
|
||||
default:
|
||||
ch := make(chan error, 1)
|
||||
|
||||
if r.writeQueue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-r.connectionContext.Done():
|
||||
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
||||
case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
||||
}
|
||||
return <-ch
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
return err
|
||||
case <-r.connectionContext.Done():
|
||||
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
||||
}
|
||||
}
|
||||
|
||||
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an OK response.
|
||||
@@ -286,20 +470,38 @@ func (r *Relay) Publish(ctx context.Context, event Event) error {
|
||||
// You don't have to build the AUTH event yourself, this function takes a function to which the
|
||||
// event that must be signed will be passed, so it's only necessary to sign that.
|
||||
func (r *Relay) Auth(ctx context.Context, sign func(context.Context, *Event) error) error {
|
||||
authEvent := Event{
|
||||
CreatedAt: Now(),
|
||||
Kind: KindClientAuthentication,
|
||||
Tags: Tags{
|
||||
Tag{"relay", r.URL},
|
||||
Tag{"challenge", r.challenge},
|
||||
},
|
||||
Content: "",
|
||||
}
|
||||
if err := sign(ctx, &authEvent); err != nil {
|
||||
return fmt.Errorf("error signing auth event: %w", err)
|
||||
if r.authed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.publish(ctx, authEvent.ID, &AuthEnvelope{Event: authEvent})
|
||||
if r.challenge == "" {
|
||||
return fmt.Errorf("no challenge, can't AUTH")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
r.performAuth.Do(func() {
|
||||
authEvent := Event{
|
||||
CreatedAt: Now(),
|
||||
Kind: KindClientAuthentication,
|
||||
Tags: Tags{
|
||||
Tag{"relay", r.URL},
|
||||
Tag{"challenge", r.challenge},
|
||||
},
|
||||
Content: "",
|
||||
}
|
||||
if err := sign(ctx, &authEvent); err != nil {
|
||||
err = fmt.Errorf("error signing auth event: %w", err)
|
||||
}
|
||||
|
||||
err = r.publish(ctx, authEvent.ID, &AuthEnvelope{Event: authEvent})
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
r.authed = true
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// publish can be used both for EVENT and for AUTH
|
||||
@@ -381,22 +583,20 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
|
||||
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
|
||||
// Failure to do that will result in a huge number of halted goroutines being created.
|
||||
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
|
||||
if r.conn == nil {
|
||||
return nil, fmt.Errorf("not connected to %s", r.URL)
|
||||
if !r.IsConnected() {
|
||||
return nil, ErrDisconnected
|
||||
}
|
||||
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
sub.cancel(ErrFireFailed)
|
||||
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filter, r.URL, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-r.closedNotify:
|
||||
sub.unsub(ErrDisconnected)
|
||||
case <-ctx.Done():
|
||||
}
|
||||
<-ctx.Done()
|
||||
sub.cancel(nil)
|
||||
}()
|
||||
|
||||
return sub, nil
|
||||
@@ -420,6 +620,7 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
|
||||
ClosedReason: make(chan string, 1),
|
||||
Filter: filter,
|
||||
match: filter.Matches,
|
||||
eoseTimedOut: make(chan struct{}),
|
||||
}
|
||||
|
||||
sub.checkDuplicate = opts.CheckDuplicate
|
||||
@@ -444,12 +645,45 @@ func (r *Relay) PrepareSubscription(ctx context.Context, filter Filter, opts Sub
|
||||
|
||||
go func() {
|
||||
time.Sleep(opts.MaxWaitForEOSE)
|
||||
close(sub.eoseTimedOut)
|
||||
sub.dispatchEose()
|
||||
}()
|
||||
}
|
||||
|
||||
// if the relay connection dies, cancel this subscription
|
||||
go func() {
|
||||
select {
|
||||
case <-sub.Context.Done():
|
||||
return
|
||||
case <-r.connectionContext.Done():
|
||||
sub.cancel(context.Cause(r.connectionContext))
|
||||
}
|
||||
}()
|
||||
|
||||
// start handling events, eose, unsub etc:
|
||||
go sub.start()
|
||||
go func() {
|
||||
<-sub.Context.Done()
|
||||
|
||||
// mark subscription as closed and send a CLOSE to the relay (naive sync.Once implementation)
|
||||
if sub.live.CompareAndSwap(true, false) {
|
||||
closeMsg := CloseEnvelope(sub.id)
|
||||
closeb, _ := (&closeMsg).MarshalJSON()
|
||||
if err := sub.Relay.WriteWithError(closeb); err != nil {
|
||||
_ = sub.Relay.close(err)
|
||||
}
|
||||
}
|
||||
|
||||
// remove subscription from our map
|
||||
sub.Relay.Subscriptions.Delete(sub.counter)
|
||||
|
||||
// do this so we don't have the possibility of closing the Events channel and then trying to send to it
|
||||
sub.mu.Lock()
|
||||
close(sub.Events)
|
||||
if sub.countResult != nil {
|
||||
close(sub.countResult)
|
||||
}
|
||||
sub.mu.Unlock()
|
||||
}()
|
||||
|
||||
return sub
|
||||
}
|
||||
@@ -484,29 +718,25 @@ func (r *Relay) QueryEvents(filter Filter) iter.Seq[Event] {
|
||||
}
|
||||
|
||||
// Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
|
||||
func (r *Relay) Count(
|
||||
ctx context.Context,
|
||||
filter Filter,
|
||||
opts SubscriptionOptions,
|
||||
) (uint32, []byte, error) {
|
||||
// If opts.AutoAuth is set, it will handle "auth-required:" CLOSEs using RelayOptions.AuthHandler.
|
||||
func (r *Relay) Count(ctx context.Context, filter Filter, opts SubscriptionOptions) (uint32, []byte, error) {
|
||||
v, err := r.countInternal(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if v.Count == nil {
|
||||
return 0, nil, errors.New("count subscription ended without result")
|
||||
}
|
||||
|
||||
return *v.Count, v.HyperLogLog, nil
|
||||
}
|
||||
|
||||
func (r *Relay) countInternal(ctx context.Context, filter Filter, opts SubscriptionOptions) (CountEnvelope, error) {
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
sub.countResult = make(chan CountEnvelope)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
return CountEnvelope{}, err
|
||||
if !r.IsConnected() {
|
||||
return CountEnvelope{}, ErrDisconnected
|
||||
}
|
||||
|
||||
defer sub.unsub(errors.New("countInternal() ended"))
|
||||
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.CancelFunc
|
||||
@@ -514,13 +744,54 @@ func (r *Relay) countInternal(ctx context.Context, filter Filter, opts Subscript
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
hasAuthed := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case count := <-sub.countResult:
|
||||
return count, nil
|
||||
case <-ctx.Done():
|
||||
return CountEnvelope{}, ctx.Err()
|
||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||
sub.countResult = make(chan CountEnvelope, 1)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
sub.cancel(ErrFireFailed)
|
||||
return CountEnvelope{}, fmt.Errorf("couldn't count %v at %s: %w", filter, r.URL, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
sub.cancel(nil)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case count, ok := <-sub.countResult:
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
if !ok || count.Count == nil {
|
||||
return CountEnvelope{}, errors.New("count subscription ended without result")
|
||||
}
|
||||
return count, nil
|
||||
case reason := <-sub.ClosedReason:
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
if strings.HasPrefix(reason, "auth-required:") && r.authHandler != nil && !hasAuthed {
|
||||
authErr := r.Auth(ctx, func(authCtx context.Context, evt *Event) error {
|
||||
return r.authHandler(authCtx, r, evt)
|
||||
})
|
||||
if authErr == nil {
|
||||
hasAuthed = true
|
||||
goto resubscribe
|
||||
}
|
||||
return CountEnvelope{}, fmt.Errorf("failed to auth: %w", authErr)
|
||||
}
|
||||
return CountEnvelope{}, fmt.Errorf("count: CLOSED received: %s", reason)
|
||||
case <-sub.Context.Done():
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
return CountEnvelope{}, context.Cause(sub.Context)
|
||||
case <-ctx.Done():
|
||||
sub.cancel(errors.New("countInternal() ended"))
|
||||
return CountEnvelope{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
resubscribe:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,19 +801,7 @@ func (r *Relay) Close() error {
|
||||
}
|
||||
|
||||
func (r *Relay) close(reason error) error {
|
||||
r.closeMutex.Lock()
|
||||
defer r.closeMutex.Unlock()
|
||||
|
||||
if r.connectionContextCancel == nil {
|
||||
return fmt.Errorf("relay already closed")
|
||||
}
|
||||
|
||||
if r.conn == nil {
|
||||
return fmt.Errorf("relay not connected")
|
||||
}
|
||||
|
||||
r.connectionContextCancel(reason)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultSchemaURL = "https://raw.githubusercontent.com/nostr-protocol/registry-of-kinds/952c36fcd7129aa85be22d00c2b381ae47ee9c18/schema.yaml"
|
||||
const DefaultSchemaURL = "https://raw.githubusercontent.com/nostr-protocol/registry-of-kinds/ffa18bf6fb5496d755b465b062e18c676df1a5d4/schema.yaml"
|
||||
|
||||
// this is used by hex.Decode in the "hex" validator -- we don't care about data races
|
||||
var hexdummydecoder = make([]byte, 128)
|
||||
|
||||
+13
-13
@@ -124,7 +124,7 @@ func TestValidateNext_ID(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := &nextSpec{Type: "id", Required: true}
|
||||
next := &ContentSpec{Type: "id", Required: true}
|
||||
_, err := v.validateNext(tt.tag, 1, next)
|
||||
if tt.valid {
|
||||
require.NoError(t, err)
|
||||
@@ -163,7 +163,7 @@ func TestValidateNext_PubKey(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := &nextSpec{Type: "pubkey", Required: true}
|
||||
next := &ContentSpec{Type: "pubkey", Required: true}
|
||||
_, err := v.validateNext(tt.tag, 1, next)
|
||||
if tt.valid {
|
||||
require.NoError(t, err)
|
||||
@@ -212,7 +212,7 @@ func TestValidateNext_Relay(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := &nextSpec{Type: "relay", Required: true}
|
||||
next := &ContentSpec{Type: "relay", Required: true}
|
||||
_, err := v.validateNext(tt.tag, 1, next)
|
||||
if tt.valid {
|
||||
require.NoError(t, err)
|
||||
@@ -261,7 +261,7 @@ func TestValidateNext_Kind(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := &nextSpec{Type: "kind", Required: true}
|
||||
next := &ContentSpec{Type: "kind", Required: true}
|
||||
_, err := v.validateNext(tt.tag, 1, next)
|
||||
if tt.valid {
|
||||
require.NoError(t, err)
|
||||
@@ -298,7 +298,7 @@ func TestValidateNext_Constrained(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := &nextSpec{Type: "constrained", Required: true, Either: tt.allowed}
|
||||
next := &ContentSpec{Type: "constrained", Required: true, Either: tt.allowed}
|
||||
_, err := v.validateNext(tt.tag, 3, next)
|
||||
if tt.valid {
|
||||
require.NoError(t, err)
|
||||
@@ -342,7 +342,7 @@ func TestValidateNext_GitCommit(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := &nextSpec{Type: "hex", Min: 40, Max: 40, Required: true}
|
||||
next := &ContentSpec{Type: "hex", Min: 40, Max: 40, Required: true}
|
||||
_, err := v.validateNext(tt.tag, 1, next)
|
||||
if tt.valid {
|
||||
require.NoError(t, err)
|
||||
@@ -376,7 +376,7 @@ func TestValidateNext_Addr(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := &nextSpec{Type: "addr", Required: true}
|
||||
next := &ContentSpec{Type: "addr", Required: true}
|
||||
_, err := v.validateNext(tt.tag, 1, next)
|
||||
if tt.valid {
|
||||
require.NoError(t, err)
|
||||
@@ -393,7 +393,7 @@ func TestValidateNext_Free(t *testing.T) {
|
||||
|
||||
// free type should accept anything
|
||||
tag := nostr.Tag{"test", "any value here", "even", "multiple", "values"}
|
||||
next := &nextSpec{Type: "free", Required: true}
|
||||
next := &ContentSpec{Type: "free", Required: true}
|
||||
_, err = v.validateNext(tag, 1, next)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -403,7 +403,7 @@ func TestValidateNext_UnknownType(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
tag := nostr.Tag{"test", "value"}
|
||||
next := &nextSpec{Type: "unknown-type", Required: true}
|
||||
next := &ContentSpec{Type: "unknown-type", Required: true}
|
||||
|
||||
// should not fail when FailOnUnknownType is false (default)
|
||||
_, err = v.validateNext(tag, 1, next)
|
||||
@@ -422,13 +422,13 @@ func TestValidateNext_RequiredField(t *testing.T) {
|
||||
|
||||
// test missing required field
|
||||
tag := nostr.Tag{"test"} // only name, missing required value
|
||||
next := &nextSpec{Type: "free", Required: true}
|
||||
next := &ContentSpec{Type: "free", Required: true}
|
||||
_, err = v.validateNext(tag, 1, next)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing index 1")
|
||||
|
||||
// test optional field
|
||||
next = &nextSpec{Type: "free", Required: false}
|
||||
next = &ContentSpec{Type: "free", Required: false}
|
||||
_, err = v.validateNext(tag, 1, next)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -439,7 +439,7 @@ func TestValidateNext_Variadic(t *testing.T) {
|
||||
|
||||
// test variadic field with multiple values
|
||||
tag := nostr.Tag{"test", "value1", "value2", "value3"}
|
||||
next := &nextSpec{Type: "free", Variadic: true}
|
||||
next := &ContentSpec{Type: "free", Variadic: true}
|
||||
_, err = v.validateNext(tag, 1, next)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -450,7 +450,7 @@ func TestValidateNext_Variadic(t *testing.T) {
|
||||
|
||||
// test variadic field with no values (should fail if required)
|
||||
tag = nostr.Tag{"test"}
|
||||
next = &nextSpec{Type: "free", Variadic: true, Required: true}
|
||||
next = &ContentSpec{Type: "free", Variadic: true, Required: true}
|
||||
_, err = v.validateNext(tag, 1, next)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package sdk
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
@@ -55,14 +56,15 @@ func (sys *System) batchLoadAddressableEvents(
|
||||
cm := sync.Mutex{}
|
||||
|
||||
aggregatedContext, aggregatedCancel := context.WithCancel(context.Background())
|
||||
waiting := len(pubkeys)
|
||||
waiting := atomic.Int32{}
|
||||
waiting.Add(int32(len(pubkeys)))
|
||||
|
||||
for i, pubkey := range pubkeys {
|
||||
ctx, cancel := context.WithCancel(ctxs[i])
|
||||
defer cancel()
|
||||
|
||||
// build batched queries for the external relays
|
||||
go func(i int, pubkey nostr.PubKey) {
|
||||
go func(i int, pubkey nostr.PubKey, ctx context.Context) {
|
||||
// gather relays we'll use for this pubkey
|
||||
relays := sys.determineRelaysToQuery(ctx, pubkey, kind)
|
||||
|
||||
@@ -92,11 +94,10 @@ func (sys *System) batchLoadAddressableEvents(
|
||||
wg.Done()
|
||||
|
||||
<-ctx.Done()
|
||||
waiting--
|
||||
if waiting == 0 {
|
||||
if waiting.Add(-1) == 0 {
|
||||
aggregatedCancel()
|
||||
}
|
||||
}(i, pubkey)
|
||||
}(i, pubkey, ctx)
|
||||
}
|
||||
|
||||
// wait for relay batches to be prepared
|
||||
|
||||
@@ -68,7 +68,7 @@ func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts Options
|
||||
// The first context passed to this function within a given batch window will be provided to
|
||||
// the registered BatchFunc.
|
||||
func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) {
|
||||
c := make(chan Result[V])
|
||||
c := make(chan Result[V], 1)
|
||||
|
||||
// this is sent to batch fn. It contains the key and the channel to return
|
||||
// the result on
|
||||
@@ -83,7 +83,9 @@ func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) {
|
||||
case <-b.thresholdReached:
|
||||
case <-time.After(l.wait):
|
||||
l.batchLock.Lock()
|
||||
l.curBatcher = l.newBatcher()
|
||||
if l.curBatcher == b {
|
||||
l.curBatcher = l.newBatcher()
|
||||
}
|
||||
l.batchLock.Unlock()
|
||||
}
|
||||
|
||||
@@ -117,11 +119,15 @@ func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) {
|
||||
|
||||
l.batchLock.Unlock()
|
||||
|
||||
if v, ok := <-c; ok {
|
||||
return v.Data, v.Error
|
||||
select {
|
||||
case v, ok := <-c:
|
||||
if ok {
|
||||
return v.Data, v.Error
|
||||
}
|
||||
return value, NoValueError
|
||||
case <-ctx.Done():
|
||||
return value, ctx.Err()
|
||||
}
|
||||
|
||||
return value, NoValueError
|
||||
}
|
||||
|
||||
type batcher[K comparable, V any] struct {
|
||||
|
||||
+30
-41
@@ -2,54 +2,39 @@ package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"fiatjaf.com/nostr"
|
||||
"fiatjaf.com/nostr/eventstore/slicestore"
|
||||
"fiatjaf.com/nostr/khatru"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamLiveFeed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// start 3 local relays
|
||||
// start 3 local relays using httptest
|
||||
relay1 := khatru.NewRelay()
|
||||
relay2 := khatru.NewRelay()
|
||||
relay3 := khatru.NewRelay()
|
||||
|
||||
for _, r := range []*khatru.Relay{relay1, relay2, relay3} {
|
||||
db := &slicestore.SliceStore{}
|
||||
db.Init()
|
||||
r.UseEventstore(db, 4000)
|
||||
defer db.Close()
|
||||
dbs := make([]*slicestore.SliceStore, 3)
|
||||
for i, r := range []*khatru.Relay{relay1, relay2, relay3} {
|
||||
dbs[i] = &slicestore.SliceStore{}
|
||||
dbs[i].Init()
|
||||
r.UseEventstore(dbs[i], 4000)
|
||||
}
|
||||
|
||||
s1 := make(chan bool)
|
||||
s2 := make(chan bool)
|
||||
s3 := make(chan bool)
|
||||
|
||||
go func() {
|
||||
err := relay1.Start("127.0.0.1", 48481, s1)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
go func() {
|
||||
err := relay2.Start("127.0.0.1", 48482, s2)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
go func() {
|
||||
err := relay3.Start("127.0.0.1", 48483, s3)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
defer relay1.Shutdown(ctx)
|
||||
defer relay2.Shutdown(ctx)
|
||||
defer relay3.Shutdown(ctx)
|
||||
|
||||
<-s1
|
||||
<-s2
|
||||
<-s3
|
||||
server1 := httptest.NewServer(relay1)
|
||||
server2 := httptest.NewServer(relay2)
|
||||
server3 := httptest.NewServer(relay3)
|
||||
defer server1.Close()
|
||||
defer server2.Close()
|
||||
defer server3.Close()
|
||||
for _, db := range dbs {
|
||||
defer db.Close()
|
||||
}
|
||||
|
||||
// generate two random keypairs for testing
|
||||
sk1 := nostr.Generate()
|
||||
@@ -60,14 +45,18 @@ func TestStreamLiveFeed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
url1 := "ws" + server1.URL[4:]
|
||||
url2 := "ws" + server2.URL[4:]
|
||||
url3 := "ws" + server3.URL[4:]
|
||||
|
||||
// first publish relay lists to relay1 for both users
|
||||
relayListEvt1 := nostr.Event{
|
||||
PubKey: pk1,
|
||||
CreatedAt: nostr.Now(),
|
||||
Kind: 10002,
|
||||
Tags: nostr.Tags{
|
||||
{"r", "ws://localhost:48482", "write"},
|
||||
{"r", "ws://localhost:48483", "write"},
|
||||
{"r", url2, "write"},
|
||||
{"r", url3, "write"},
|
||||
},
|
||||
Content: "",
|
||||
}
|
||||
@@ -78,15 +67,15 @@ func TestStreamLiveFeed(t *testing.T) {
|
||||
CreatedAt: nostr.Now(),
|
||||
Kind: 10002,
|
||||
Tags: nostr.Tags{
|
||||
{"r", "ws://localhost:48482", "write"},
|
||||
{"r", "ws://localhost:48483", "write"},
|
||||
{"r", url2, "write"},
|
||||
{"r", url3, "write"},
|
||||
},
|
||||
Content: "",
|
||||
}
|
||||
relayListEvt2.Sign(sk2)
|
||||
|
||||
// publish relay lists to relay1
|
||||
relay, err := nostr.RelayConnect(ctx, "ws://localhost:48481", nostr.RelayOptions{})
|
||||
relay, err := nostr.RelayConnect(ctx, url1, nostr.RelayOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to relay1: %v", err)
|
||||
}
|
||||
@@ -100,7 +89,7 @@ func TestStreamLiveFeed(t *testing.T) {
|
||||
|
||||
// create a new system instance pointing only to relay1 as the "indexer"
|
||||
sys := NewSystem()
|
||||
sys.RelayListRelays = NewRelayStream("ws://localhost:48481")
|
||||
sys.RelayListRelays = NewRelayStream(url1)
|
||||
defer sys.Close()
|
||||
|
||||
// prepublish some events
|
||||
@@ -123,8 +112,8 @@ func TestStreamLiveFeed(t *testing.T) {
|
||||
evt2.Sign(sk2)
|
||||
|
||||
// publish events concurrently to relays 2 and 3
|
||||
go sys.Pool.PublishMany(ctx, []string{"ws://localhost:48482", "ws://localhost:48483"}, evt1)
|
||||
go sys.Pool.PublishMany(ctx, []string{"ws://localhost:48482", "ws://localhost:48483"}, evt2)
|
||||
go sys.Pool.PublishMany(ctx, []string{url2, url3}, evt1)
|
||||
go sys.Pool.PublishMany(ctx, []string{url2, url3}, evt2)
|
||||
|
||||
// start streaming events for both pubkeys
|
||||
events, err := sys.StreamLiveFeed(ctx, []nostr.PubKey{pk1, pk2}, []nostr.Kind{1})
|
||||
@@ -174,8 +163,8 @@ func TestStreamLiveFeed(t *testing.T) {
|
||||
evt2.Sign(sk2)
|
||||
|
||||
// publish events concurrently to relays 2 and 3
|
||||
go sys.Pool.PublishMany(ctx, []string{"ws://localhost:48482", "ws://localhost:48483"}, evt1)
|
||||
go sys.Pool.PublishMany(ctx, []string{"ws://localhost:48482", "ws://localhost:48483"}, evt2)
|
||||
go sys.Pool.PublishMany(ctx, []string{url2, url3}, evt1)
|
||||
go sys.Pool.PublishMany(ctx, []string{url2, url3}, evt2)
|
||||
|
||||
// wait for events
|
||||
receivedEvt1 := false
|
||||
|
||||
+23
-1
@@ -22,6 +22,17 @@ func (sys *System) FetchBookmarkList(ctx context.Context, pubkey nostr.PubKey) G
|
||||
return ml
|
||||
}
|
||||
|
||||
func (sys *System) FetchProfileBadgesList(ctx context.Context, pubkey nostr.PubKey) GenericList[string, EventRef] {
|
||||
sys.profileBadgesListCacheOnce.Do(func() {
|
||||
if sys.ProfileBadgesListCache == nil {
|
||||
sys.ProfileBadgesListCache = cache_memory.New[GenericList[string, EventRef]](1000)
|
||||
}
|
||||
})
|
||||
|
||||
ml, _ := fetchGenericList(sys, ctx, pubkey, 10008, kind_10008, parseEventRef, sys.ProfileBadgesListCache)
|
||||
return ml
|
||||
}
|
||||
|
||||
func (sys *System) FetchPinList(ctx context.Context, pubkey nostr.PubKey) GenericList[string, EventRef] {
|
||||
sys.pinListCacheOnce.Do(func() {
|
||||
if sys.PinListCache == nil {
|
||||
@@ -33,6 +44,17 @@ func (sys *System) FetchPinList(ctx context.Context, pubkey nostr.PubKey) Generi
|
||||
return ml
|
||||
}
|
||||
|
||||
func (sys *System) FetchGitRepositoryList(ctx context.Context, pubkey nostr.PubKey) GenericList[string, EventRef] {
|
||||
sys.gitRepositoryListCacheOnce.Do(func() {
|
||||
if sys.GitRepositoryListCache == nil {
|
||||
sys.GitRepositoryListCache = cache_memory.New[GenericList[string, EventRef]](1000)
|
||||
}
|
||||
})
|
||||
|
||||
ml, _ := fetchGenericList(sys, ctx, pubkey, 10018, kind_10018, parseEventRef, sys.GitRepositoryListCache)
|
||||
return ml
|
||||
}
|
||||
|
||||
func parseEventRef(tag nostr.Tag) (evr EventRef, ok bool) {
|
||||
if len(tag) < 2 {
|
||||
return evr, false
|
||||
@@ -54,5 +76,5 @@ func parseEventRef(tag nostr.Tag) (evr EventRef, ok bool) {
|
||||
return evr, false
|
||||
}
|
||||
|
||||
return evr, false
|
||||
return evr, true
|
||||
}
|
||||
|
||||
@@ -39,6 +39,39 @@ func (sys *System) FetchMuteList(ctx context.Context, pubkey nostr.PubKey) Gener
|
||||
return ml
|
||||
}
|
||||
|
||||
func (sys *System) FetchMediaFollowList(ctx context.Context, pubkey nostr.PubKey) GenericList[nostr.PubKey, ProfileRef] {
|
||||
sys.mediaFollowListCacheOnce.Do(func() {
|
||||
if sys.MediaFollowListCache == nil {
|
||||
sys.MediaFollowListCache = cache_memory.New[GenericList[nostr.PubKey, ProfileRef]](1000)
|
||||
}
|
||||
})
|
||||
|
||||
ml, _ := fetchGenericList(sys, ctx, pubkey, 10020, kind_10020, parseProfileRef, sys.MediaFollowListCache)
|
||||
return ml
|
||||
}
|
||||
|
||||
func (sys *System) FetchGoodWikiAuthorList(ctx context.Context, pubkey nostr.PubKey) GenericList[nostr.PubKey, ProfileRef] {
|
||||
sys.goodWikiAuthorListCacheOnce.Do(func() {
|
||||
if sys.GoodWikiAuthorListCache == nil {
|
||||
sys.GoodWikiAuthorListCache = cache_memory.New[GenericList[nostr.PubKey, ProfileRef]](1000)
|
||||
}
|
||||
})
|
||||
|
||||
ml, _ := fetchGenericList(sys, ctx, pubkey, 10101, kind_10101, parseProfileRef, sys.GoodWikiAuthorListCache)
|
||||
return ml
|
||||
}
|
||||
|
||||
func (sys *System) FetchGitAuthorList(ctx context.Context, pubkey nostr.PubKey) GenericList[nostr.PubKey, ProfileRef] {
|
||||
sys.gitAuthorListCacheOnce.Do(func() {
|
||||
if sys.GitAuthorListCache == nil {
|
||||
sys.GitAuthorListCache = cache_memory.New[GenericList[nostr.PubKey, ProfileRef]](1000)
|
||||
}
|
||||
})
|
||||
|
||||
ml, _ := fetchGenericList(sys, ctx, pubkey, 10017, kind_10017, parseProfileRef, sys.GitAuthorListCache)
|
||||
return ml
|
||||
}
|
||||
|
||||
func (sys *System) FetchFollowSets(ctx context.Context, pubkey nostr.PubKey) GenericSets[nostr.PubKey, ProfileRef] {
|
||||
sys.followSetsCacheOnce.Do(func() {
|
||||
if sys.FollowSetsCache == nil {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user