From b899ef88654fd084f3fcb4882e686606977d2469 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Thu, 23 Apr 2026 20:31:57 -0300 Subject: [PATCH] faster signature verification by serializing directly into the sha with less allocations. --- event.go | 226 +++++++++++++++++++++++++++++++++++--- go.mod | 2 +- go.sum | 4 +- helpers.go | 40 ------- signature.go | 11 +- signature_libsecp256k1.go | 11 +- 6 files changed, 224 insertions(+), 70 deletions(-) diff --git a/event.go b/event.go index 51ad0b1..2435657 100644 --- a/event.go +++ b/event.go @@ -2,7 +2,9 @@ package nostr import ( "crypto/sha256" + "hash" "strconv" + "unsafe" "github.com/mailru/easyjson" "github.com/templexxx/xhex" @@ -26,10 +28,17 @@ func (evt Event) String() string { // GetID serializes and returns the event ID as a string. func (evt Event) GetID() ID { - return sha256.Sum256(evt.Serialize()) + var id ID + evt.serializedHash(&id) + return id } -// CheckID checks if the implied ID matches the given ID more efficiently. +// SetID calculates and sets the id to the event in a single operation. +func (evt *Event) SetID() { + evt.serializedHash(&evt.ID) +} + +// CheckID checks if the implied ID matches the currently assigned ID. func (evt Event) CheckID() bool { return evt.GetID() == evt.ID } @@ -38,17 +47,56 @@ func (evt Event) CheckID() bool { func (evt Event) Serialize() []byte { // the serialization process is just putting everything into a JSON array // so the order is kept. See NIP-01 - dst := make([]byte, 4+64, 100+len(evt.Content)+len(evt.Tags)*80) + dst := make([]byte, 0, 100+len(evt.Content)+len(evt.Tags)*80) + return evt.appendSerialized(dst) +} - // the header portion is easy to serialize - // [0,"pubkey",created_at,kind,[ - copy(dst, `[0,"`) - xhex.Encode(dst[4:4+64], evt.PubKey[:]) // there will always be such capacity +var escTable [256]bool + +// pre-built escape sequences; index by the offending byte. +var escSeq [256][2]byte + +// pre-built []byte slices for hash.Write calls (no per-call allocation). +var escSlice [256][]byte + +var ( + jsonQuote = []byte{'"'} + serializedStart = []byte(`[0,"`) + serializedPubkeyEnd = []byte(`",`) + serializedTagsEnd = []byte("],") + serializedTagStart = []byte{'['} + serializedTagEnd = []byte{']'} + serializedComma = []byte{','} + serializedEnd = []byte{']'} +) + +func init() { + for _, b := range []byte{'"', '\\', '\n', '\r', '\t'} { + escTable[b] = true + } + + escSeq['"'] = [2]byte{'\\', '"'} + escSeq['\\'] = [2]byte{'\\', '\\'} + escSeq['\n'] = [2]byte{'\\', 'n'} + escSeq['\r'] = [2]byte{'\\', 'r'} + escSeq['\t'] = [2]byte{'\\', 't'} + for b, seq := range escSeq { + if escTable[b] { + escSlice[b] = seq[:] + } + } +} + +func (evt Event) appendSerialized(dst []byte) []byte { + start := len(dst) + dst = append(dst, `[0,"`...) + dst = append(dst, make([]byte, 64)...) + xhex.Encode(dst[start+4:start+4+64], evt.PubKey[:]) dst = append(dst, `",`...) - dst = append(dst, strconv.FormatInt(int64(evt.CreatedAt), 10)...) - dst = append(dst, `,`...) - dst = append(dst, strconv.FormatUint(uint64(evt.Kind), 10)...) - dst = append(dst, `,`...) + dst = strconv.AppendInt(dst, int64(evt.CreatedAt), 10) + dst = append(dst, ',') + dst = strconv.AppendUint(dst, uint64(evt.Kind), 10) + dst = append(dst, ',') // tags dst = append(dst, '[') @@ -62,15 +110,167 @@ func (evt Event) Serialize() []byte { if i > 0 { dst = append(dst, ',') } - dst = escapeString(dst, s) + dst = appendJSONString(dst, s) } dst = append(dst, ']') } dst = append(dst, "],"...) // content needs to be escaped in general as it is user generated. - dst = escapeString(dst, evt.Content) + dst = appendJSONString(dst, evt.Content) dst = append(dst, ']') return dst } + +func (evt Event) serializedHash(dst *ID) { + h := sha256.New() + h.Write(serializedStart) + + var pubkeyHex [64]byte + xhex.Encode(pubkeyHex[:], evt.PubKey[:]) + h.Write(pubkeyHex[:]) + h.Write(serializedPubkeyEnd) + + var numBuf [20]byte + b := strconv.AppendInt(numBuf[:0], int64(evt.CreatedAt), 10) + h.Write(b) + h.Write(serializedComma) + b = strconv.AppendUint(numBuf[:0], uint64(evt.Kind), 10) + h.Write(b) + h.Write(serializedComma) + + h.Write(serializedTagStart) + for i, tag := range evt.Tags { + if i > 0 { + h.Write(serializedComma) + } + h.Write(serializedTagStart) + for j, s := range tag { + if j > 0 { + h.Write(serializedComma) + } + writeJSONString(h, s) + } + h.Write(serializedTagEnd) + } + h.Write(serializedTagsEnd) + + writeJSONString(h, evt.Content) + h.Write(serializedEnd) + + h.Sum((*dst)[:0]) +} + +// ── SWAR helper ────────────────────────────────────────────────────────────── + +// hasSpecial returns non-zero if any byte in w is one of: \t 0x09, \n 0x0A, +// " 0x22, \ 0x5C. Uses the classic "hasvalue" bit-trick — no branches, no +// memory, pure ALU. Works regardless of endianness because we only care +// whether a match exists, not where. +// +//go:nosplit +func hasSpecial(w uint64) bool { + match := func(w, v uint64) uint64 { + x := w ^ (0x0101010101010101 * v) + return (x - 0x0101010101010101) & ^x & 0x8080808080808080 + } + return match(w, 0x09)|match(w, 0x0A)|match(w, 0x0D)|match(w, 0x22)|match(w, 0x5C) != 0 +} + +func appendJSONString(dst []byte, s string) []byte { + dst = append(dst, '"') + + n := len(s) + if n == 0 { + return append(dst, '"') + } + + base := uintptr(unsafe.Pointer(unsafe.StringData(s))) + start, i := 0, 0 + + // consume 8 bytes at a time; + // if the whole word is clean, advance without touching dst at all; + // but when a word is dirty, fall back to the byte loop only for that 8-byte window + for i+8 <= n { + w := *(*uint64)(unsafe.Pointer(base + uintptr(i))) + if hasSpecial(w) { + for end := i + 8; i < end; i++ { + if escTable[s[i]] { + // append everything since the start or the last time we did this up to here + dst = append(dst, s[start:i]...) + + // append this special sequence + seq := escSeq[s[i]] + dst = append(dst, seq[0], seq[1]) + + // set this as a checkpoint + start = i + 1 + } + } + } else { + i += 8 + } + } + + // scalar tail for the remaining <8 bytes (same logic used for the hasSpecial branch above) + for ; i < n; i++ { + if escTable[s[i]] { + dst = append(dst, s[start:i]...) + seq := escSeq[s[i]] + dst = append(dst, seq[0], seq[1]) + start = i + 1 + } + } + + // add the remaining chunk (in a string without any specials this will add everything at once) + dst = append(dst, s[start:]...) + + return append(dst, '"') +} + +func writeJSONString(h hash.Hash, s string) { + h.Write(jsonQuote) + + n := len(s) + if n == 0 { + h.Write(jsonQuote) + return + } + + base := uintptr(unsafe.Pointer(unsafe.StringData(s))) + start, i := 0, 0 + + for i+8 <= n { + w := *(*uint64)(unsafe.Pointer(base + uintptr(i))) + // apply same logic as of appendJSONString() + if hasSpecial(w) { + for end := i + 8; i < end; i++ { + if escTable[s[i]] { + if i > start { + h.Write(unsafe.Slice(unsafe.StringData(s[start:i]), i-start)) + } + h.Write(escSlice[s[i]]) + start = i + 1 + } + } + } else { + i += 8 + } + } + + for ; i < n; i++ { + if escTable[s[i]] { + if i > start { + h.Write(unsafe.Slice(unsafe.StringData(s[start:i]), i-start)) + } + h.Write(escSlice[s[i]]) + start = i + 1 + } + } + + if start < n { + h.Write(unsafe.Slice(unsafe.StringData(s[start:]), len(s)-start)) + } + h.Write(jsonQuote) +} diff --git a/go.mod b/go.mod index 740aa19..66f73aa 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( ) require ( - fiatjaf.com/lib v0.3.6 + fiatjaf.com/lib v0.3.7 github.com/dgraph-io/ristretto/v2 v2.3.0 github.com/go-git/go-git/v5 v5.16.3 github.com/pemistahl/lingua-go v1.4.0 diff --git a/go.sum b/go.sum index 55f8e78..2e9ddcc 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -fiatjaf.com/lib v0.3.6 h1:GRZNSxHI2EWdjSKVuzaT+c0aifLDtS16SzkeJaHyJfY= -fiatjaf.com/lib v0.3.6/go.mod h1:UlHaZvPHj25PtKLh9GjZkUHRmQ2xZ8Jkoa4VRaLeeQ8= +fiatjaf.com/lib v0.3.7 h1:mXZOn7NrUcjSdy4oNvwQyAmes7Ueb+Zr5hjqMIe2dxI= +fiatjaf.com/lib v0.3.7/go.mod h1:UlHaZvPHj25PtKLh9GjZkUHRmQ2xZ8Jkoa4VRaLeeQ8= github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc= github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw= github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc= diff --git a/helpers.go b/helpers.go index ba5154a..2d86bc0 100644 --- a/helpers.go +++ b/helpers.go @@ -92,46 +92,6 @@ func similarPublicKey(as, bs []PubKey) bool { return true } -// Escaping strings for JSON encoding according to RFC8259. -// Also encloses result in quotation marks "". -func escapeString(dst []byte, s string) []byte { - dst = append(dst, '"') - for i := 0; i < len(s); i++ { - c := s[i] - switch { - case c == '"': - // quotation mark - dst = append(dst, []byte{'\\', '"'}...) - case c == '\\': - // reverse solidus - dst = append(dst, []byte{'\\', '\\'}...) - case c >= 0x20: - // default, rest below are control chars - dst = append(dst, c) - case c == 0x08: - dst = append(dst, []byte{'\\', 'b'}...) - case c < 0x09: - dst = append(dst, []byte{'\\', 'u', '0', '0', '0', '0' + c}...) - case c == 0x09: - dst = append(dst, []byte{'\\', 't'}...) - case c == 0x0a: - dst = append(dst, []byte{'\\', 'n'}...) - case c == 0x0c: - dst = append(dst, []byte{'\\', 'f'}...) - case c == 0x0d: - dst = append(dst, []byte{'\\', 'r'}...) - case c < 0x10: - dst = append(dst, []byte{'\\', 'u', '0', '0', '0', 0x57 + c}...) - case c < 0x1a: - dst = append(dst, []byte{'\\', 'u', '0', '0', '1', 0x20 + c}...) - case c < 0x20: - dst = append(dst, []byte{'\\', 'u', '0', '0', '1', 0x47 + c}...) - } - } - dst = append(dst, '"') - return dst -} - func subIdToSerial(subId string) int64 { n := strings.Index(subId, ":") if n < 0 || n > len(subId) { diff --git a/signature.go b/signature.go index 88168f5..5c295b1 100644 --- a/signature.go +++ b/signature.go @@ -3,8 +3,6 @@ package nostr import ( - "crypto/sha256" - "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -34,8 +32,8 @@ func (evt Event) VerifySignature() bool { sig := schnorr.NewSignature(&r, &s) // check signature - hash := sha256.Sum256(evt.Serialize()) - return sig.Verify(hash[:], pubkey) + evt.SetID() + return sig.Verify(evt.ID[:], pubkey) } // Sign signs an event with a given privateKey. @@ -52,13 +50,12 @@ func (evt *Event) Sign(secretKey [32]byte) error { pkBytes := pk.SerializeCompressed()[1:] evt.PubKey = PubKey(pkBytes) - h := sha256.Sum256(evt.Serialize()) - sig, err := schnorr.Sign(sk, h[:], schnorr.FastSign()) + evt.SetID() + sig, err := schnorr.Sign(sk, evt.ID[:], schnorr.FastSign()) if err != nil { return err } - evt.ID = h sigb := sig.Serialize() evt.Sig = [64]byte(sigb) diff --git a/signature_libsecp256k1.go b/signature_libsecp256k1.go index 5617ecb..c106410 100644 --- a/signature_libsecp256k1.go +++ b/signature_libsecp256k1.go @@ -25,7 +25,6 @@ import "C" import ( "crypto/rand" - "crypto/sha256" "errors" "unsafe" @@ -33,14 +32,14 @@ import ( ) func (evt Event) VerifySignature() bool { - msg := sha256.Sum256(evt.Serialize()) + evt.SetID() var xonly C.secp256k1_xonly_pubkey if C.secp256k1_xonly_pubkey_parse(globalSecp256k1Context, &xonly, (*C.uchar)(unsafe.Pointer(&evt.PubKey[0]))) != 1 { return false } - res := C.secp256k1_schnorrsig_verify(globalSecp256k1Context, (*C.uchar)(unsafe.Pointer(&evt.Sig[0])), (*C.uchar)(unsafe.Pointer(&msg[0])), 32, &xonly) + res := C.secp256k1_schnorrsig_verify(globalSecp256k1Context, (*C.uchar)(unsafe.Pointer(&evt.Sig[0])), (*C.uchar)(unsafe.Pointer(&evt.ID[0])), 32, &xonly) return res == 1 } @@ -59,16 +58,14 @@ func (evt *Event) Sign(secretKey [32]byte, signOpts ...schnorr.SignOption) error C.secp256k1_keypair_xonly_pub(globalSecp256k1Context, &xonly, nil, &keypair) C.secp256k1_xonly_pubkey_serialize(globalSecp256k1Context, (*C.uchar)(unsafe.Pointer(&evt.PubKey[0])), &xonly) - h := sha256.Sum256(evt.Serialize()) + evt.SetID() var random [32]byte rand.Read(random[:]) - if C.secp256k1_schnorrsig_sign32(globalSecp256k1Context, (*C.uchar)(unsafe.Pointer(&evt.Sig[0])), (*C.uchar)(unsafe.Pointer(&h[0])), &keypair, (*C.uchar)(unsafe.Pointer(&random[0]))) != 1 { + if C.secp256k1_schnorrsig_sign32(globalSecp256k1Context, (*C.uchar)(unsafe.Pointer(&evt.Sig[0])), (*C.uchar)(unsafe.Pointer(&evt.ID[0])), &keypair, (*C.uchar)(unsafe.Pointer(&random[0]))) != 1 { return errors.New("failed to sign message") } - evt.ID = h - return nil }