diff --git a/bittorrent/params.go b/bittorrent/params.go index f7aaf02..3a08b64 100644 --- a/bittorrent/params.go +++ b/bittorrent/params.go @@ -180,15 +180,15 @@ func parseQuery(query string) (q *QueryParams, err error) { // String returns a string parsed from a query. Every key can be returned as a // string because they are encoded in the URL as strings. -func (qp *QueryParams) String(key string) (string, bool) { - value, ok := qp.params[key] +func (qp QueryParams) String(key string) (string, bool) { + value, ok := qp.params[strings.ToLower(key)] return value, ok } // Uint returns an uint parsed from a query. After being called, it is safe to // cast the uint64 to your desired length. -func (qp *QueryParams) Uint(key string, bitSize int) (uint64, error) { - str, exists := qp.params[key] +func (qp QueryParams) Uint(key string, bitSize int) (uint64, error) { + str, exists := qp.params[strings.ToLower(key)] if !exists { return 0, ErrKeyNotFound } @@ -197,17 +197,17 @@ func (qp *QueryParams) Uint(key string, bitSize int) (uint64, error) { } // InfoHashes returns a list of requested infohashes. -func (qp *QueryParams) InfoHashes() []InfoHash { +func (qp QueryParams) InfoHashes() []InfoHash { return qp.infoHashes } // RawPath returns the raw path from the parsed URL. -func (qp *QueryParams) RawPath() string { +func (qp QueryParams) RawPath() string { return qp.path } // RawQuery returns the raw query from the parsed URL. -func (qp *QueryParams) RawQuery() string { +func (qp QueryParams) RawQuery() string { return qp.query } diff --git a/dist/example_config.yaml b/dist/example_config.yaml index b8bfad2..250bfc5 100644 --- a/dist/example_config.yaml +++ b/dist/example_config.yaml @@ -264,10 +264,13 @@ mochi: prehooks: # - name: jwt # options: + # header: "authorization" # issuer: "https://issuer.com" # audience: "https://some.issuer.com" # jwk_set_url: "https://issuer.com/keys" # jwk_set_update_interval: 5m + # handle_announce: true + # handle_scrape: false # # - name: client approval # options: diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 8a3ace4..8e700a2 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -1,15 +1,15 @@ -// Package jwt implements a Hook that fails an Announce if the client's request +// Package jwt implements a Hook that fails on Announce or Scrape if the client's request // is missing a valid JSON Web Token. // // JWTs are validated against the standard claims in RFC7519 along with an -// extra "infohash" claim that verifies the client has access to the Swarm. -// RS256 keys are asynchronously rotated from a provided JWK Set HTTP endpoint. +// extra "infohash(es)" claim that verifies the client has access to the Swarm. package jwt import ( "context" "errors" "fmt" + "strings" "time" "github.com/MicahParks/keyfunc" @@ -24,7 +24,11 @@ import ( ) // Name is the name by which this middleware is registered with Conf. -const Name = "jwt" +const ( + Name = "jwt" + authorizationHeader = "authorization" + bearerAuthPrefix = "bearer " +) func init() { middleware.RegisterBuilder(Name, build) @@ -38,8 +42,6 @@ var ( // ErrInvalidJWT is returned when a JWT fails to verify. ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt") - errInvalidInfoHashClaim = errors.New("token has invalid \"infohash\" claim") - errJWKsNotSet = errors.New("required parameters not provided: Issuer, Audience and/or JWKSetURL") hmacAlgorithms = jwt.WithValidMethods([]string{ @@ -54,10 +56,13 @@ var ( // Config represents all the values required by this middleware to fetch JWKs // and verify JWTs. type Config struct { + Header string Issuer string Audience string JWKSetURL string `cfg:"jwk_set_url"` JWKUpdateInterval time.Duration `cfg:"jwk_set_update_interval"` + HandleAnnounce bool `cfg:"handle_announce"` + HandleScrape bool `cfg:"handle_scrape"` } type hook struct { @@ -65,30 +70,37 @@ type hook struct { jwks *keyfunc.JWKS } -type claims struct { - jwt.RegisteredClaims - InfoHash string `json:"infohash,omitempty"` -} - func build(options conf.MapConfig, _ storage.PeerStorage) (h middleware.Hook, err error) { var cfg Config + logger.Debug().Object("options", options).Msg("creating new JWT middleware") + if err = options.Unmarshal(&cfg); err != nil { return nil, fmt.Errorf("middleware %s: %w", Name, err) } - logger.Debug().Object("options", options).Msg("creating new JWT middleware") + if len(cfg.JWKSetURL) > 0 && len(cfg.Issuer) > 0 && len(cfg.Audience) > 0 { + if len(cfg.Header) == 0 { + cfg.Header = authorizationHeader + logger.Warn(). + Str("name", "Header"). + Str("default", cfg.Header). + Msg("falling back to default configuration") + } - if len(cfg.JWKSetURL) > 0 { var jwks *keyfunc.JWKS - jwks, err = keyfunc.Get(cfg.JWKSetURL, keyfunc.Options{ - Ctx: context.Background(), - RefreshErrorHandler: func(err error) { - logger.Error().Err(err).Msg("error occurred while updating JWKs") - }, - RefreshInterval: cfg.JWKUpdateInterval, - RefreshUnknownKID: true, - }) + if cfg.HandleAnnounce || cfg.HandleScrape { + jwks, err = keyfunc.Get(cfg.JWKSetURL, keyfunc.Options{ + Ctx: context.Background(), + RefreshErrorHandler: func(err error) { + logger.Error().Err(err).Msg("error occurred while updating JWKs") + }, + RefreshInterval: cfg.JWKUpdateInterval, + RefreshUnknownKID: true, + }) + } else { + logger.Warn().Msg("both announce and scrape handle disabled") + } if err == nil { h = &hook{ cfg: cfg, @@ -111,73 +123,171 @@ func (h *hook) Stop() stop.Result { return c.Result() } +type compatibleClaims interface { + Valid() error + ToRegisteredClaims() jwt.RegisteredClaims +} + +type announceClaims struct { + jwt.RegisteredClaims + InfoHash string `json:"infohash,omitempty"` +} + +func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims { + return ac.RegisteredClaims +} + func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) { - if req.Params == nil { - return ctx, ErrMissingJWT + if !h.cfg.HandleAnnounce { + return ctx, nil + } + var err error + + if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 { + err = ErrMissingJWT + } else { + claims := new(announceClaims) + if errs := h.validateBaseJWT(jwtParam, claims); len(errs) > 0 { + logger.Info(). + Errs("errors", errs). + Object("source", req.RequestPeer). + Msg("JWT validation failed") + err = ErrInvalidJWT + } else { + var claimIH bittorrent.InfoHash + if claimIH, err = bittorrent.NewInfoHash(claims.InfoHash); err != nil { + logger.Info(). + Err(err). + Object("source", req.RequestPeer). + Msg("InfoHash claim parse failed") + err = ErrInvalidJWT + } + if req.InfoHash != claimIH { + logger.Info(). + Stringer("provided", claimIH). + Stringer("required", req.InfoHash). + Object("source", req.RequestPeer). + Msg("InfoHash claim not equals to request InfoHash") + err = ErrInvalidJWT + } + } } - jwtParam, ok := req.Params.String("jwt") - if !ok { - return ctx, ErrMissingJWT - } - - if errs := h.validateJWT(req.InfoHash, jwtParam); len(errs) > 0 { - logger.Info(). - Errs("errors", errs). - Object("source", req.RequestPeer). - Msg("JWT validation failed") - return ctx, ErrInvalidJWT - } - - return ctx, nil + return ctx, err } -func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) { - // Scrapes don't require any protection. - return ctx, nil +type scrapeClaims struct { + jwt.RegisteredClaims + InfoHashes []string `json:"infohashes,omitempty"` } -func (h *hook) validateJWT(ih bittorrent.InfoHash, rawJwt string) []error { - // KeyFunc will check KID, Parse will check ALG and signature - errs := make([]error, 0, 4) - claims := new(claims) - token, err := jwt.ParseWithClaims(rawJwt, claims, h.jwks.Keyfunc, hmacAlgorithms) - if err != nil { - return []error{err} +func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims { + return sc.RegisteredClaims +} + +func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) { + if !h.cfg.HandleScrape { + return ctx, nil } - if err = token.Claims.Valid(); err != nil { + var err error + + if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 { + err = ErrMissingJWT + } else { + claims := new(scrapeClaims) + if errs := h.validateBaseJWT(jwtParam, claims); len(errs) > 0 { + logger.Info(). + Errs("errors", errs). + Array("source", req.RequestAddresses). + Msg("JWT validation failed") + err = ErrInvalidJWT + } else { + var claimIHs bittorrent.InfoHashes + for _, s := range claims.InfoHashes { + if providedIh, err := bittorrent.NewInfoHash(s); err == nil { + claimIHs = append(claimIHs, providedIh) + } else { + logger.Info(). + Err(err). + Array("source", req.RequestAddresses). + Msg("InfoHash claim parse failed") + } + } + eq := len(req.InfoHashes) == len(claimIHs) + if eq { + for _, rIH := range req.InfoHashes { + found := false + for _, cIH := range claimIHs { + if rIH == cIH { + found = true + break + } + } + if !found { + eq = false + break + } + } + } + if !eq { + logger.Info(). + Array("provided", claimIHs). + Array("required", req.InfoHashes). + Array("source", req.RequestAddresses). + Msg("InfoHashes claim not equals to request InfoHashes") + err = ErrInvalidJWT + } + } + } + + return ctx, err +} + +func (h *hook) getJWT(params bittorrent.Params) (jwt string) { + if params != nil { + var found bool + if jwt, found = params.String(h.cfg.Header); found { + if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) { + jwt = jwt[len(bearerAuthPrefix):] + } + } + } + return +} + +func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) { + if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) { + jwtParam = jwtParam[len(bearerAuthPrefix):] + } + if _, err := jwt.ParseWithClaims(jwtParam, claims, h.jwks.Keyfunc, hmacAlgorithms); err != nil { errs = append(errs, err) } + if err := claims.Valid(); err != nil { + errs = append(errs, err) + } + if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 { + errs = append(errs, errs0...) + } + return +} - if !claims.VerifyIssuer(h.cfg.Issuer, true) { +func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) { + rc := cl.ToRegisteredClaims() + if !rc.VerifyIssuer(h.cfg.Issuer, true) { logger.Debug(). - Str("provided", claims.Issuer). + Str("provided", rc.Issuer). Str("required", h.cfg.Issuer). Msg("unequal or missing issuer when validating JWT") errs = append(errs, jwt.ErrTokenInvalidIssuer) } - if !claims.VerifyAudience(h.cfg.Audience, true) { + if !rc.VerifyAudience(h.cfg.Audience, true) { logger.Debug(). - Strs("provided", claims.Audience). + Strs("provided", rc.Audience). Str("required", h.cfg.Audience). Msg("unequal or missing audience when validating JWT") errs = append(errs, jwt.ErrTokenInvalidAudience) } - - providedIh, err := bittorrent.NewInfoHash(claims.InfoHash) - if err != nil { - errs = append(errs, err) - } - if providedIh != ih { - logger.Debug(). - Err(err). - Stringer("provided", providedIh). - Stringer("required", ih). - Msg("invalid or unequal info hash when validating JWT") - errs = append(errs, errInvalidInfoHashClaim) - } - - return errs + return } diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go index 4e42130..046b83b 100644 --- a/middleware/jwt/jwt_test.go +++ b/middleware/jwt/jwt_test.go @@ -101,7 +101,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) { })) defer s.Close() - token := jwt.NewWithClaims(jwt.SigningMethodES256, claims{ + token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: "CN=test", Subject: "CN=test", @@ -118,6 +118,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) { require.Nil(t, err) //goland:noinspection HttpUrlsUsage cfg := conf.MapConfig{ + "handle_announce": true, "issuer": "CN=test", "audience": "test", "jwk_set_url": "http://" + s.Listener.Addr().String(), @@ -126,7 +127,7 @@ func TestHook_HandleAnnounceValid(t *testing.T) { h, err := build(cfg, nil) require.Nil(t, err) data := make(map[string]string) - data["jwt"] = tokenString + data[authorizationHeader] = bearerAuthPrefix + tokenString _, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{ InfoHash: infoHash, RequestPeer: bittorrent.RequestPeer{}, @@ -143,7 +144,7 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) { // now we wll use HMAC-SHA256 with invalid random key // all errors should be nil except announce request - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims{ + token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: "CN=test", Subject: "CN=test", @@ -162,6 +163,8 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) { require.Nil(t, err) //goland:noinspection HttpUrlsUsage cfg := conf.MapConfig{ + "handle_announce": true, + "header": "jwt", "issuer": "CN=test", "audience": "test", "jwk_set_url": "http://" + s.Listener.Addr().String(), @@ -178,3 +181,44 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) { }, nil) require.NotNil(t, err) } + +func TestHook_HandleScrapeValid(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(jwksData) + })) + defer s.Close() + + token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "CN=test", + Subject: "CN=test", + Audience: []string{"test"}, + ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, + NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, + ID: strconv.FormatInt(rand.Int63(), 16), + }, + InfoHashes: []string{infoHash.String()}, + }) + + token.Header["kid"] = jwksData.Keys[0].KeyID + tokenString, err := token.SignedString(privKey) + require.Nil(t, err) + //goland:noinspection HttpUrlsUsage + cfg := conf.MapConfig{ + "handle_scrape": true, + "issuer": "CN=test", + "audience": "test", + "jwk_set_url": "http://" + s.Listener.Addr().String(), + "jwk_set_update_interval": time.Minute, + } + h, err := build(cfg, nil) + require.Nil(t, err) + data := make(map[string]string) + data[authorizationHeader] = bearerAuthPrefix + tokenString + _, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{ + InfoHashes: bittorrent.InfoHashes{infoHash}, + RequestAddresses: bittorrent.RequestAddresses{}, + Params: ¶ms{data: data}, + }, nil) + require.Nil(t, err) +} diff --git a/middleware/pkg/random/entropy.go b/middleware/pkg/random/entropy.go deleted file mode 100644 index 6fd01bf..0000000 --- a/middleware/pkg/random/entropy.go +++ /dev/null @@ -1,19 +0,0 @@ -package random - -import ( - "encoding/binary" - - "github.com/sot-tech/mochi/bittorrent" -) - -// DeriveEntropyFromRequest generates 2*64 bits of pseudo random state from an -// AnnounceRequest. -// -// Calling DeriveEntropyFromRequest multiple times yields the same values. -func DeriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 uint64) { - if len(req.InfoHash) >= bittorrent.InfoHashV1Len { - v0 = binary.BigEndian.Uint64([]byte(req.InfoHash[:8])) + binary.BigEndian.Uint64([]byte(req.InfoHash[8:16])) - } - v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16]) - return -} diff --git a/middleware/pkg/random/xorshift.go b/middleware/pkg/random/xorshift.go deleted file mode 100644 index 78d5f03..0000000 --- a/middleware/pkg/random/xorshift.go +++ /dev/null @@ -1,28 +0,0 @@ -// Package random implements the XORShift PRNG and a way to derive random state -// from an AnnounceRequest. -package random - -// GenerateAndAdvance applies XORShift128Plus on s0 and s1, returning -// the new states newS0, newS1 and a pseudo-random number v. -func GenerateAndAdvance(s0, s1 uint64) (v, newS0, newS1 uint64) { - v = s0 + s1 - newS0 = s1 - s0 ^= s0 << 23 - newS1 = s0 ^ s1 ^ (s0 >> 18) ^ (s1 >> 5) - return -} - -// Intn generates an int k that satisfies k >= 0 && k < n. -// n must be > 0. -// It returns the generated k and the new state of the generator. -func Intn(s0, s1 uint64, n int) (int, uint64, uint64) { - if n <= 0 { - panic("invalid n <= 0") - } - v, newS0, newS1 := GenerateAndAdvance(s0, s1) - k := int(v) - if k < 0 { - k = -k - } - return k % n, newS0, newS1 -} diff --git a/middleware/pkg/random/xorshift_test.go b/middleware/pkg/random/xorshift_test.go deleted file mode 100644 index 6df98b1..0000000 --- a/middleware/pkg/random/xorshift_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package random - -import ( - "math/rand" - "testing" - - "github.com/stretchr/testify/require" - - _ "github.com/sot-tech/mochi/pkg/randseed" -) - -func TestIntn(t *testing.T) { - s0, s1 := rand.Uint64(), rand.Uint64() - var k int - for i := 0; i < 10000; i++ { - k, s0, s1 = Intn(s0, s1, 10) - require.True(t, k >= 0, "Intn() must be >= 0") - require.True(t, k < 10, "Intn(k) must be < k") - } -} - -func BenchmarkAdvanceXORShift128Plus(b *testing.B) { - s0, s1 := rand.Uint64(), rand.Uint64() - var v uint64 - for i := 0; i < b.N; i++ { - v, s0, s1 = GenerateAndAdvance(s0, s1) - } - _, _, _ = v, s0, s1 -} - -func BenchmarkIntn(b *testing.B) { - s0, s1 := rand.Uint64(), rand.Uint64() - var v int - for i := 0; i < b.N; i++ { - v, s0, s1 = Intn(s0, s1, 1000) - } - _, _, _ = v, s0, s1 -} diff --git a/middleware/varinterval/varinterval.go b/middleware/varinterval/varinterval.go index f9540ac..c677c76 100644 --- a/middleware/varinterval/varinterval.go +++ b/middleware/varinterval/varinterval.go @@ -3,14 +3,15 @@ package varinterval import ( "context" + "encoding/binary" "errors" "fmt" + "math" "sync" "time" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/middleware" - "github.com/sot-tech/mochi/middleware/pkg/random" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/storage" ) @@ -79,14 +80,12 @@ type hook struct { } func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) { - s0, s1 := random.DeriveEntropyFromRequest(req) // Generate a probability p < 1.0. - v, s0, s1 := random.Intn(s0, s1, 1<<24) - p := float32(v) / (1 << 24) - if h.cfg.ModifyResponseProbability == 1 || p < h.cfg.ModifyResponseProbability { + p, s0, s1 := xoroshiro128p(deriveEntropyFromRequest(req)) + if float32(float64(p)/math.MaxUint64) < h.cfg.ModifyResponseProbability { // Generate the increase delta. - v, _, _ = random.Intn(s0, s1, h.cfg.MaxIncreaseDelta) - add := time.Duration(v+1) * time.Second + v, _, _ := xoroshiro128p(s0, s1) + add := time.Duration(v%uint64(h.cfg.MaxIncreaseDelta)+1) * time.Second resp.Interval += add @@ -102,3 +101,26 @@ func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ // Scrapes are not altered. return ctx, nil } + +// deriveEntropyFromRequest generates 2*64 bits of pseudo random state from an +// bittorrent.AnnounceRequest. +// +// Calling deriveEntropyFromRequest multiple times yields the same values. +func deriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 uint64) { + if len(req.InfoHash) >= bittorrent.InfoHashV1Len { + v0 = binary.BigEndian.Uint64([]byte(req.InfoHash[:8])) + binary.BigEndian.Uint64([]byte(req.InfoHash[8:16])) + } + v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16]) + return +} + +// xoroshiro128p calculates predictable pseudorandom number +// with XOR/rotate/shift/rotate 128+ algorithm. +// see https://prng.di.unimi.it/xoroshiro128plus.c +func xoroshiro128p(s0, s1 uint64) (result, ns0, ns1 uint64) { + result = s0 + s1 + s1 ^= s0 + ns0 = ((s0 << 24) | (s0 >> 40)) ^ s1 ^ (s1 << 16) // rotl(s0, 24) ^ s1 ^ (s1 << 16) + ns1 = (s1 << 37) | (s1 >> 27) // rotl(s1, 37) + return +} diff --git a/middleware/varinterval/varinterval_test.go b/middleware/varinterval/varinterval_test.go index f8c298e..8ff8ee6 100644 --- a/middleware/varinterval/varinterval_test.go +++ b/middleware/varinterval/varinterval_test.go @@ -3,12 +3,15 @@ package varinterval import ( "context" "fmt" + "math/rand" "testing" "github.com/stretchr/testify/require" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/pkg/conf" + + _ "github.com/sot-tech/mochi/pkg/randseed" ) var configTests = []struct { @@ -61,3 +64,13 @@ func TestHandleAnnounce(t *testing.T) { require.True(t, resp.Interval > 0, "interval should have been increased") require.True(t, resp.MinInterval > 0, "min_interval should have been increased") } + +func BenchmarkXORoShiRo128Plus(b *testing.B) { + s0, s1 := rand.Uint64(), rand.Uint64() + var v uint64 + b.ResetTimer() + for i := 0; i < b.N; i++ { + v, s0, s1 = xoroshiro128p(s0, s1) + } + _, _, _ = v, s0, s1 +}