diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 4abb5db..8a3ace4 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -8,7 +8,6 @@ package jwt import ( "context" - "crypto/subtle" "errors" "fmt" "time" @@ -43,7 +42,13 @@ var ( errJWKsNotSet = errors.New("required parameters not provided: Issuer, Audience and/or JWKSetURL") - hmacAlgorithms = jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name, jwt.SigningMethodHS384.Name, jwt.SigningMethodHS512.Name}) + hmacAlgorithms = jwt.WithValidMethods([]string{ + jwt.SigningMethodHS256.Alg(), jwt.SigningMethodHS384.Alg(), jwt.SigningMethodHS512.Alg(), + jwt.SigningMethodRS256.Alg(), jwt.SigningMethodRS384.Alg(), jwt.SigningMethodRS512.Alg(), + jwt.SigningMethodPS256.Alg(), jwt.SigningMethodPS384.Alg(), jwt.SigningMethodPS512.Alg(), + jwt.SigningMethodES256.Alg(), jwt.SigningMethodES384.Alg(), jwt.SigningMethodES512.Alg(), + jwt.SigningMethodEdDSA.Alg(), + }) ) // Config represents all the values required by this middleware to fetch JWKs @@ -135,7 +140,8 @@ func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ func (h *hook) validateJWT(ih bittorrent.InfoHash, rawJwt string) []error { // KeyFunc will check KID, Parse will check ALG and signature errs := make([]error, 0, 4) - token, err := jwt.ParseWithClaims(rawJwt, claims{}, h.jwks.Keyfunc, hmacAlgorithms) + claims := new(claims) + token, err := jwt.ParseWithClaims(rawJwt, claims, h.jwks.Keyfunc, hmacAlgorithms) if err != nil { return []error{err} } @@ -144,8 +150,6 @@ func (h *hook) validateJWT(ih bittorrent.InfoHash, rawJwt string) []error { errs = append(errs, err) } - claims := token.Claims.(claims) - if !claims.VerifyIssuer(h.cfg.Issuer, true) { logger.Debug(). Str("provided", claims.Issuer). @@ -166,8 +170,8 @@ func (h *hook) validateJWT(ih bittorrent.InfoHash, rawJwt string) []error { if err != nil { errs = append(errs, err) } - if subtle.ConstantTimeCompare([]byte(providedIh), []byte(ih)) != 0 { - logger.Error(). + if providedIh != ih { + logger.Debug(). Err(err). Stringer("provided", providedIh). Stringer("required", ih). diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go new file mode 100644 index 0000000..4e42130 --- /dev/null +++ b/middleware/jwt/jwt_test.go @@ -0,0 +1,180 @@ +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "encoding/base64" + "encoding/json" + "math/rand" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/minio/sha256-simd" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/pkg/conf" + "github.com/sot-tech/mochi/pkg/log" + _ "github.com/sot-tech/mochi/pkg/randseed" +) + +const ( + privKeyPEM = ` +-----BEGIN PRIVATE KEY----- +MEECAQAwEwYHKoZIzj0CAQYIKoZIzj0DAQcEJzAlAgEBBCCI7Zc2IUKazCBCK5VY +WxxE6lVhGR+exaWgrh0Yq9t4gQ== +-----END PRIVATE KEY----- +` +) + +var ( + privKey *ecdsa.PrivateKey + infoHash bittorrent.InfoHash + jwksData JWKSKeys +) + +type JWKSKey struct { + KeyType string `json:"kty"` + Usage string `json:"use"` + KeyID string `json:"kid"` + Algorithm string `json:"alg"` + Curve string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` +} + +type JWKSKeys struct { + Keys []JWKSKey `json:"keys"` +} + +type params struct { + data map[string]string +} + +func (p params) String(key string) (out string, found bool) { + out, found = p.data[key] + return +} + +func (params) RawPath() (s string) { + return +} + +func (params) RawQuery() (s string) { + return +} + +func (params) MarshalZerologObject(*zerolog.Event) { + return +} + +func init() { + _ = log.ConfigureLogger("", "info", false, false) + privKey, _ = jwt.ParseECPrivateKeyFromPEM([]byte(privKeyPEM)) + ihBytes := make([]byte, bittorrent.InfoHashV1Len) + rand.Read(ihBytes) + infoHash, _ = bittorrent.NewInfoHash(ihBytes) + s2 := sha256.New() + s2.Write(elliptic.Marshal(privKey.PublicKey.Curve, privKey.PublicKey.X, privKey.PublicKey.Y)) + jwksData = JWKSKeys{Keys: []JWKSKey{ + { + KeyType: "EC", + Usage: "sig", + KeyID: base64.RawURLEncoding.EncodeToString(s2.Sum(nil)), + Algorithm: jwt.SigningMethodES256.Name, + Curve: privKey.Curve.Params().Name, + X: base64.RawURLEncoding.EncodeToString(privKey.PublicKey.X.Bytes()), + Y: base64.RawURLEncoding.EncodeToString(privKey.PublicKey.Y.Bytes()), + }, + }} +} + +func TestHook_HandleAnnounceValid(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(jwksData) + })) + defer s.Close() + + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "CN=test", + Subject: "CN=test", + Audience: []string{"test"}, + ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, + NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, + ID: strconv.FormatInt(rand.Int63(), 16), + }, + InfoHash: infoHash.String(), + }) + + token.Header["kid"] = jwksData.Keys[0].KeyID + tokenString, err := token.SignedString(privKey) + require.Nil(t, err) + //goland:noinspection HttpUrlsUsage + cfg := conf.MapConfig{ + "issuer": "CN=test", + "audience": "test", + "jwk_set_url": "http://" + s.Listener.Addr().String(), + "jwk_set_update_interval": time.Minute, + } + h, err := build(cfg, nil) + require.Nil(t, err) + data := make(map[string]string) + data["jwt"] = tokenString + _, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{ + InfoHash: infoHash, + RequestPeer: bittorrent.RequestPeer{}, + Params: ¶ms{data: data}, + }, nil) + require.Nil(t, err) +} + +func TestHook_HandleAnnounceInvalid(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(jwksData) + })) + defer s.Close() + + // now we wll use HMAC-SHA256 with invalid random key + // all errors should be nil except announce request + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "CN=test", + Subject: "CN=test", + Audience: []string{"test"}, + ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, + NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, + ID: strconv.FormatInt(rand.Int63(), 16), + }, + InfoHash: infoHash.String(), + }) + + token.Header["kid"] = jwksData.Keys[0].KeyID + k := make([]byte, 20) + rand.Read(k) + tokenString, err := token.SignedString(k) + require.Nil(t, err) + //goland:noinspection HttpUrlsUsage + cfg := conf.MapConfig{ + "issuer": "CN=test", + "audience": "test", + "jwk_set_url": "http://" + s.Listener.Addr().String(), + "jwk_set_update_interval": time.Minute, + } + h, err := build(cfg, nil) + require.Nil(t, err) + data := make(map[string]string) + data["jwt"] = tokenString + _, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{ + InfoHash: infoHash, + RequestPeer: bittorrent.RequestPeer{}, + Params: ¶ms{data: data}, + }, nil) + require.NotNil(t, err) +}