diff --git a/.golangci.yaml b/.golangci.yaml index 94a7e35..7eb8775 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,6 +1,8 @@ --- run: - go: "1.19" + # mochi in not written with generics (a.t.m), + # so we can check with 1.17 + go: "1.17" timeout: "5m" output: sort-results: true diff --git a/frontend/udp/connection_id.go b/frontend/udp/connection_id.go index 7272e71..8a4546d 100644 --- a/frontend/udp/connection_id.go +++ b/frontend/udp/connection_id.go @@ -33,31 +33,21 @@ type ConnectionIDGenerator struct { // It will be overwritten by subsequent calls to Generate. connID []byte - // scratch is a 32-byte slice that is used as a scratchpad for the generated - // HMACs. - scratch []byte + // the leeway for a timestamp on a connection ID. + maxClockSkew time.Duration } // NewConnectionIDGenerator creates a new connection ID generator. -func NewConnectionIDGenerator(key string) *ConnectionIDGenerator { +func NewConnectionIDGenerator(key string, maxClockSkew time.Duration) *ConnectionIDGenerator { return &ConnectionIDGenerator{ mac: hmac.New(func() hash.Hash { return xxh3.New() }, []byte(key)), - connID: make([]byte, 8), - scratch: make([]byte, 32), + connID: make([]byte, 8), + maxClockSkew: maxClockSkew, } } -// reset resets the generator. -// This is called by other methods of the generator, it's not necessary to call -// it after getting a generator from a pool. -func (g *ConnectionIDGenerator) reset() { - g.mac.Reset() - g.connID = g.connID[:8] - g.scratch = g.scratch[:0] -} - // Generate generates an 8-byte connection ID as described in BEP 15 for the // given IP and the current time. // @@ -73,15 +63,12 @@ func (g *ConnectionIDGenerator) reset() { // will be reused, so it must not be referenced after returning the generator // to a pool and will be overwritten be subsequent calls to Generate! func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) []byte { - g.reset() - + g.mac.Reset() binary.BigEndian.PutUint32(g.connID, uint32(now.Unix())) g.mac.Write(g.connID[:4]) - ipBytes, _ := ip.MarshalBinary() - g.mac.Write(ipBytes) - g.scratch = g.mac.Sum(g.scratch) - copy(g.connID[4:8], g.scratch[:4]) + g.mac.Write(ip.AsSlice()) + copy(g.connID[4:8], g.mac.Sum(nil)[:4]) log.Debug(). Stringer("ip", ip). @@ -92,22 +79,20 @@ func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) []byte { } // Validate validates the given connection ID for an IP and the current time. -func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now time.Time, maxClockSkew time.Duration) bool { - ts := time.Unix(int64(binary.BigEndian.Uint32(connectionID[:4])), 0) +func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now time.Time) bool { + g.mac.Reset() + tsBytes := connectionID[:4] + ts := time.Unix(int64(binary.BigEndian.Uint32(tsBytes)), 0) log.Debug(). Stringer("ip", ip). - Time("ts", ts).Time("now", now). + Time("ts", ts). + Time("now", now). Hex("connID", g.connID). Msg("validating connection ID") - if now.After(ts.Add(ttl)) || ts.After(now.Add(maxClockSkew)) { - return false - } - g.reset() - - g.mac.Write(connectionID[:4]) - ipBytes, _ := ip.MarshalBinary() - g.mac.Write(ipBytes) - g.scratch = g.mac.Sum(g.scratch) - return hmac.Equal(g.scratch[:4], connectionID[4:]) + g.mac.Write(tsBytes) + g.mac.Write(ip.AsSlice()) + return hmac.Equal(g.mac.Sum(nil)[:4], connectionID[4:8]) && + now.Before(ts.Add(ttl)) && + ts.Before(now.Add(g.maxClockSkew)) } diff --git a/frontend/udp/connection_id_test.go b/frontend/udp/connection_id_test.go index 5dcd5a8..aa23157 100644 --- a/frontend/udp/connection_id_test.go +++ b/frontend/udp/connection_id_test.go @@ -4,13 +4,14 @@ import ( "crypto/hmac" "encoding/binary" "fmt" + "hash" "net/netip" "sync" "testing" "time" - "github.com/minio/sha256-simd" "github.com/stretchr/testify/require" + "github.com/zeebo/xxh3" "github.com/sot-tech/mochi/pkg/log" ) @@ -32,14 +33,14 @@ var golden = []struct { // This is a wrapper around creating a new ConnectionIDGenerator and generating // an ID. It is recommended to use the generator for performance. func NewConnectionID(ip netip.Addr, now time.Time, key string) []byte { - return NewConnectionIDGenerator(key).Generate(ip, now) + return NewConnectionIDGenerator(key, 0).Generate(ip, now) } // ValidConnectionID determines whether a connection identifier is legitimate. // This is a wrapper around creating a new ConnectionIDGenerator and validating // the ID. It is recommended to use the generator for performance. func ValidConnectionID(connectionID []byte, ip netip.Addr, now time.Time, maxClockSkew time.Duration, key string) bool { - return NewConnectionIDGenerator(key).Validate(connectionID, ip, now, maxClockSkew) + return NewConnectionIDGenerator(key, maxClockSkew).Validate(connectionID, ip, now) } // simpleNewConnectionID generates a new connection ID the explicit way. @@ -48,10 +49,11 @@ func simpleNewConnectionID(ip netip.Addr, now time.Time, key string) []byte { buf := make([]byte, 8) binary.BigEndian.PutUint32(buf, uint32(now.Unix())) - mac := hmac.New(sha256.New, []byte(key)) + mac := hmac.New(func() hash.Hash { + return xxh3.New() + }, []byte(key)) mac.Write(buf[:4]) - ipBytes, _ := ip.MarshalBinary() - mac.Write(ipBytes) + mac.Write(ip.AsSlice()) macBytes := mac.Sum(nil)[:4] copy(buf[4:], macBytes) @@ -93,7 +95,7 @@ func TestReuseGeneratorGenerate(t *testing.T) { cid := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key) require.Len(t, cid, 8) - gen := NewConnectionIDGenerator(tt.key) + gen := NewConnectionIDGenerator(tt.key, 0) for i := 0; i < 3; i++ { connID := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0)) @@ -106,10 +108,10 @@ func TestReuseGeneratorGenerate(t *testing.T) { func TestReuseGeneratorValidate(t *testing.T) { for _, tt := range golden { t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) { - gen := NewConnectionIDGenerator(tt.key) + gen := NewConnectionIDGenerator(tt.key, time.Minute) cid := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0)) for i := 0; i < 3; i++ { - got := gen.Validate(cid, netip.MustParseAddr(tt.ip), time.Unix(tt.now, 0), time.Minute) + got := gen.Validate(cid, netip.MustParseAddr(tt.ip), time.Unix(tt.now, 0)) if got != tt.valid { t.Errorf("expected validity: %t got validity: %t", tt.valid, got) } @@ -159,7 +161,7 @@ func BenchmarkConnectionIDGenerator_Generate(b *testing.B) { pool := &sync.Pool{ New: func() any { - return NewConnectionIDGenerator(key) + return NewConnectionIDGenerator(key, 0) }, } @@ -197,14 +199,14 @@ func BenchmarkConnectionIDGenerator_Validate(b *testing.B) { pool := &sync.Pool{ New: func() any { - return NewConnectionIDGenerator(key) + return NewConnectionIDGenerator(key, 10*time.Second) }, } b.RunParallel(func(pb *testing.PB) { for pb.Next() { gen := pool.Get().(*ConnectionIDGenerator) - if !gen.Validate(cid, ip, createdAt, 10*time.Second) { + if !gen.Validate(cid, ip, createdAt) { b.FailNow() } pool.Put(gen) diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index 61c1a3e..7c8a22b 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -78,7 +78,6 @@ type udpFE struct { wg sync.WaitGroup genPool *sync.Pool logic *middleware.Logic - maxClockSkew time.Duration collectTimings bool ctxCancel context.CancelFunc onceCloser sync.Once @@ -98,12 +97,11 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend, sockets: make([]*net.UDPConn, cfg.Workers), closing: make(chan any), logic: logic, - maxClockSkew: cfg.MaxClockSkew, collectTimings: cfg.EnableRequestTiming, ParseOptions: cfg.ParseOptions, genPool: &sync.Pool{ New: func() any { - return NewConnectionIDGenerator(cfg.PrivateKey) + return NewConnectionIDGenerator(cfg.PrivateKey, cfg.MaxClockSkew) }, }, } @@ -241,7 +239,7 @@ func (f *udpFE) handleRequest(ctx context.Context, r Request, w ResponseWriter) // If this isn't requesting a new connection ID and the connection ID is // invalid, then fail. - if actionID != connectActionID && !gen.Validate(connID, r.IP, timecache.Now(), f.maxClockSkew) { + if actionID != connectActionID && !gen.Validate(connID, r.IP, timecache.Now()) { err = errBadConnectionID WriteError(w, txID, err) return diff --git a/storage/memory/storage.go b/storage/memory/storage.go index d8d20a3..c6da56d 100644 --- a/storage/memory/storage.go +++ b/storage/memory/storage.go @@ -140,16 +140,6 @@ func (p *ihSwarm) keys(fn func(k bittorrent.InfoHash) bool) { p.RUnlock() } -func (p *ihSwarm) forEach(fn func(k bittorrent.InfoHash, v swarm) bool) { - p.RLock() - for k, v := range p.m { - if !fn(k, v) { - break - } - } - p.RUnlock() -} - type swarm struct { // map serialized peer to mtime seeders *peers @@ -172,7 +162,6 @@ func (p *peers) set(k bittorrent.Peer, v int64) { p.Lock() p.m[k] = v p.Unlock() - return } func (p *peers) del(k bittorrent.Peer) (ok bool) {