move PRNG to pkg, fix lint warnings

This commit is contained in:
Lawrence, Rendall
2022-11-30 23:20:05 +03:00
parent f80e9af156
commit 3b1c713c85
6 changed files with 100 additions and 51 deletions
+36 -23
View File
@@ -3,18 +3,31 @@ package udp
import (
"crypto/hmac"
"encoding/binary"
"github.com/cespare/xxhash/v2"
"hash"
"math/rand"
"net/netip"
"time"
"github.com/cespare/xxhash/v2"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/xorshift"
)
// ttl is the duration a connection ID should be valid according to BEP 15.
var ttl = int64(2 * time.Minute)
const (
// length of connection ID
connIDLen = 8
// uint64 length + 1 byte salt
buffLen = 9
// 16 bytes enough for hashes with output length up to 128bit
scratchLen = 16
// length of HMAC in bytes to place it in connection ID
hmacLen = 5
)
// A ConnectionIDGenerator is a reusable generator and validator for connection
// IDs as described in BEP 15.
// It is not thread safe, but is safe to be pooled and reused by other
@@ -33,16 +46,18 @@ type ConnectionIDGenerator struct {
// It will be overwritten by subsequent calls to Generate.
connID []byte
// buffer for HMAC input
buff []byte
// scratch is a 32-byte slice that is used as a scratchpad for the generated
// HMACs.
// HMACs to increase hash performance.
scratch []byte
// the leeway for a timestamp on a connection ID.
maxClockSkew int64
cnt uint64
// PRNG footprint holder
s uint64
}
// NewConnectionIDGenerator creates a new connection ID generator.
@@ -51,30 +66,21 @@ func NewConnectionIDGenerator(key string, maxClockSkew time.Duration) *Connectio
mac: hmac.New(func() hash.Hash {
return xxhash.New()
}, []byte(key)),
connID: make([]byte, 8),
buff: make([]byte, 9),
scratch: make([]byte, 16),
connID: make([]byte, connIDLen),
buff: make([]byte, buffLen),
scratch: make([]byte, scratchLen),
maxClockSkew: int64(maxClockSkew),
cnt: rand.Uint64(),
s: rand.Uint64(),
}
}
// xor-shift-star generator
func (g *ConnectionIDGenerator) nextCnt() uint64 {
g.cnt ^= g.cnt >> 12
g.cnt ^= g.cnt << 25
g.cnt ^= g.cnt >> 27
g.cnt *= 0x2545F4914F6CDD1D
return g.cnt
}
// reset resets the generator.
// This is called by other methods of the generator, it's not necessary to call
// it after getting a generator from a pool.
func (g *ConnectionIDGenerator) reset() {
g.mac.Reset()
g.connID = g.connID[:8]
g.buff = g.buff[:9]
g.connID = g.connID[:connIDLen]
g.buff = g.buff[:buffLen]
g.scratch = g.scratch[:0]
}
@@ -84,6 +90,7 @@ func (g *ConnectionIDGenerator) reset() {
// The first byte is random salt, next 2 bytes - truncated unix timestamp
// when ID was generated, last 5 bytes are a truncated HMAC token created
// from salt (1 byte), full unix timestamp (8 bytes) and source IP (4/16 bytes).
//
// Salt used to mitigate generation same MAC if there are several clients
// from same IP sent requests within one second.
//
@@ -96,20 +103,22 @@ func (g *ConnectionIDGenerator) reset() {
// to a pool and will be overwritten be subsequent calls to Generate!
func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) (out []byte) {
g.reset()
g.buff[0] = byte(g.nextCnt())
var r uint64
r, g.s = xorshift.XorShift64S(g.s)
g.buff[0] = byte(r)
binary.BigEndian.PutUint64(g.buff[1:], uint64(now.Unix()))
g.mac.Write(g.buff)
g.mac.Write(ip.AsSlice())
g.scratch = g.mac.Sum(g.scratch)
g.connID[0], g.connID[1], g.connID[2] = g.buff[0], g.buff[7], g.buff[8]
copy(g.connID[3:], g.scratch[:5])
copy(g.connID[connIDLen-hmacLen:], g.scratch[:hmacLen])
log.Debug().
Stringer("ip", ip).
Hex("connID", g.connID).
Msg("generated connection ID")
return g.connID[:8]
return g.connID[:connIDLen]
}
// Validate validates the given connection ID for an IP and the current time.
@@ -117,13 +126,17 @@ func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now
g.reset()
nowTS := now.Unix()
g.buff[0] = connectionID[0]
// connectionID contains only 2 bytes of timestamp, so we clean little 16 bits to place it and rehash.
// We will provide restored full timestamp respectively to current timestamp,
// 2 bytes should be enough to avoid collisions within ~18 hours from same IP.
ts := nowTS&((^int64(0)>>16)<<16) | int64(connectionID[1])<<8 | int64(connectionID[2])
binary.BigEndian.PutUint64(g.buff[1:], uint64(ts))
g.mac.Write(g.buff)
g.mac.Write(ip.AsSlice())
g.scratch = g.mac.Sum(g.scratch)
res := hmac.Equal(g.scratch[:5], connectionID[3:8])
res = ts-g.maxClockSkew <= nowTS && res
res := hmac.Equal(g.scratch[:hmacLen], connectionID[connIDLen-hmacLen:connIDLen])
// ts-skew < now < ts+ttl+skew
res = ts-g.maxClockSkew < nowTS && res
res = nowTS < ts+ttl+g.maxClockSkew && res
log.Debug().
Stringer("ip", ip).
+4 -4
View File
@@ -4,7 +4,6 @@ import (
"crypto/hmac"
"encoding/binary"
"fmt"
"github.com/cespare/xxhash/v2"
"hash"
"math/rand"
"net/netip"
@@ -12,6 +11,7 @@ import (
"testing"
"time"
"github.com/cespare/xxhash/v2"
"github.com/sot-tech/mochi/pkg/log"
_ "github.com/sot-tech/mochi/pkg/randseed"
"github.com/stretchr/testify/require"
@@ -55,7 +55,7 @@ func simpleNewConnectionID(ip netip.Addr, now time.Time, key string) []byte {
binary.BigEndian.PutUint64(buffer[1:], uint64(now.Unix()))
mac.Write(buffer)
mac.Write(ip.AsSlice())
buffer[0], buffer[1], buffer[2] = buffer[0], buffer[7], buffer[8]
buffer[1], buffer[2] = buffer[7], buffer[8]
copy(buffer[3:8], mac.Sum(nil))
buffer = buffer[:8]
@@ -86,7 +86,7 @@ func TestGeneration(t *testing.T) {
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
want := simpleNewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
got := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
require.Equal(t, want, got)
require.NotEqual(t, want, got) // IDs should NOT be equal because of salt
})
}
}
@@ -101,7 +101,7 @@ func TestReuseGeneratorGenerate(t *testing.T) {
for i := 0; i < 3; i++ {
connID := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0))
require.Equal(t, cid, connID)
require.Equal(t, cid, connID) // IDs should NOT be equal because of salt
}
})
}
+3 -13
View File
@@ -13,6 +13,7 @@ import (
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/middleware"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/xorshift"
"github.com/sot-tech/mochi/storage"
)
@@ -81,10 +82,10 @@ type hook struct {
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
// Generate a probability p < 1.0.
p, s0, s1 := xoroshiro128p(deriveEntropyFromRequest(req))
p, s0, s1 := xorshift.XoRoShiRo128SS(deriveEntropyFromRequest(req))
if float32(float64(p)/math.MaxUint64) < h.cfg.ModifyResponseProbability {
// Generate the increase delta.
v, _, _ := xoroshiro128p(s0, s1)
v, _, _ := xorshift.XoRoShiRo128SS(s0, s1)
add := time.Duration(v%uint64(h.cfg.MaxIncreaseDelta)+1) * time.Second
resp.Interval += add
@@ -113,14 +114,3 @@ func deriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 ui
v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16])
return
}
// xoroshiro128p calculates predictable pseudorandom number
// with XOR/rotate/shift/rotate 128+ algorithm.
// see https://prng.di.unimi.it/xoroshiro128plus.c
func xoroshiro128p(s0, s1 uint64) (result, ns0, ns1 uint64) {
result = s0 + s1
s1 ^= s0
ns0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16)
ns1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37)
return
}
@@ -3,7 +3,6 @@ package varinterval
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/stretchr/testify/require"
@@ -64,13 +63,3 @@ func TestHandleAnnounce(t *testing.T) {
require.True(t, resp.Interval > 0, "interval should have been increased")
require.True(t, resp.MinInterval > 0, "min_interval should have been increased")
}
func BenchmarkXORoShiRo128Plus(b *testing.B) {
s0, s1 := rand.Uint64(), rand.Uint64()
var v uint64
b.ResetTimer()
for i := 0; i < b.N; i++ {
v, s0, s1 = xoroshiro128p(s0, s1)
}
_, _, _ = v, s0, s1
}
+27
View File
@@ -0,0 +1,27 @@
// Package xorshift contains functions for fast generating
// predictable pseudorandom numbers
// See https://prng.di.unimi.it .
package xorshift
// XoRoShiRo128SS calculates predictable pseudorandom number
// with XOR/rotate/shift/rotate 128** (xoroshiro128starstar) algorithm.
// In some cases a little faster than XorShift64S, but uses 128 bits footprint.
// see https://prng.di.unimi.it/xoroshiro128starstar.c
func XoRoShiRo128SS(s0, s1 uint64) (uint64, uint64, uint64) {
r := s0 * 5
r = ((r << 7) | (r >> 57)) * 9 // rotl(s0*5, 7) * 9
s1 ^= s0
s0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16)
s1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37)
return r, s0, s1
}
// XorShift64S calculates predictable pseudorandom number
// with XOR/Shift 64* (shorshift64*) algorithm.
// see https://vigna.di.unimi.it/ftp/papers/xorshift.pdf
func XorShift64S(s uint64) (uint64, uint64) {
s ^= s >> 12
s ^= s << 25
s ^= s >> 27
return s * uint64(0x2545F4914F6CDD1D), s
}
+30
View File
@@ -0,0 +1,30 @@
package xorshift
import (
"math/rand"
"testing"
)
func BenchmarkRand(b *testing.B) {
var cnt uint64
for i := 0; i < b.N; i++ {
cnt = rand.Uint64()
}
_ = cnt
}
func BenchmarkXoRoShiRo128SS(b *testing.B) {
v, s0, s1 := uint64(0), rand.Uint64(), rand.Uint64()
for i := 0; i < b.N; i++ {
v, s0, s1 = XoRoShiRo128SS(s0, s1)
}
_, _, _ = v, s0, s1
}
func BenchmarkXorShift64Star(b *testing.B) {
v, s := uint64(0), rand.Uint64()
for i := 0; i < b.N; i++ {
v, s = XorShift64S(s)
}
_, _ = v, s
}