(partially tested) add scrape JWT validation

* remove `xorshift` package, add internal function in `varinterval`
* change `bittorrent.QueryParams` getters to search case insensitive keys
This commit is contained in:
Lawrence, Rendall
2022-08-25 18:31:48 +03:00
parent f5a58630db
commit 498779aeaf
9 changed files with 276 additions and 169 deletions
+7 -7
View File
@@ -180,15 +180,15 @@ func parseQuery(query string) (q *QueryParams, err error) {
// String returns a string parsed from a query. Every key can be returned as a
// string because they are encoded in the URL as strings.
func (qp *QueryParams) String(key string) (string, bool) {
value, ok := qp.params[key]
func (qp QueryParams) String(key string) (string, bool) {
value, ok := qp.params[strings.ToLower(key)]
return value, ok
}
// Uint returns an uint parsed from a query. After being called, it is safe to
// cast the uint64 to your desired length.
func (qp *QueryParams) Uint(key string, bitSize int) (uint64, error) {
str, exists := qp.params[key]
func (qp QueryParams) Uint(key string, bitSize int) (uint64, error) {
str, exists := qp.params[strings.ToLower(key)]
if !exists {
return 0, ErrKeyNotFound
}
@@ -197,17 +197,17 @@ func (qp *QueryParams) Uint(key string, bitSize int) (uint64, error) {
}
// InfoHashes returns a list of requested infohashes.
func (qp *QueryParams) InfoHashes() []InfoHash {
func (qp QueryParams) InfoHashes() []InfoHash {
return qp.infoHashes
}
// RawPath returns the raw path from the parsed URL.
func (qp *QueryParams) RawPath() string {
func (qp QueryParams) RawPath() string {
return qp.path
}
// RawQuery returns the raw query from the parsed URL.
func (qp *QueryParams) RawQuery() string {
func (qp QueryParams) RawQuery() string {
return qp.query
}
+3
View File
@@ -264,10 +264,13 @@ mochi:
prehooks:
# - name: jwt
# options:
# header: "authorization"
# issuer: "https://issuer.com"
# audience: "https://some.issuer.com"
# jwk_set_url: "https://issuer.com/keys"
# jwk_set_update_interval: 5m
# handle_announce: true
# handle_scrape: false
#
# - name: client approval
# options:
+177 -67
View File
@@ -1,15 +1,15 @@
// Package jwt implements a Hook that fails an Announce if the client's request
// Package jwt implements a Hook that fails on Announce or Scrape if the client's request
// is missing a valid JSON Web Token.
//
// JWTs are validated against the standard claims in RFC7519 along with an
// extra "infohash" claim that verifies the client has access to the Swarm.
// RS256 keys are asynchronously rotated from a provided JWK Set HTTP endpoint.
// extra "infohash(es)" claim that verifies the client has access to the Swarm.
package jwt
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/MicahParks/keyfunc"
@@ -24,7 +24,11 @@ import (
)
// Name is the name by which this middleware is registered with Conf.
const Name = "jwt"
const (
Name = "jwt"
authorizationHeader = "authorization"
bearerAuthPrefix = "bearer "
)
func init() {
middleware.RegisterBuilder(Name, build)
@@ -38,8 +42,6 @@ var (
// ErrInvalidJWT is returned when a JWT fails to verify.
ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt")
errInvalidInfoHashClaim = errors.New("token has invalid \"infohash\" claim")
errJWKsNotSet = errors.New("required parameters not provided: Issuer, Audience and/or JWKSetURL")
hmacAlgorithms = jwt.WithValidMethods([]string{
@@ -54,10 +56,13 @@ var (
// Config represents all the values required by this middleware to fetch JWKs
// and verify JWTs.
type Config struct {
Header string
Issuer string
Audience string
JWKSetURL string `cfg:"jwk_set_url"`
JWKUpdateInterval time.Duration `cfg:"jwk_set_update_interval"`
HandleAnnounce bool `cfg:"handle_announce"`
HandleScrape bool `cfg:"handle_scrape"`
}
type hook struct {
@@ -65,30 +70,37 @@ type hook struct {
jwks *keyfunc.JWKS
}
type claims struct {
jwt.RegisteredClaims
InfoHash string `json:"infohash,omitempty"`
}
func build(options conf.MapConfig, _ storage.PeerStorage) (h middleware.Hook, err error) {
var cfg Config
logger.Debug().Object("options", options).Msg("creating new JWT middleware")
if err = options.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("middleware %s: %w", Name, err)
}
logger.Debug().Object("options", options).Msg("creating new JWT middleware")
if len(cfg.JWKSetURL) > 0 && len(cfg.Issuer) > 0 && len(cfg.Audience) > 0 {
if len(cfg.Header) == 0 {
cfg.Header = authorizationHeader
logger.Warn().
Str("name", "Header").
Str("default", cfg.Header).
Msg("falling back to default configuration")
}
if len(cfg.JWKSetURL) > 0 {
var jwks *keyfunc.JWKS
jwks, err = keyfunc.Get(cfg.JWKSetURL, keyfunc.Options{
Ctx: context.Background(),
RefreshErrorHandler: func(err error) {
logger.Error().Err(err).Msg("error occurred while updating JWKs")
},
RefreshInterval: cfg.JWKUpdateInterval,
RefreshUnknownKID: true,
})
if cfg.HandleAnnounce || cfg.HandleScrape {
jwks, err = keyfunc.Get(cfg.JWKSetURL, keyfunc.Options{
Ctx: context.Background(),
RefreshErrorHandler: func(err error) {
logger.Error().Err(err).Msg("error occurred while updating JWKs")
},
RefreshInterval: cfg.JWKUpdateInterval,
RefreshUnknownKID: true,
})
} else {
logger.Warn().Msg("both announce and scrape handle disabled")
}
if err == nil {
h = &hook{
cfg: cfg,
@@ -111,73 +123,171 @@ func (h *hook) Stop() stop.Result {
return c.Result()
}
type compatibleClaims interface {
Valid() error
ToRegisteredClaims() jwt.RegisteredClaims
}
type announceClaims struct {
jwt.RegisteredClaims
InfoHash string `json:"infohash,omitempty"`
}
func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims {
return ac.RegisteredClaims
}
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
if req.Params == nil {
return ctx, ErrMissingJWT
if !h.cfg.HandleAnnounce {
return ctx, nil
}
var err error
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
err = ErrMissingJWT
} else {
claims := new(announceClaims)
if errs := h.validateBaseJWT(jwtParam, claims); len(errs) > 0 {
logger.Info().
Errs("errors", errs).
Object("source", req.RequestPeer).
Msg("JWT validation failed")
err = ErrInvalidJWT
} else {
var claimIH bittorrent.InfoHash
if claimIH, err = bittorrent.NewInfoHash(claims.InfoHash); err != nil {
logger.Info().
Err(err).
Object("source", req.RequestPeer).
Msg("InfoHash claim parse failed")
err = ErrInvalidJWT
}
if req.InfoHash != claimIH {
logger.Info().
Stringer("provided", claimIH).
Stringer("required", req.InfoHash).
Object("source", req.RequestPeer).
Msg("InfoHash claim not equals to request InfoHash")
err = ErrInvalidJWT
}
}
}
jwtParam, ok := req.Params.String("jwt")
if !ok {
return ctx, ErrMissingJWT
}
if errs := h.validateJWT(req.InfoHash, jwtParam); len(errs) > 0 {
logger.Info().
Errs("errors", errs).
Object("source", req.RequestPeer).
Msg("JWT validation failed")
return ctx, ErrInvalidJWT
}
return ctx, nil
return ctx, err
}
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
// Scrapes don't require any protection.
return ctx, nil
type scrapeClaims struct {
jwt.RegisteredClaims
InfoHashes []string `json:"infohashes,omitempty"`
}
func (h *hook) validateJWT(ih bittorrent.InfoHash, rawJwt string) []error {
// KeyFunc will check KID, Parse will check ALG and signature
errs := make([]error, 0, 4)
claims := new(claims)
token, err := jwt.ParseWithClaims(rawJwt, claims, h.jwks.Keyfunc, hmacAlgorithms)
if err != nil {
return []error{err}
func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims {
return sc.RegisteredClaims
}
func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
if !h.cfg.HandleScrape {
return ctx, nil
}
if err = token.Claims.Valid(); err != nil {
var err error
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
err = ErrMissingJWT
} else {
claims := new(scrapeClaims)
if errs := h.validateBaseJWT(jwtParam, claims); len(errs) > 0 {
logger.Info().
Errs("errors", errs).
Array("source", req.RequestAddresses).
Msg("JWT validation failed")
err = ErrInvalidJWT
} else {
var claimIHs bittorrent.InfoHashes
for _, s := range claims.InfoHashes {
if providedIh, err := bittorrent.NewInfoHash(s); err == nil {
claimIHs = append(claimIHs, providedIh)
} else {
logger.Info().
Err(err).
Array("source", req.RequestAddresses).
Msg("InfoHash claim parse failed")
}
}
eq := len(req.InfoHashes) == len(claimIHs)
if eq {
for _, rIH := range req.InfoHashes {
found := false
for _, cIH := range claimIHs {
if rIH == cIH {
found = true
break
}
}
if !found {
eq = false
break
}
}
}
if !eq {
logger.Info().
Array("provided", claimIHs).
Array("required", req.InfoHashes).
Array("source", req.RequestAddresses).
Msg("InfoHashes claim not equals to request InfoHashes")
err = ErrInvalidJWT
}
}
}
return ctx, err
}
func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
if params != nil {
var found bool
if jwt, found = params.String(h.cfg.Header); found {
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
jwt = jwt[len(bearerAuthPrefix):]
}
}
}
return
}
func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) {
if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) {
jwtParam = jwtParam[len(bearerAuthPrefix):]
}
if _, err := jwt.ParseWithClaims(jwtParam, claims, h.jwks.Keyfunc, hmacAlgorithms); err != nil {
errs = append(errs, err)
}
if err := claims.Valid(); err != nil {
errs = append(errs, err)
}
if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 {
errs = append(errs, errs0...)
}
return
}
if !claims.VerifyIssuer(h.cfg.Issuer, true) {
func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) {
rc := cl.ToRegisteredClaims()
if !rc.VerifyIssuer(h.cfg.Issuer, true) {
logger.Debug().
Str("provided", claims.Issuer).
Str("provided", rc.Issuer).
Str("required", h.cfg.Issuer).
Msg("unequal or missing issuer when validating JWT")
errs = append(errs, jwt.ErrTokenInvalidIssuer)
}
if !claims.VerifyAudience(h.cfg.Audience, true) {
if !rc.VerifyAudience(h.cfg.Audience, true) {
logger.Debug().
Strs("provided", claims.Audience).
Strs("provided", rc.Audience).
Str("required", h.cfg.Audience).
Msg("unequal or missing audience when validating JWT")
errs = append(errs, jwt.ErrTokenInvalidAudience)
}
providedIh, err := bittorrent.NewInfoHash(claims.InfoHash)
if err != nil {
errs = append(errs, err)
}
if providedIh != ih {
logger.Debug().
Err(err).
Stringer("provided", providedIh).
Stringer("required", ih).
Msg("invalid or unequal info hash when validating JWT")
errs = append(errs, errInvalidInfoHashClaim)
}
return errs
return
}
+47 -3
View File
@@ -101,7 +101,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
}))
defer s.Close()
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims{
token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
@@ -118,6 +118,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
require.Nil(t, err)
//goland:noinspection HttpUrlsUsage
cfg := conf.MapConfig{
"handle_announce": true,
"issuer": "CN=test",
"audience": "test",
"jwk_set_url": "http://" + s.Listener.Addr().String(),
@@ -126,7 +127,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
h, err := build(cfg, nil)
require.Nil(t, err)
data := make(map[string]string)
data["jwt"] = tokenString
data[authorizationHeader] = bearerAuthPrefix + tokenString
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
InfoHash: infoHash,
RequestPeer: bittorrent.RequestPeer{},
@@ -143,7 +144,7 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
// now we wll use HMAC-SHA256 with invalid random key
// all errors should be nil except announce request
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims{
token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
@@ -162,6 +163,8 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
require.Nil(t, err)
//goland:noinspection HttpUrlsUsage
cfg := conf.MapConfig{
"handle_announce": true,
"header": "jwt",
"issuer": "CN=test",
"audience": "test",
"jwk_set_url": "http://" + s.Listener.Addr().String(),
@@ -178,3 +181,44 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
}, nil)
require.NotNil(t, err)
}
func TestHook_HandleScrapeValid(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(jwksData)
}))
defer s.Close()
token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
Audience: []string{"test"},
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
},
InfoHashes: []string{infoHash.String()},
})
token.Header["kid"] = jwksData.Keys[0].KeyID
tokenString, err := token.SignedString(privKey)
require.Nil(t, err)
//goland:noinspection HttpUrlsUsage
cfg := conf.MapConfig{
"handle_scrape": true,
"issuer": "CN=test",
"audience": "test",
"jwk_set_url": "http://" + s.Listener.Addr().String(),
"jwk_set_update_interval": time.Minute,
}
h, err := build(cfg, nil)
require.Nil(t, err)
data := make(map[string]string)
data[authorizationHeader] = bearerAuthPrefix + tokenString
_, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{
InfoHashes: bittorrent.InfoHashes{infoHash},
RequestAddresses: bittorrent.RequestAddresses{},
Params: &params{data: data},
}, nil)
require.Nil(t, err)
}
-19
View File
@@ -1,19 +0,0 @@
package random
import (
"encoding/binary"
"github.com/sot-tech/mochi/bittorrent"
)
// DeriveEntropyFromRequest generates 2*64 bits of pseudo random state from an
// AnnounceRequest.
//
// Calling DeriveEntropyFromRequest multiple times yields the same values.
func DeriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 uint64) {
if len(req.InfoHash) >= bittorrent.InfoHashV1Len {
v0 = binary.BigEndian.Uint64([]byte(req.InfoHash[:8])) + binary.BigEndian.Uint64([]byte(req.InfoHash[8:16]))
}
v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16])
return
}
-28
View File
@@ -1,28 +0,0 @@
// Package random implements the XORShift PRNG and a way to derive random state
// from an AnnounceRequest.
package random
// GenerateAndAdvance applies XORShift128Plus on s0 and s1, returning
// the new states newS0, newS1 and a pseudo-random number v.
func GenerateAndAdvance(s0, s1 uint64) (v, newS0, newS1 uint64) {
v = s0 + s1
newS0 = s1
s0 ^= s0 << 23
newS1 = s0 ^ s1 ^ (s0 >> 18) ^ (s1 >> 5)
return
}
// Intn generates an int k that satisfies k >= 0 && k < n.
// n must be > 0.
// It returns the generated k and the new state of the generator.
func Intn(s0, s1 uint64, n int) (int, uint64, uint64) {
if n <= 0 {
panic("invalid n <= 0")
}
v, newS0, newS1 := GenerateAndAdvance(s0, s1)
k := int(v)
if k < 0 {
k = -k
}
return k % n, newS0, newS1
}
-38
View File
@@ -1,38 +0,0 @@
package random
import (
"math/rand"
"testing"
"github.com/stretchr/testify/require"
_ "github.com/sot-tech/mochi/pkg/randseed"
)
func TestIntn(t *testing.T) {
s0, s1 := rand.Uint64(), rand.Uint64()
var k int
for i := 0; i < 10000; i++ {
k, s0, s1 = Intn(s0, s1, 10)
require.True(t, k >= 0, "Intn() must be >= 0")
require.True(t, k < 10, "Intn(k) must be < k")
}
}
func BenchmarkAdvanceXORShift128Plus(b *testing.B) {
s0, s1 := rand.Uint64(), rand.Uint64()
var v uint64
for i := 0; i < b.N; i++ {
v, s0, s1 = GenerateAndAdvance(s0, s1)
}
_, _, _ = v, s0, s1
}
func BenchmarkIntn(b *testing.B) {
s0, s1 := rand.Uint64(), rand.Uint64()
var v int
for i := 0; i < b.N; i++ {
v, s0, s1 = Intn(s0, s1, 1000)
}
_, _, _ = v, s0, s1
}
+29 -7
View File
@@ -3,14 +3,15 @@ package varinterval
import (
"context"
"encoding/binary"
"errors"
"fmt"
"math"
"sync"
"time"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/middleware"
"github.com/sot-tech/mochi/middleware/pkg/random"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/storage"
)
@@ -79,14 +80,12 @@ type hook struct {
}
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
s0, s1 := random.DeriveEntropyFromRequest(req)
// Generate a probability p < 1.0.
v, s0, s1 := random.Intn(s0, s1, 1<<24)
p := float32(v) / (1 << 24)
if h.cfg.ModifyResponseProbability == 1 || p < h.cfg.ModifyResponseProbability {
p, s0, s1 := xoroshiro128p(deriveEntropyFromRequest(req))
if float32(float64(p)/math.MaxUint64) < h.cfg.ModifyResponseProbability {
// Generate the increase delta.
v, _, _ = random.Intn(s0, s1, h.cfg.MaxIncreaseDelta)
add := time.Duration(v+1) * time.Second
v, _, _ := xoroshiro128p(s0, s1)
add := time.Duration(v%uint64(h.cfg.MaxIncreaseDelta)+1) * time.Second
resp.Interval += add
@@ -102,3 +101,26 @@ func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _
// Scrapes are not altered.
return ctx, nil
}
// deriveEntropyFromRequest generates 2*64 bits of pseudo random state from an
// bittorrent.AnnounceRequest.
//
// Calling deriveEntropyFromRequest multiple times yields the same values.
func deriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 uint64) {
if len(req.InfoHash) >= bittorrent.InfoHashV1Len {
v0 = binary.BigEndian.Uint64([]byte(req.InfoHash[:8])) + binary.BigEndian.Uint64([]byte(req.InfoHash[8:16]))
}
v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16])
return
}
// xoroshiro128p calculates predictable pseudorandom number
// with XOR/rotate/shift/rotate 128+ algorithm.
// see https://prng.di.unimi.it/xoroshiro128plus.c
func xoroshiro128p(s0, s1 uint64) (result, ns0, ns1 uint64) {
result = s0 + s1
s1 ^= s0
ns0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16)
ns1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37)
return
}
@@ -3,12 +3,15 @@ package varinterval
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/stretchr/testify/require"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/conf"
_ "github.com/sot-tech/mochi/pkg/randseed"
)
var configTests = []struct {
@@ -61,3 +64,13 @@ func TestHandleAnnounce(t *testing.T) {
require.True(t, resp.Interval > 0, "interval should have been increased")
require.True(t, resp.MinInterval > 0, "min_interval should have been increased")
}
func BenchmarkXORoShiRo128Plus(b *testing.B) {
s0, s1 := rand.Uint64(), rand.Uint64()
var v uint64
b.ResetTimer()
for i := 0; i < b.N; i++ {
v, s0, s1 = xoroshiro128p(s0, s1)
}
_, _, _ = v, s0, s1
}