mirror of
https://github.com/sot-tech/mochi.git
synced 2026-06-15 09:03:37 -07:00
simplify udp connection ID generator
This commit is contained in:
+3
-1
@@ -1,6 +1,8 @@
|
||||
---
|
||||
run:
|
||||
go: "1.19"
|
||||
# mochi in not written with generics (a.t.m),
|
||||
# so we can check with 1.17
|
||||
go: "1.17"
|
||||
timeout: "5m"
|
||||
output:
|
||||
sort-results: true
|
||||
|
||||
@@ -33,31 +33,21 @@ type ConnectionIDGenerator struct {
|
||||
// It will be overwritten by subsequent calls to Generate.
|
||||
connID []byte
|
||||
|
||||
// scratch is a 32-byte slice that is used as a scratchpad for the generated
|
||||
// HMACs.
|
||||
scratch []byte
|
||||
// the leeway for a timestamp on a connection ID.
|
||||
maxClockSkew time.Duration
|
||||
}
|
||||
|
||||
// NewConnectionIDGenerator creates a new connection ID generator.
|
||||
func NewConnectionIDGenerator(key string) *ConnectionIDGenerator {
|
||||
func NewConnectionIDGenerator(key string, maxClockSkew time.Duration) *ConnectionIDGenerator {
|
||||
return &ConnectionIDGenerator{
|
||||
mac: hmac.New(func() hash.Hash {
|
||||
return xxh3.New()
|
||||
}, []byte(key)),
|
||||
connID: make([]byte, 8),
|
||||
scratch: make([]byte, 32),
|
||||
connID: make([]byte, 8),
|
||||
maxClockSkew: maxClockSkew,
|
||||
}
|
||||
}
|
||||
|
||||
// reset resets the generator.
|
||||
// This is called by other methods of the generator, it's not necessary to call
|
||||
// it after getting a generator from a pool.
|
||||
func (g *ConnectionIDGenerator) reset() {
|
||||
g.mac.Reset()
|
||||
g.connID = g.connID[:8]
|
||||
g.scratch = g.scratch[:0]
|
||||
}
|
||||
|
||||
// Generate generates an 8-byte connection ID as described in BEP 15 for the
|
||||
// given IP and the current time.
|
||||
//
|
||||
@@ -73,15 +63,12 @@ func (g *ConnectionIDGenerator) reset() {
|
||||
// will be reused, so it must not be referenced after returning the generator
|
||||
// to a pool and will be overwritten be subsequent calls to Generate!
|
||||
func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) []byte {
|
||||
g.reset()
|
||||
|
||||
g.mac.Reset()
|
||||
binary.BigEndian.PutUint32(g.connID, uint32(now.Unix()))
|
||||
|
||||
g.mac.Write(g.connID[:4])
|
||||
ipBytes, _ := ip.MarshalBinary()
|
||||
g.mac.Write(ipBytes)
|
||||
g.scratch = g.mac.Sum(g.scratch)
|
||||
copy(g.connID[4:8], g.scratch[:4])
|
||||
g.mac.Write(ip.AsSlice())
|
||||
copy(g.connID[4:8], g.mac.Sum(nil)[:4])
|
||||
|
||||
log.Debug().
|
||||
Stringer("ip", ip).
|
||||
@@ -92,22 +79,20 @@ func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) []byte {
|
||||
}
|
||||
|
||||
// Validate validates the given connection ID for an IP and the current time.
|
||||
func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now time.Time, maxClockSkew time.Duration) bool {
|
||||
ts := time.Unix(int64(binary.BigEndian.Uint32(connectionID[:4])), 0)
|
||||
func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now time.Time) bool {
|
||||
g.mac.Reset()
|
||||
tsBytes := connectionID[:4]
|
||||
ts := time.Unix(int64(binary.BigEndian.Uint32(tsBytes)), 0)
|
||||
log.Debug().
|
||||
Stringer("ip", ip).
|
||||
Time("ts", ts).Time("now", now).
|
||||
Time("ts", ts).
|
||||
Time("now", now).
|
||||
Hex("connID", g.connID).
|
||||
Msg("validating connection ID")
|
||||
if now.After(ts.Add(ttl)) || ts.After(now.Add(maxClockSkew)) {
|
||||
return false
|
||||
}
|
||||
|
||||
g.reset()
|
||||
|
||||
g.mac.Write(connectionID[:4])
|
||||
ipBytes, _ := ip.MarshalBinary()
|
||||
g.mac.Write(ipBytes)
|
||||
g.scratch = g.mac.Sum(g.scratch)
|
||||
return hmac.Equal(g.scratch[:4], connectionID[4:])
|
||||
g.mac.Write(tsBytes)
|
||||
g.mac.Write(ip.AsSlice())
|
||||
return hmac.Equal(g.mac.Sum(nil)[:4], connectionID[4:8]) &&
|
||||
now.Before(ts.Add(ttl)) &&
|
||||
ts.Before(now.Add(g.maxClockSkew))
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"crypto/hmac"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zeebo/xxh3"
|
||||
|
||||
"github.com/sot-tech/mochi/pkg/log"
|
||||
)
|
||||
@@ -32,14 +33,14 @@ var golden = []struct {
|
||||
// This is a wrapper around creating a new ConnectionIDGenerator and generating
|
||||
// an ID. It is recommended to use the generator for performance.
|
||||
func NewConnectionID(ip netip.Addr, now time.Time, key string) []byte {
|
||||
return NewConnectionIDGenerator(key).Generate(ip, now)
|
||||
return NewConnectionIDGenerator(key, 0).Generate(ip, now)
|
||||
}
|
||||
|
||||
// ValidConnectionID determines whether a connection identifier is legitimate.
|
||||
// This is a wrapper around creating a new ConnectionIDGenerator and validating
|
||||
// the ID. It is recommended to use the generator for performance.
|
||||
func ValidConnectionID(connectionID []byte, ip netip.Addr, now time.Time, maxClockSkew time.Duration, key string) bool {
|
||||
return NewConnectionIDGenerator(key).Validate(connectionID, ip, now, maxClockSkew)
|
||||
return NewConnectionIDGenerator(key, maxClockSkew).Validate(connectionID, ip, now)
|
||||
}
|
||||
|
||||
// simpleNewConnectionID generates a new connection ID the explicit way.
|
||||
@@ -48,10 +49,11 @@ func simpleNewConnectionID(ip netip.Addr, now time.Time, key string) []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint32(buf, uint32(now.Unix()))
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac := hmac.New(func() hash.Hash {
|
||||
return xxh3.New()
|
||||
}, []byte(key))
|
||||
mac.Write(buf[:4])
|
||||
ipBytes, _ := ip.MarshalBinary()
|
||||
mac.Write(ipBytes)
|
||||
mac.Write(ip.AsSlice())
|
||||
macBytes := mac.Sum(nil)[:4]
|
||||
copy(buf[4:], macBytes)
|
||||
|
||||
@@ -93,7 +95,7 @@ func TestReuseGeneratorGenerate(t *testing.T) {
|
||||
cid := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
||||
require.Len(t, cid, 8)
|
||||
|
||||
gen := NewConnectionIDGenerator(tt.key)
|
||||
gen := NewConnectionIDGenerator(tt.key, 0)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
connID := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0))
|
||||
@@ -106,10 +108,10 @@ func TestReuseGeneratorGenerate(t *testing.T) {
|
||||
func TestReuseGeneratorValidate(t *testing.T) {
|
||||
for _, tt := range golden {
|
||||
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
|
||||
gen := NewConnectionIDGenerator(tt.key)
|
||||
gen := NewConnectionIDGenerator(tt.key, time.Minute)
|
||||
cid := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0))
|
||||
for i := 0; i < 3; i++ {
|
||||
got := gen.Validate(cid, netip.MustParseAddr(tt.ip), time.Unix(tt.now, 0), time.Minute)
|
||||
got := gen.Validate(cid, netip.MustParseAddr(tt.ip), time.Unix(tt.now, 0))
|
||||
if got != tt.valid {
|
||||
t.Errorf("expected validity: %t got validity: %t", tt.valid, got)
|
||||
}
|
||||
@@ -159,7 +161,7 @@ func BenchmarkConnectionIDGenerator_Generate(b *testing.B) {
|
||||
|
||||
pool := &sync.Pool{
|
||||
New: func() any {
|
||||
return NewConnectionIDGenerator(key)
|
||||
return NewConnectionIDGenerator(key, 0)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -197,14 +199,14 @@ func BenchmarkConnectionIDGenerator_Validate(b *testing.B) {
|
||||
|
||||
pool := &sync.Pool{
|
||||
New: func() any {
|
||||
return NewConnectionIDGenerator(key)
|
||||
return NewConnectionIDGenerator(key, 10*time.Second)
|
||||
},
|
||||
}
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
gen := pool.Get().(*ConnectionIDGenerator)
|
||||
if !gen.Validate(cid, ip, createdAt, 10*time.Second) {
|
||||
if !gen.Validate(cid, ip, createdAt) {
|
||||
b.FailNow()
|
||||
}
|
||||
pool.Put(gen)
|
||||
|
||||
@@ -78,7 +78,6 @@ type udpFE struct {
|
||||
wg sync.WaitGroup
|
||||
genPool *sync.Pool
|
||||
logic *middleware.Logic
|
||||
maxClockSkew time.Duration
|
||||
collectTimings bool
|
||||
ctxCancel context.CancelFunc
|
||||
onceCloser sync.Once
|
||||
@@ -98,12 +97,11 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend,
|
||||
sockets: make([]*net.UDPConn, cfg.Workers),
|
||||
closing: make(chan any),
|
||||
logic: logic,
|
||||
maxClockSkew: cfg.MaxClockSkew,
|
||||
collectTimings: cfg.EnableRequestTiming,
|
||||
ParseOptions: cfg.ParseOptions,
|
||||
genPool: &sync.Pool{
|
||||
New: func() any {
|
||||
return NewConnectionIDGenerator(cfg.PrivateKey)
|
||||
return NewConnectionIDGenerator(cfg.PrivateKey, cfg.MaxClockSkew)
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -241,7 +239,7 @@ func (f *udpFE) handleRequest(ctx context.Context, r Request, w ResponseWriter)
|
||||
|
||||
// If this isn't requesting a new connection ID and the connection ID is
|
||||
// invalid, then fail.
|
||||
if actionID != connectActionID && !gen.Validate(connID, r.IP, timecache.Now(), f.maxClockSkew) {
|
||||
if actionID != connectActionID && !gen.Validate(connID, r.IP, timecache.Now()) {
|
||||
err = errBadConnectionID
|
||||
WriteError(w, txID, err)
|
||||
return
|
||||
|
||||
@@ -140,16 +140,6 @@ func (p *ihSwarm) keys(fn func(k bittorrent.InfoHash) bool) {
|
||||
p.RUnlock()
|
||||
}
|
||||
|
||||
func (p *ihSwarm) forEach(fn func(k bittorrent.InfoHash, v swarm) bool) {
|
||||
p.RLock()
|
||||
for k, v := range p.m {
|
||||
if !fn(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
p.RUnlock()
|
||||
}
|
||||
|
||||
type swarm struct {
|
||||
// map serialized peer to mtime
|
||||
seeders *peers
|
||||
@@ -172,7 +162,6 @@ func (p *peers) set(k bittorrent.Peer, v int64) {
|
||||
p.Lock()
|
||||
p.m[k] = v
|
||||
p.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (p *peers) del(k bittorrent.Peer) (ok bool) {
|
||||
|
||||
Reference in New Issue
Block a user