mirror of
https://github.com/sot-tech/mochi.git
synced 2026-07-03 07:13:59 -07:00
move PRNG to pkg, fix lint warnings
This commit is contained in:
@@ -3,18 +3,31 @@ package udp
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"encoding/binary"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"hash"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
|
||||
"github.com/sot-tech/mochi/pkg/log"
|
||||
"github.com/sot-tech/mochi/pkg/xorshift"
|
||||
)
|
||||
|
||||
// ttl is the duration a connection ID should be valid according to BEP 15.
|
||||
var ttl = int64(2 * time.Minute)
|
||||
|
||||
const (
|
||||
// length of connection ID
|
||||
connIDLen = 8
|
||||
// uint64 length + 1 byte salt
|
||||
buffLen = 9
|
||||
// 16 bytes enough for hashes with output length up to 128bit
|
||||
scratchLen = 16
|
||||
// length of HMAC in bytes to place it in connection ID
|
||||
hmacLen = 5
|
||||
)
|
||||
|
||||
// A ConnectionIDGenerator is a reusable generator and validator for connection
|
||||
// IDs as described in BEP 15.
|
||||
// It is not thread safe, but is safe to be pooled and reused by other
|
||||
@@ -33,16 +46,18 @@ type ConnectionIDGenerator struct {
|
||||
// It will be overwritten by subsequent calls to Generate.
|
||||
connID []byte
|
||||
|
||||
// buffer for HMAC input
|
||||
buff []byte
|
||||
|
||||
// scratch is a 32-byte slice that is used as a scratchpad for the generated
|
||||
// HMACs.
|
||||
// HMACs to increase hash performance.
|
||||
scratch []byte
|
||||
|
||||
// the leeway for a timestamp on a connection ID.
|
||||
maxClockSkew int64
|
||||
|
||||
cnt uint64
|
||||
// PRNG footprint holder
|
||||
s uint64
|
||||
}
|
||||
|
||||
// NewConnectionIDGenerator creates a new connection ID generator.
|
||||
@@ -51,30 +66,21 @@ func NewConnectionIDGenerator(key string, maxClockSkew time.Duration) *Connectio
|
||||
mac: hmac.New(func() hash.Hash {
|
||||
return xxhash.New()
|
||||
}, []byte(key)),
|
||||
connID: make([]byte, 8),
|
||||
buff: make([]byte, 9),
|
||||
scratch: make([]byte, 16),
|
||||
connID: make([]byte, connIDLen),
|
||||
buff: make([]byte, buffLen),
|
||||
scratch: make([]byte, scratchLen),
|
||||
maxClockSkew: int64(maxClockSkew),
|
||||
cnt: rand.Uint64(),
|
||||
s: rand.Uint64(),
|
||||
}
|
||||
}
|
||||
|
||||
// xor-shift-star generator
|
||||
func (g *ConnectionIDGenerator) nextCnt() uint64 {
|
||||
g.cnt ^= g.cnt >> 12
|
||||
g.cnt ^= g.cnt << 25
|
||||
g.cnt ^= g.cnt >> 27
|
||||
g.cnt *= 0x2545F4914F6CDD1D
|
||||
return g.cnt
|
||||
}
|
||||
|
||||
// reset resets the generator.
|
||||
// This is called by other methods of the generator, it's not necessary to call
|
||||
// it after getting a generator from a pool.
|
||||
func (g *ConnectionIDGenerator) reset() {
|
||||
g.mac.Reset()
|
||||
g.connID = g.connID[:8]
|
||||
g.buff = g.buff[:9]
|
||||
g.connID = g.connID[:connIDLen]
|
||||
g.buff = g.buff[:buffLen]
|
||||
g.scratch = g.scratch[:0]
|
||||
}
|
||||
|
||||
@@ -84,6 +90,7 @@ func (g *ConnectionIDGenerator) reset() {
|
||||
// The first byte is random salt, next 2 bytes - truncated unix timestamp
|
||||
// when ID was generated, last 5 bytes are a truncated HMAC token created
|
||||
// from salt (1 byte), full unix timestamp (8 bytes) and source IP (4/16 bytes).
|
||||
//
|
||||
// Salt used to mitigate generation same MAC if there are several clients
|
||||
// from same IP sent requests within one second.
|
||||
//
|
||||
@@ -96,20 +103,22 @@ func (g *ConnectionIDGenerator) reset() {
|
||||
// to a pool and will be overwritten be subsequent calls to Generate!
|
||||
func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) (out []byte) {
|
||||
g.reset()
|
||||
g.buff[0] = byte(g.nextCnt())
|
||||
var r uint64
|
||||
r, g.s = xorshift.XorShift64S(g.s)
|
||||
g.buff[0] = byte(r)
|
||||
binary.BigEndian.PutUint64(g.buff[1:], uint64(now.Unix()))
|
||||
g.mac.Write(g.buff)
|
||||
g.mac.Write(ip.AsSlice())
|
||||
|
||||
g.scratch = g.mac.Sum(g.scratch)
|
||||
g.connID[0], g.connID[1], g.connID[2] = g.buff[0], g.buff[7], g.buff[8]
|
||||
copy(g.connID[3:], g.scratch[:5])
|
||||
copy(g.connID[connIDLen-hmacLen:], g.scratch[:hmacLen])
|
||||
|
||||
log.Debug().
|
||||
Stringer("ip", ip).
|
||||
Hex("connID", g.connID).
|
||||
Msg("generated connection ID")
|
||||
return g.connID[:8]
|
||||
return g.connID[:connIDLen]
|
||||
}
|
||||
|
||||
// Validate validates the given connection ID for an IP and the current time.
|
||||
@@ -117,13 +126,17 @@ func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now
|
||||
g.reset()
|
||||
nowTS := now.Unix()
|
||||
g.buff[0] = connectionID[0]
|
||||
// connectionID contains only 2 bytes of timestamp, so we clean little 16 bits to place it and rehash.
|
||||
// We will provide restored full timestamp respectively to current timestamp,
|
||||
// 2 bytes should be enough to avoid collisions within ~18 hours from same IP.
|
||||
ts := nowTS&((^int64(0)>>16)<<16) | int64(connectionID[1])<<8 | int64(connectionID[2])
|
||||
binary.BigEndian.PutUint64(g.buff[1:], uint64(ts))
|
||||
g.mac.Write(g.buff)
|
||||
g.mac.Write(ip.AsSlice())
|
||||
g.scratch = g.mac.Sum(g.scratch)
|
||||
res := hmac.Equal(g.scratch[:5], connectionID[3:8])
|
||||
res = ts-g.maxClockSkew <= nowTS && res
|
||||
res := hmac.Equal(g.scratch[:hmacLen], connectionID[connIDLen-hmacLen:connIDLen])
|
||||
// ts-skew < now < ts+ttl+skew
|
||||
res = ts-g.maxClockSkew < nowTS && res
|
||||
res = nowTS < ts+ttl+g.maxClockSkew && res
|
||||
log.Debug().
|
||||
Stringer("ip", ip).
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"crypto/hmac"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"hash"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/sot-tech/mochi/pkg/log"
|
||||
_ "github.com/sot-tech/mochi/pkg/randseed"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -55,7 +55,7 @@ func simpleNewConnectionID(ip netip.Addr, now time.Time, key string) []byte {
|
||||
binary.BigEndian.PutUint64(buffer[1:], uint64(now.Unix()))
|
||||
mac.Write(buffer)
|
||||
mac.Write(ip.AsSlice())
|
||||
buffer[0], buffer[1], buffer[2] = buffer[0], buffer[7], buffer[8]
|
||||
buffer[1], buffer[2] = buffer[7], buffer[8]
|
||||
copy(buffer[3:8], mac.Sum(nil))
|
||||
buffer = buffer[:8]
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestGeneration(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
|
||||
want := simpleNewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
||||
got := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
||||
require.Equal(t, want, got)
|
||||
require.NotEqual(t, want, got) // IDs should NOT be equal because of salt
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func TestReuseGeneratorGenerate(t *testing.T) {
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
connID := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0))
|
||||
require.Equal(t, cid, connID)
|
||||
require.Equal(t, cid, connID) // IDs should NOT be equal because of salt
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/middleware"
|
||||
"github.com/sot-tech/mochi/pkg/conf"
|
||||
"github.com/sot-tech/mochi/pkg/xorshift"
|
||||
"github.com/sot-tech/mochi/storage"
|
||||
)
|
||||
|
||||
@@ -81,10 +82,10 @@ type hook struct {
|
||||
|
||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||
// Generate a probability p < 1.0.
|
||||
p, s0, s1 := xoroshiro128p(deriveEntropyFromRequest(req))
|
||||
p, s0, s1 := xorshift.XoRoShiRo128SS(deriveEntropyFromRequest(req))
|
||||
if float32(float64(p)/math.MaxUint64) < h.cfg.ModifyResponseProbability {
|
||||
// Generate the increase delta.
|
||||
v, _, _ := xoroshiro128p(s0, s1)
|
||||
v, _, _ := xorshift.XoRoShiRo128SS(s0, s1)
|
||||
add := time.Duration(v%uint64(h.cfg.MaxIncreaseDelta)+1) * time.Second
|
||||
|
||||
resp.Interval += add
|
||||
@@ -113,14 +114,3 @@ func deriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 ui
|
||||
v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16])
|
||||
return
|
||||
}
|
||||
|
||||
// xoroshiro128p calculates predictable pseudorandom number
|
||||
// with XOR/rotate/shift/rotate 128+ algorithm.
|
||||
// see https://prng.di.unimi.it/xoroshiro128plus.c
|
||||
func xoroshiro128p(s0, s1 uint64) (result, ns0, ns1 uint64) {
|
||||
result = s0 + s1
|
||||
s1 ^= s0
|
||||
ns0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16)
|
||||
ns1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package varinterval
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -64,13 +63,3 @@ func TestHandleAnnounce(t *testing.T) {
|
||||
require.True(t, resp.Interval > 0, "interval should have been increased")
|
||||
require.True(t, resp.MinInterval > 0, "min_interval should have been increased")
|
||||
}
|
||||
|
||||
func BenchmarkXORoShiRo128Plus(b *testing.B) {
|
||||
s0, s1 := rand.Uint64(), rand.Uint64()
|
||||
var v uint64
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v, s0, s1 = xoroshiro128p(s0, s1)
|
||||
}
|
||||
_, _, _ = v, s0, s1
|
||||
}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
// Package xorshift contains functions for fast generating
|
||||
// predictable pseudorandom numbers
|
||||
// See https://prng.di.unimi.it .
|
||||
package xorshift
|
||||
|
||||
// XoRoShiRo128SS calculates predictable pseudorandom number
|
||||
// with XOR/rotate/shift/rotate 128** (xoroshiro128starstar) algorithm.
|
||||
// In some cases a little faster than XorShift64S, but uses 128 bits footprint.
|
||||
// see https://prng.di.unimi.it/xoroshiro128starstar.c
|
||||
func XoRoShiRo128SS(s0, s1 uint64) (uint64, uint64, uint64) {
|
||||
r := s0 * 5
|
||||
r = ((r << 7) | (r >> 57)) * 9 // rotl(s0*5, 7) * 9
|
||||
s1 ^= s0
|
||||
s0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16)
|
||||
s1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37)
|
||||
return r, s0, s1
|
||||
}
|
||||
|
||||
// XorShift64S calculates predictable pseudorandom number
|
||||
// with XOR/Shift 64* (shorshift64*) algorithm.
|
||||
// see https://vigna.di.unimi.it/ftp/papers/xorshift.pdf
|
||||
func XorShift64S(s uint64) (uint64, uint64) {
|
||||
s ^= s >> 12
|
||||
s ^= s << 25
|
||||
s ^= s >> 27
|
||||
return s * uint64(0x2545F4914F6CDD1D), s
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package xorshift
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkRand(b *testing.B) {
|
||||
var cnt uint64
|
||||
for i := 0; i < b.N; i++ {
|
||||
cnt = rand.Uint64()
|
||||
}
|
||||
_ = cnt
|
||||
}
|
||||
|
||||
func BenchmarkXoRoShiRo128SS(b *testing.B) {
|
||||
v, s0, s1 := uint64(0), rand.Uint64(), rand.Uint64()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v, s0, s1 = XoRoShiRo128SS(s0, s1)
|
||||
}
|
||||
_, _, _ = v, s0, s1
|
||||
}
|
||||
|
||||
func BenchmarkXorShift64Star(b *testing.B) {
|
||||
v, s := uint64(0), rand.Uint64()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v, s = XorShift64S(s)
|
||||
}
|
||||
_, _ = v, s
|
||||
}
|
||||
Reference in New Issue
Block a user