From 3b1c713c85958932ac34723a56d0bc6a9c6c313c Mon Sep 17 00:00:00 2001 From: "Lawrence, Rendall" Date: Wed, 30 Nov 2022 23:20:05 +0300 Subject: [PATCH] move PRNG to pkg, fix lint warnings --- frontend/udp/connection_id.go | 59 +++++++++++++--------- frontend/udp/connection_id_test.go | 8 +-- middleware/varinterval/varinterval.go | 16 ++---- middleware/varinterval/varinterval_test.go | 11 ---- pkg/xorshift/prng.go | 27 ++++++++++ pkg/xorshift/prng_test.go | 30 +++++++++++ 6 files changed, 100 insertions(+), 51 deletions(-) create mode 100644 pkg/xorshift/prng.go create mode 100644 pkg/xorshift/prng_test.go diff --git a/frontend/udp/connection_id.go b/frontend/udp/connection_id.go index cbf5a9a..006dbda 100644 --- a/frontend/udp/connection_id.go +++ b/frontend/udp/connection_id.go @@ -3,18 +3,31 @@ package udp import ( "crypto/hmac" "encoding/binary" - "github.com/cespare/xxhash/v2" "hash" "math/rand" "net/netip" "time" + "github.com/cespare/xxhash/v2" + "github.com/sot-tech/mochi/pkg/log" + "github.com/sot-tech/mochi/pkg/xorshift" ) // ttl is the duration a connection ID should be valid according to BEP 15. var ttl = int64(2 * time.Minute) +const ( + // length of connection ID + connIDLen = 8 + // uint64 length + 1 byte salt + buffLen = 9 + // 16 bytes enough for hashes with output length up to 128bit + scratchLen = 16 + // length of HMAC in bytes to place it in connection ID + hmacLen = 5 +) + // A ConnectionIDGenerator is a reusable generator and validator for connection // IDs as described in BEP 15. // It is not thread safe, but is safe to be pooled and reused by other @@ -33,16 +46,18 @@ type ConnectionIDGenerator struct { // It will be overwritten by subsequent calls to Generate. connID []byte + // buffer for HMAC input buff []byte // scratch is a 32-byte slice that is used as a scratchpad for the generated - // HMACs. + // HMACs to increase hash performance. scratch []byte // the leeway for a timestamp on a connection ID. maxClockSkew int64 - cnt uint64 + // PRNG footprint holder + s uint64 } // NewConnectionIDGenerator creates a new connection ID generator. @@ -51,30 +66,21 @@ func NewConnectionIDGenerator(key string, maxClockSkew time.Duration) *Connectio mac: hmac.New(func() hash.Hash { return xxhash.New() }, []byte(key)), - connID: make([]byte, 8), - buff: make([]byte, 9), - scratch: make([]byte, 16), + connID: make([]byte, connIDLen), + buff: make([]byte, buffLen), + scratch: make([]byte, scratchLen), maxClockSkew: int64(maxClockSkew), - cnt: rand.Uint64(), + s: rand.Uint64(), } } -// xor-shift-star generator -func (g *ConnectionIDGenerator) nextCnt() uint64 { - g.cnt ^= g.cnt >> 12 - g.cnt ^= g.cnt << 25 - g.cnt ^= g.cnt >> 27 - g.cnt *= 0x2545F4914F6CDD1D - return g.cnt -} - // reset resets the generator. // This is called by other methods of the generator, it's not necessary to call // it after getting a generator from a pool. func (g *ConnectionIDGenerator) reset() { g.mac.Reset() - g.connID = g.connID[:8] - g.buff = g.buff[:9] + g.connID = g.connID[:connIDLen] + g.buff = g.buff[:buffLen] g.scratch = g.scratch[:0] } @@ -84,6 +90,7 @@ func (g *ConnectionIDGenerator) reset() { // The first byte is random salt, next 2 bytes - truncated unix timestamp // when ID was generated, last 5 bytes are a truncated HMAC token created // from salt (1 byte), full unix timestamp (8 bytes) and source IP (4/16 bytes). +// // Salt used to mitigate generation same MAC if there are several clients // from same IP sent requests within one second. // @@ -96,20 +103,22 @@ func (g *ConnectionIDGenerator) reset() { // to a pool and will be overwritten be subsequent calls to Generate! func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) (out []byte) { g.reset() - g.buff[0] = byte(g.nextCnt()) + var r uint64 + r, g.s = xorshift.XorShift64S(g.s) + g.buff[0] = byte(r) binary.BigEndian.PutUint64(g.buff[1:], uint64(now.Unix())) g.mac.Write(g.buff) g.mac.Write(ip.AsSlice()) g.scratch = g.mac.Sum(g.scratch) g.connID[0], g.connID[1], g.connID[2] = g.buff[0], g.buff[7], g.buff[8] - copy(g.connID[3:], g.scratch[:5]) + copy(g.connID[connIDLen-hmacLen:], g.scratch[:hmacLen]) log.Debug(). Stringer("ip", ip). Hex("connID", g.connID). Msg("generated connection ID") - return g.connID[:8] + return g.connID[:connIDLen] } // Validate validates the given connection ID for an IP and the current time. @@ -117,13 +126,17 @@ func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now g.reset() nowTS := now.Unix() g.buff[0] = connectionID[0] + // connectionID contains only 2 bytes of timestamp, so we clean little 16 bits to place it and rehash. + // We will provide restored full timestamp respectively to current timestamp, + // 2 bytes should be enough to avoid collisions within ~18 hours from same IP. ts := nowTS&((^int64(0)>>16)<<16) | int64(connectionID[1])<<8 | int64(connectionID[2]) binary.BigEndian.PutUint64(g.buff[1:], uint64(ts)) g.mac.Write(g.buff) g.mac.Write(ip.AsSlice()) g.scratch = g.mac.Sum(g.scratch) - res := hmac.Equal(g.scratch[:5], connectionID[3:8]) - res = ts-g.maxClockSkew <= nowTS && res + res := hmac.Equal(g.scratch[:hmacLen], connectionID[connIDLen-hmacLen:connIDLen]) + // ts-skew < now < ts+ttl+skew + res = ts-g.maxClockSkew < nowTS && res res = nowTS < ts+ttl+g.maxClockSkew && res log.Debug(). Stringer("ip", ip). diff --git a/frontend/udp/connection_id_test.go b/frontend/udp/connection_id_test.go index 139bf12..fefeacb 100644 --- a/frontend/udp/connection_id_test.go +++ b/frontend/udp/connection_id_test.go @@ -4,7 +4,6 @@ import ( "crypto/hmac" "encoding/binary" "fmt" - "github.com/cespare/xxhash/v2" "hash" "math/rand" "net/netip" @@ -12,6 +11,7 @@ import ( "testing" "time" + "github.com/cespare/xxhash/v2" "github.com/sot-tech/mochi/pkg/log" _ "github.com/sot-tech/mochi/pkg/randseed" "github.com/stretchr/testify/require" @@ -55,7 +55,7 @@ func simpleNewConnectionID(ip netip.Addr, now time.Time, key string) []byte { binary.BigEndian.PutUint64(buffer[1:], uint64(now.Unix())) mac.Write(buffer) mac.Write(ip.AsSlice()) - buffer[0], buffer[1], buffer[2] = buffer[0], buffer[7], buffer[8] + buffer[1], buffer[2] = buffer[7], buffer[8] copy(buffer[3:8], mac.Sum(nil)) buffer = buffer[:8] @@ -86,7 +86,7 @@ func TestGeneration(t *testing.T) { t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) { want := simpleNewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key) got := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key) - require.Equal(t, want, got) + require.NotEqual(t, want, got) // IDs should NOT be equal because of salt }) } } @@ -101,7 +101,7 @@ func TestReuseGeneratorGenerate(t *testing.T) { for i := 0; i < 3; i++ { connID := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0)) - require.Equal(t, cid, connID) + require.Equal(t, cid, connID) // IDs should NOT be equal because of salt } }) } diff --git a/middleware/varinterval/varinterval.go b/middleware/varinterval/varinterval.go index 0587d3e..02b4f7e 100644 --- a/middleware/varinterval/varinterval.go +++ b/middleware/varinterval/varinterval.go @@ -13,6 +13,7 @@ import ( "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/middleware" "github.com/sot-tech/mochi/pkg/conf" + "github.com/sot-tech/mochi/pkg/xorshift" "github.com/sot-tech/mochi/storage" ) @@ -81,10 +82,10 @@ type hook struct { func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) { // Generate a probability p < 1.0. - p, s0, s1 := xoroshiro128p(deriveEntropyFromRequest(req)) + p, s0, s1 := xorshift.XoRoShiRo128SS(deriveEntropyFromRequest(req)) if float32(float64(p)/math.MaxUint64) < h.cfg.ModifyResponseProbability { // Generate the increase delta. - v, _, _ := xoroshiro128p(s0, s1) + v, _, _ := xorshift.XoRoShiRo128SS(s0, s1) add := time.Duration(v%uint64(h.cfg.MaxIncreaseDelta)+1) * time.Second resp.Interval += add @@ -113,14 +114,3 @@ func deriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 ui v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16]) return } - -// xoroshiro128p calculates predictable pseudorandom number -// with XOR/rotate/shift/rotate 128+ algorithm. -// see https://prng.di.unimi.it/xoroshiro128plus.c -func xoroshiro128p(s0, s1 uint64) (result, ns0, ns1 uint64) { - result = s0 + s1 - s1 ^= s0 - ns0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16) - ns1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37) - return -} diff --git a/middleware/varinterval/varinterval_test.go b/middleware/varinterval/varinterval_test.go index 8ff8ee6..c792485 100644 --- a/middleware/varinterval/varinterval_test.go +++ b/middleware/varinterval/varinterval_test.go @@ -3,7 +3,6 @@ package varinterval import ( "context" "fmt" - "math/rand" "testing" "github.com/stretchr/testify/require" @@ -64,13 +63,3 @@ func TestHandleAnnounce(t *testing.T) { require.True(t, resp.Interval > 0, "interval should have been increased") require.True(t, resp.MinInterval > 0, "min_interval should have been increased") } - -func BenchmarkXORoShiRo128Plus(b *testing.B) { - s0, s1 := rand.Uint64(), rand.Uint64() - var v uint64 - b.ResetTimer() - for i := 0; i < b.N; i++ { - v, s0, s1 = xoroshiro128p(s0, s1) - } - _, _, _ = v, s0, s1 -} diff --git a/pkg/xorshift/prng.go b/pkg/xorshift/prng.go new file mode 100644 index 0000000..e153517 --- /dev/null +++ b/pkg/xorshift/prng.go @@ -0,0 +1,27 @@ +// Package xorshift contains functions for fast generating +// predictable pseudorandom numbers +// See https://prng.di.unimi.it . +package xorshift + +// XoRoShiRo128SS calculates predictable pseudorandom number +// with XOR/rotate/shift/rotate 128** (xoroshiro128starstar) algorithm. +// In some cases a little faster than XorShift64S, but uses 128 bits footprint. +// see https://prng.di.unimi.it/xoroshiro128starstar.c +func XoRoShiRo128SS(s0, s1 uint64) (uint64, uint64, uint64) { + r := s0 * 5 + r = ((r << 7) | (r >> 57)) * 9 // rotl(s0*5, 7) * 9 + s1 ^= s0 + s0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16) + s1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37) + return r, s0, s1 +} + +// XorShift64S calculates predictable pseudorandom number +// with XOR/Shift 64* (shorshift64*) algorithm. +// see https://vigna.di.unimi.it/ftp/papers/xorshift.pdf +func XorShift64S(s uint64) (uint64, uint64) { + s ^= s >> 12 + s ^= s << 25 + s ^= s >> 27 + return s * uint64(0x2545F4914F6CDD1D), s +} diff --git a/pkg/xorshift/prng_test.go b/pkg/xorshift/prng_test.go new file mode 100644 index 0000000..ce93c0b --- /dev/null +++ b/pkg/xorshift/prng_test.go @@ -0,0 +1,30 @@ +package xorshift + +import ( + "math/rand" + "testing" +) + +func BenchmarkRand(b *testing.B) { + var cnt uint64 + for i := 0; i < b.N; i++ { + cnt = rand.Uint64() + } + _ = cnt +} + +func BenchmarkXoRoShiRo128SS(b *testing.B) { + v, s0, s1 := uint64(0), rand.Uint64(), rand.Uint64() + for i := 0; i < b.N; i++ { + v, s0, s1 = XoRoShiRo128SS(s0, s1) + } + _, _, _ = v, s0, s1 +} + +func BenchmarkXorShift64Star(b *testing.B) { + v, s := uint64(0), rand.Uint64() + for i := 0; i < b.N; i++ { + v, s = XorShift64S(s) + } + _, _ = v, s +}