mirror of
https://github.com/sot-tech/mochi.git
synced 2026-06-11 07:23:30 -07:00
(partially tested) add scrape JWT validation
* remove `xorshift` package, add internal function in `varinterval` * change `bittorrent.QueryParams` getters to search case insensitive keys
This commit is contained in:
@@ -180,15 +180,15 @@ func parseQuery(query string) (q *QueryParams, err error) {
|
||||
|
||||
// String returns a string parsed from a query. Every key can be returned as a
|
||||
// string because they are encoded in the URL as strings.
|
||||
func (qp *QueryParams) String(key string) (string, bool) {
|
||||
value, ok := qp.params[key]
|
||||
func (qp QueryParams) String(key string) (string, bool) {
|
||||
value, ok := qp.params[strings.ToLower(key)]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Uint returns an uint parsed from a query. After being called, it is safe to
|
||||
// cast the uint64 to your desired length.
|
||||
func (qp *QueryParams) Uint(key string, bitSize int) (uint64, error) {
|
||||
str, exists := qp.params[key]
|
||||
func (qp QueryParams) Uint(key string, bitSize int) (uint64, error) {
|
||||
str, exists := qp.params[strings.ToLower(key)]
|
||||
if !exists {
|
||||
return 0, ErrKeyNotFound
|
||||
}
|
||||
@@ -197,17 +197,17 @@ func (qp *QueryParams) Uint(key string, bitSize int) (uint64, error) {
|
||||
}
|
||||
|
||||
// InfoHashes returns a list of requested infohashes.
|
||||
func (qp *QueryParams) InfoHashes() []InfoHash {
|
||||
func (qp QueryParams) InfoHashes() []InfoHash {
|
||||
return qp.infoHashes
|
||||
}
|
||||
|
||||
// RawPath returns the raw path from the parsed URL.
|
||||
func (qp *QueryParams) RawPath() string {
|
||||
func (qp QueryParams) RawPath() string {
|
||||
return qp.path
|
||||
}
|
||||
|
||||
// RawQuery returns the raw query from the parsed URL.
|
||||
func (qp *QueryParams) RawQuery() string {
|
||||
func (qp QueryParams) RawQuery() string {
|
||||
return qp.query
|
||||
}
|
||||
|
||||
|
||||
Vendored
+3
@@ -264,10 +264,13 @@ mochi:
|
||||
prehooks:
|
||||
# - name: jwt
|
||||
# options:
|
||||
# header: "authorization"
|
||||
# issuer: "https://issuer.com"
|
||||
# audience: "https://some.issuer.com"
|
||||
# jwk_set_url: "https://issuer.com/keys"
|
||||
# jwk_set_update_interval: 5m
|
||||
# handle_announce: true
|
||||
# handle_scrape: false
|
||||
#
|
||||
# - name: client approval
|
||||
# options:
|
||||
|
||||
+177
-67
@@ -1,15 +1,15 @@
|
||||
// Package jwt implements a Hook that fails an Announce if the client's request
|
||||
// Package jwt implements a Hook that fails on Announce or Scrape if the client's request
|
||||
// is missing a valid JSON Web Token.
|
||||
//
|
||||
// JWTs are validated against the standard claims in RFC7519 along with an
|
||||
// extra "infohash" claim that verifies the client has access to the Swarm.
|
||||
// RS256 keys are asynchronously rotated from a provided JWK Set HTTP endpoint.
|
||||
// extra "infohash(es)" claim that verifies the client has access to the Swarm.
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/MicahParks/keyfunc"
|
||||
@@ -24,7 +24,11 @@ import (
|
||||
)
|
||||
|
||||
// Name is the name by which this middleware is registered with Conf.
|
||||
const Name = "jwt"
|
||||
const (
|
||||
Name = "jwt"
|
||||
authorizationHeader = "authorization"
|
||||
bearerAuthPrefix = "bearer "
|
||||
)
|
||||
|
||||
func init() {
|
||||
middleware.RegisterBuilder(Name, build)
|
||||
@@ -38,8 +42,6 @@ var (
|
||||
// ErrInvalidJWT is returned when a JWT fails to verify.
|
||||
ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt")
|
||||
|
||||
errInvalidInfoHashClaim = errors.New("token has invalid \"infohash\" claim")
|
||||
|
||||
errJWKsNotSet = errors.New("required parameters not provided: Issuer, Audience and/or JWKSetURL")
|
||||
|
||||
hmacAlgorithms = jwt.WithValidMethods([]string{
|
||||
@@ -54,10 +56,13 @@ var (
|
||||
// Config represents all the values required by this middleware to fetch JWKs
|
||||
// and verify JWTs.
|
||||
type Config struct {
|
||||
Header string
|
||||
Issuer string
|
||||
Audience string
|
||||
JWKSetURL string `cfg:"jwk_set_url"`
|
||||
JWKUpdateInterval time.Duration `cfg:"jwk_set_update_interval"`
|
||||
HandleAnnounce bool `cfg:"handle_announce"`
|
||||
HandleScrape bool `cfg:"handle_scrape"`
|
||||
}
|
||||
|
||||
type hook struct {
|
||||
@@ -65,30 +70,37 @@ type hook struct {
|
||||
jwks *keyfunc.JWKS
|
||||
}
|
||||
|
||||
type claims struct {
|
||||
jwt.RegisteredClaims
|
||||
InfoHash string `json:"infohash,omitempty"`
|
||||
}
|
||||
|
||||
func build(options conf.MapConfig, _ storage.PeerStorage) (h middleware.Hook, err error) {
|
||||
var cfg Config
|
||||
|
||||
logger.Debug().Object("options", options).Msg("creating new JWT middleware")
|
||||
|
||||
if err = options.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("middleware %s: %w", Name, err)
|
||||
}
|
||||
|
||||
logger.Debug().Object("options", options).Msg("creating new JWT middleware")
|
||||
if len(cfg.JWKSetURL) > 0 && len(cfg.Issuer) > 0 && len(cfg.Audience) > 0 {
|
||||
if len(cfg.Header) == 0 {
|
||||
cfg.Header = authorizationHeader
|
||||
logger.Warn().
|
||||
Str("name", "Header").
|
||||
Str("default", cfg.Header).
|
||||
Msg("falling back to default configuration")
|
||||
}
|
||||
|
||||
if len(cfg.JWKSetURL) > 0 {
|
||||
var jwks *keyfunc.JWKS
|
||||
jwks, err = keyfunc.Get(cfg.JWKSetURL, keyfunc.Options{
|
||||
Ctx: context.Background(),
|
||||
RefreshErrorHandler: func(err error) {
|
||||
logger.Error().Err(err).Msg("error occurred while updating JWKs")
|
||||
},
|
||||
RefreshInterval: cfg.JWKUpdateInterval,
|
||||
RefreshUnknownKID: true,
|
||||
})
|
||||
if cfg.HandleAnnounce || cfg.HandleScrape {
|
||||
jwks, err = keyfunc.Get(cfg.JWKSetURL, keyfunc.Options{
|
||||
Ctx: context.Background(),
|
||||
RefreshErrorHandler: func(err error) {
|
||||
logger.Error().Err(err).Msg("error occurred while updating JWKs")
|
||||
},
|
||||
RefreshInterval: cfg.JWKUpdateInterval,
|
||||
RefreshUnknownKID: true,
|
||||
})
|
||||
} else {
|
||||
logger.Warn().Msg("both announce and scrape handle disabled")
|
||||
}
|
||||
if err == nil {
|
||||
h = &hook{
|
||||
cfg: cfg,
|
||||
@@ -111,73 +123,171 @@ func (h *hook) Stop() stop.Result {
|
||||
return c.Result()
|
||||
}
|
||||
|
||||
type compatibleClaims interface {
|
||||
Valid() error
|
||||
ToRegisteredClaims() jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type announceClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
InfoHash string `json:"infohash,omitempty"`
|
||||
}
|
||||
|
||||
func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims {
|
||||
return ac.RegisteredClaims
|
||||
}
|
||||
|
||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||
if req.Params == nil {
|
||||
return ctx, ErrMissingJWT
|
||||
if !h.cfg.HandleAnnounce {
|
||||
return ctx, nil
|
||||
}
|
||||
var err error
|
||||
|
||||
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
|
||||
err = ErrMissingJWT
|
||||
} else {
|
||||
claims := new(announceClaims)
|
||||
if errs := h.validateBaseJWT(jwtParam, claims); len(errs) > 0 {
|
||||
logger.Info().
|
||||
Errs("errors", errs).
|
||||
Object("source", req.RequestPeer).
|
||||
Msg("JWT validation failed")
|
||||
err = ErrInvalidJWT
|
||||
} else {
|
||||
var claimIH bittorrent.InfoHash
|
||||
if claimIH, err = bittorrent.NewInfoHash(claims.InfoHash); err != nil {
|
||||
logger.Info().
|
||||
Err(err).
|
||||
Object("source", req.RequestPeer).
|
||||
Msg("InfoHash claim parse failed")
|
||||
err = ErrInvalidJWT
|
||||
}
|
||||
if req.InfoHash != claimIH {
|
||||
logger.Info().
|
||||
Stringer("provided", claimIH).
|
||||
Stringer("required", req.InfoHash).
|
||||
Object("source", req.RequestPeer).
|
||||
Msg("InfoHash claim not equals to request InfoHash")
|
||||
err = ErrInvalidJWT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jwtParam, ok := req.Params.String("jwt")
|
||||
if !ok {
|
||||
return ctx, ErrMissingJWT
|
||||
}
|
||||
|
||||
if errs := h.validateJWT(req.InfoHash, jwtParam); len(errs) > 0 {
|
||||
logger.Info().
|
||||
Errs("errors", errs).
|
||||
Object("source", req.RequestPeer).
|
||||
Msg("JWT validation failed")
|
||||
return ctx, ErrInvalidJWT
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
||||
// Scrapes don't require any protection.
|
||||
return ctx, nil
|
||||
type scrapeClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
InfoHashes []string `json:"infohashes,omitempty"`
|
||||
}
|
||||
|
||||
func (h *hook) validateJWT(ih bittorrent.InfoHash, rawJwt string) []error {
|
||||
// KeyFunc will check KID, Parse will check ALG and signature
|
||||
errs := make([]error, 0, 4)
|
||||
claims := new(claims)
|
||||
token, err := jwt.ParseWithClaims(rawJwt, claims, h.jwks.Keyfunc, hmacAlgorithms)
|
||||
if err != nil {
|
||||
return []error{err}
|
||||
func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims {
|
||||
return sc.RegisteredClaims
|
||||
}
|
||||
|
||||
func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
||||
if !h.cfg.HandleScrape {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
if err = token.Claims.Valid(); err != nil {
|
||||
var err error
|
||||
|
||||
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
|
||||
err = ErrMissingJWT
|
||||
} else {
|
||||
claims := new(scrapeClaims)
|
||||
if errs := h.validateBaseJWT(jwtParam, claims); len(errs) > 0 {
|
||||
logger.Info().
|
||||
Errs("errors", errs).
|
||||
Array("source", req.RequestAddresses).
|
||||
Msg("JWT validation failed")
|
||||
err = ErrInvalidJWT
|
||||
} else {
|
||||
var claimIHs bittorrent.InfoHashes
|
||||
for _, s := range claims.InfoHashes {
|
||||
if providedIh, err := bittorrent.NewInfoHash(s); err == nil {
|
||||
claimIHs = append(claimIHs, providedIh)
|
||||
} else {
|
||||
logger.Info().
|
||||
Err(err).
|
||||
Array("source", req.RequestAddresses).
|
||||
Msg("InfoHash claim parse failed")
|
||||
}
|
||||
}
|
||||
eq := len(req.InfoHashes) == len(claimIHs)
|
||||
if eq {
|
||||
for _, rIH := range req.InfoHashes {
|
||||
found := false
|
||||
for _, cIH := range claimIHs {
|
||||
if rIH == cIH {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
eq = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !eq {
|
||||
logger.Info().
|
||||
Array("provided", claimIHs).
|
||||
Array("required", req.InfoHashes).
|
||||
Array("source", req.RequestAddresses).
|
||||
Msg("InfoHashes claim not equals to request InfoHashes")
|
||||
err = ErrInvalidJWT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
|
||||
if params != nil {
|
||||
var found bool
|
||||
if jwt, found = params.String(h.cfg.Header); found {
|
||||
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
|
||||
jwt = jwt[len(bearerAuthPrefix):]
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) {
|
||||
if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) {
|
||||
jwtParam = jwtParam[len(bearerAuthPrefix):]
|
||||
}
|
||||
if _, err := jwt.ParseWithClaims(jwtParam, claims, h.jwks.Keyfunc, hmacAlgorithms); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if err := claims.Valid(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 {
|
||||
errs = append(errs, errs0...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !claims.VerifyIssuer(h.cfg.Issuer, true) {
|
||||
func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) {
|
||||
rc := cl.ToRegisteredClaims()
|
||||
if !rc.VerifyIssuer(h.cfg.Issuer, true) {
|
||||
logger.Debug().
|
||||
Str("provided", claims.Issuer).
|
||||
Str("provided", rc.Issuer).
|
||||
Str("required", h.cfg.Issuer).
|
||||
Msg("unequal or missing issuer when validating JWT")
|
||||
errs = append(errs, jwt.ErrTokenInvalidIssuer)
|
||||
}
|
||||
|
||||
if !claims.VerifyAudience(h.cfg.Audience, true) {
|
||||
if !rc.VerifyAudience(h.cfg.Audience, true) {
|
||||
logger.Debug().
|
||||
Strs("provided", claims.Audience).
|
||||
Strs("provided", rc.Audience).
|
||||
Str("required", h.cfg.Audience).
|
||||
Msg("unequal or missing audience when validating JWT")
|
||||
errs = append(errs, jwt.ErrTokenInvalidAudience)
|
||||
}
|
||||
|
||||
providedIh, err := bittorrent.NewInfoHash(claims.InfoHash)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if providedIh != ih {
|
||||
logger.Debug().
|
||||
Err(err).
|
||||
Stringer("provided", providedIh).
|
||||
Stringer("required", ih).
|
||||
Msg("invalid or unequal info hash when validating JWT")
|
||||
errs = append(errs, errInvalidInfoHashClaim)
|
||||
}
|
||||
|
||||
return errs
|
||||
return
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims{
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
@@ -118,6 +118,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
//goland:noinspection HttpUrlsUsage
|
||||
cfg := conf.MapConfig{
|
||||
"handle_announce": true,
|
||||
"issuer": "CN=test",
|
||||
"audience": "test",
|
||||
"jwk_set_url": "http://" + s.Listener.Addr().String(),
|
||||
@@ -126,7 +127,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
|
||||
h, err := build(cfg, nil)
|
||||
require.Nil(t, err)
|
||||
data := make(map[string]string)
|
||||
data["jwt"] = tokenString
|
||||
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
||||
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
|
||||
InfoHash: infoHash,
|
||||
RequestPeer: bittorrent.RequestPeer{},
|
||||
@@ -143,7 +144,7 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
|
||||
|
||||
// now we wll use HMAC-SHA256 with invalid random key
|
||||
// all errors should be nil except announce request
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims{
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
@@ -162,6 +163,8 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
//goland:noinspection HttpUrlsUsage
|
||||
cfg := conf.MapConfig{
|
||||
"handle_announce": true,
|
||||
"header": "jwt",
|
||||
"issuer": "CN=test",
|
||||
"audience": "test",
|
||||
"jwk_set_url": "http://" + s.Listener.Addr().String(),
|
||||
@@ -178,3 +181,44 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
|
||||
}, nil)
|
||||
require.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestHook_HandleScrapeValid(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(jwksData)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
Audience: []string{"test"},
|
||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||
},
|
||||
InfoHashes: []string{infoHash.String()},
|
||||
})
|
||||
|
||||
token.Header["kid"] = jwksData.Keys[0].KeyID
|
||||
tokenString, err := token.SignedString(privKey)
|
||||
require.Nil(t, err)
|
||||
//goland:noinspection HttpUrlsUsage
|
||||
cfg := conf.MapConfig{
|
||||
"handle_scrape": true,
|
||||
"issuer": "CN=test",
|
||||
"audience": "test",
|
||||
"jwk_set_url": "http://" + s.Listener.Addr().String(),
|
||||
"jwk_set_update_interval": time.Minute,
|
||||
}
|
||||
h, err := build(cfg, nil)
|
||||
require.Nil(t, err)
|
||||
data := make(map[string]string)
|
||||
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
||||
_, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{
|
||||
InfoHashes: bittorrent.InfoHashes{infoHash},
|
||||
RequestAddresses: bittorrent.RequestAddresses{},
|
||||
Params: ¶ms{data: data},
|
||||
}, nil)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package random
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
)
|
||||
|
||||
// DeriveEntropyFromRequest generates 2*64 bits of pseudo random state from an
|
||||
// AnnounceRequest.
|
||||
//
|
||||
// Calling DeriveEntropyFromRequest multiple times yields the same values.
|
||||
func DeriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 uint64) {
|
||||
if len(req.InfoHash) >= bittorrent.InfoHashV1Len {
|
||||
v0 = binary.BigEndian.Uint64([]byte(req.InfoHash[:8])) + binary.BigEndian.Uint64([]byte(req.InfoHash[8:16]))
|
||||
}
|
||||
v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16])
|
||||
return
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
// Package random implements the XORShift PRNG and a way to derive random state
|
||||
// from an AnnounceRequest.
|
||||
package random
|
||||
|
||||
// GenerateAndAdvance applies XORShift128Plus on s0 and s1, returning
|
||||
// the new states newS0, newS1 and a pseudo-random number v.
|
||||
func GenerateAndAdvance(s0, s1 uint64) (v, newS0, newS1 uint64) {
|
||||
v = s0 + s1
|
||||
newS0 = s1
|
||||
s0 ^= s0 << 23
|
||||
newS1 = s0 ^ s1 ^ (s0 >> 18) ^ (s1 >> 5)
|
||||
return
|
||||
}
|
||||
|
||||
// Intn generates an int k that satisfies k >= 0 && k < n.
|
||||
// n must be > 0.
|
||||
// It returns the generated k and the new state of the generator.
|
||||
func Intn(s0, s1 uint64, n int) (int, uint64, uint64) {
|
||||
if n <= 0 {
|
||||
panic("invalid n <= 0")
|
||||
}
|
||||
v, newS0, newS1 := GenerateAndAdvance(s0, s1)
|
||||
k := int(v)
|
||||
if k < 0 {
|
||||
k = -k
|
||||
}
|
||||
return k % n, newS0, newS1
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package random
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/sot-tech/mochi/pkg/randseed"
|
||||
)
|
||||
|
||||
func TestIntn(t *testing.T) {
|
||||
s0, s1 := rand.Uint64(), rand.Uint64()
|
||||
var k int
|
||||
for i := 0; i < 10000; i++ {
|
||||
k, s0, s1 = Intn(s0, s1, 10)
|
||||
require.True(t, k >= 0, "Intn() must be >= 0")
|
||||
require.True(t, k < 10, "Intn(k) must be < k")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAdvanceXORShift128Plus(b *testing.B) {
|
||||
s0, s1 := rand.Uint64(), rand.Uint64()
|
||||
var v uint64
|
||||
for i := 0; i < b.N; i++ {
|
||||
v, s0, s1 = GenerateAndAdvance(s0, s1)
|
||||
}
|
||||
_, _, _ = v, s0, s1
|
||||
}
|
||||
|
||||
func BenchmarkIntn(b *testing.B) {
|
||||
s0, s1 := rand.Uint64(), rand.Uint64()
|
||||
var v int
|
||||
for i := 0; i < b.N; i++ {
|
||||
v, s0, s1 = Intn(s0, s1, 1000)
|
||||
}
|
||||
_, _, _ = v, s0, s1
|
||||
}
|
||||
@@ -3,14 +3,15 @@ package varinterval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/middleware"
|
||||
"github.com/sot-tech/mochi/middleware/pkg/random"
|
||||
"github.com/sot-tech/mochi/pkg/conf"
|
||||
"github.com/sot-tech/mochi/storage"
|
||||
)
|
||||
@@ -79,14 +80,12 @@ type hook struct {
|
||||
}
|
||||
|
||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||
s0, s1 := random.DeriveEntropyFromRequest(req)
|
||||
// Generate a probability p < 1.0.
|
||||
v, s0, s1 := random.Intn(s0, s1, 1<<24)
|
||||
p := float32(v) / (1 << 24)
|
||||
if h.cfg.ModifyResponseProbability == 1 || p < h.cfg.ModifyResponseProbability {
|
||||
p, s0, s1 := xoroshiro128p(deriveEntropyFromRequest(req))
|
||||
if float32(float64(p)/math.MaxUint64) < h.cfg.ModifyResponseProbability {
|
||||
// Generate the increase delta.
|
||||
v, _, _ = random.Intn(s0, s1, h.cfg.MaxIncreaseDelta)
|
||||
add := time.Duration(v+1) * time.Second
|
||||
v, _, _ := xoroshiro128p(s0, s1)
|
||||
add := time.Duration(v%uint64(h.cfg.MaxIncreaseDelta)+1) * time.Second
|
||||
|
||||
resp.Interval += add
|
||||
|
||||
@@ -102,3 +101,26 @@ func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _
|
||||
// Scrapes are not altered.
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// deriveEntropyFromRequest generates 2*64 bits of pseudo random state from an
|
||||
// bittorrent.AnnounceRequest.
|
||||
//
|
||||
// Calling deriveEntropyFromRequest multiple times yields the same values.
|
||||
func deriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 uint64) {
|
||||
if len(req.InfoHash) >= bittorrent.InfoHashV1Len {
|
||||
v0 = binary.BigEndian.Uint64([]byte(req.InfoHash[:8])) + binary.BigEndian.Uint64([]byte(req.InfoHash[8:16]))
|
||||
}
|
||||
v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16])
|
||||
return
|
||||
}
|
||||
|
||||
// xoroshiro128p calculates predictable pseudorandom number
|
||||
// with XOR/rotate/shift/rotate 128+ algorithm.
|
||||
// see https://prng.di.unimi.it/xoroshiro128plus.c
|
||||
func xoroshiro128p(s0, s1 uint64) (result, ns0, ns1 uint64) {
|
||||
result = s0 + s1
|
||||
s1 ^= s0
|
||||
ns0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16)
|
||||
ns1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,12 +3,15 @@ package varinterval
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/pkg/conf"
|
||||
|
||||
_ "github.com/sot-tech/mochi/pkg/randseed"
|
||||
)
|
||||
|
||||
var configTests = []struct {
|
||||
@@ -61,3 +64,13 @@ func TestHandleAnnounce(t *testing.T) {
|
||||
require.True(t, resp.Interval > 0, "interval should have been increased")
|
||||
require.True(t, resp.MinInterval > 0, "min_interval should have been increased")
|
||||
}
|
||||
|
||||
func BenchmarkXORoShiRo128Plus(b *testing.B) {
|
||||
s0, s1 := rand.Uint64(), rand.Uint64()
|
||||
var v uint64
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v, s0, s1 = xoroshiro128p(s0, s1)
|
||||
}
|
||||
_, _, _ = v, s0, s1
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user