diff --git a/.golangci.yaml b/.golangci.yaml index 31cbeb7..b315bdb 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,5 +1,8 @@ --- run: + # mochi in not written with generics (a.t.m), + # so we can check with 1.17 + go: "1.17" timeout: "5m" output: sort-results: true @@ -51,4 +54,4 @@ issues: - "EXC0012" # Exported should have comment - "EXC0013" # Package comment should be of form - "EXC0014" # Comment on exported should be of form - - "EXC0015" # Should have a package comment \ No newline at end of file + - "EXC0015" # Should have a package comment diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 8e700a2..82aae1b 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -123,27 +123,38 @@ func (h *hook) Stop() stop.Result { return c.Result() } -type compatibleClaims interface { - Valid() error - ToRegisteredClaims() jwt.RegisteredClaims +type verifiableClaims interface { + jwt.Claims + VerifyIssuer(iss string, req bool) bool + GetIssuer() string + VerifyAudience(aud string, req bool) bool + GetAudience() []string +} + +type registeredClaimsWrapper struct { + jwt.RegisteredClaims +} + +func (rc registeredClaimsWrapper) GetIssuer() string { + return rc.Issuer +} + +func (rc registeredClaimsWrapper) GetAudience() []string { + return rc.Audience } type announceClaims struct { - jwt.RegisteredClaims + registeredClaimsWrapper InfoHash string `json:"infohash,omitempty"` } -func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims { - return ac.RegisteredClaims -} - func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) { if !h.cfg.HandleAnnounce { return ctx, nil } var err error - if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 { + if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 { err = ErrMissingJWT } else { claims := new(announceClaims) @@ -159,15 +170,15 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque logger.Info(). Err(err). Object("source", req.RequestPeer). - Msg("InfoHash claim parse failed") + Msg("'infohash' claim parse failed") err = ErrInvalidJWT } if req.InfoHash != claimIH { logger.Info(). - Stringer("provided", claimIH). - Stringer("required", req.InfoHash). - Object("source", req.RequestPeer). - Msg("InfoHash claim not equals to request InfoHash") + Stringer("claimInfoHash", claimIH). + Stringer("requestInfoHash", req.InfoHash). + Object("peer", req.RequestPeer). + Msg("unequal 'infohash' claim when validating JWT") err = ErrInvalidJWT } } @@ -177,14 +188,10 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque } type scrapeClaims struct { - jwt.RegisteredClaims + registeredClaimsWrapper InfoHashes []string `json:"infohashes,omitempty"` } -func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims { - return sc.RegisteredClaims -} - func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) { if !h.cfg.HandleScrape { return ctx, nil @@ -192,7 +199,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, var err error - if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 { + if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 { err = ErrMissingJWT } else { claims := new(scrapeClaims) @@ -210,8 +217,8 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, } else { logger.Info(). Err(err). - Array("source", req.RequestAddresses). - Msg("InfoHash claim parse failed") + Array("addresses", req.RequestAddresses). + Msg("'infohashes' claim parse failed") } } eq := len(req.InfoHashes) == len(claimIHs) @@ -232,10 +239,10 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, } if !eq { logger.Info(). - Array("provided", claimIHs). - Array("required", req.InfoHashes). - Array("source", req.RequestAddresses). - Msg("InfoHashes claim not equals to request InfoHashes") + Array("claimInfoHashes", claimIHs). + Array("requestInfoHashes", req.InfoHashes). + Array("addresses", req.RequestAddresses). + Msg("unequal 'infohashes' claim when validating JWT") err = ErrInvalidJWT } } @@ -244,7 +251,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, return ctx, err } -func (h *hook) getJWT(params bittorrent.Params) (jwt string) { +func (h *hook) getJWTString(params bittorrent.Params) (jwt string) { if params != nil { var found bool if jwt, found = params.String(h.cfg.Header); found { @@ -256,7 +263,7 @@ func (h *hook) getJWT(params bittorrent.Params) (jwt string) { return } -func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) { +func (h *hook) validateBaseJWT(jwtParam string, claims verifiableClaims) (errs []error) { if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) { jwtParam = jwtParam[len(bearerAuthPrefix):] } @@ -266,25 +273,17 @@ func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs [ if err := claims.Valid(); err != nil { errs = append(errs, err) } - if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 { - errs = append(errs, errs0...) - } - return -} -func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) { - rc := cl.ToRegisteredClaims() - if !rc.VerifyIssuer(h.cfg.Issuer, true) { + if !claims.VerifyIssuer(h.cfg.Issuer, true) { logger.Debug(). - Str("provided", rc.Issuer). + Str("provided", claims.GetIssuer()). Str("required", h.cfg.Issuer). Msg("unequal or missing issuer when validating JWT") errs = append(errs, jwt.ErrTokenInvalidIssuer) } - - if !rc.VerifyAudience(h.cfg.Audience, true) { + if !claims.VerifyAudience(h.cfg.Audience, true) { logger.Debug(). - Strs("provided", rc.Audience). + Strs("provided", claims.GetAudience()). Str("required", h.cfg.Audience). Msg("unequal or missing audience when validating JWT") errs = append(errs, jwt.ErrTokenInvalidAudience) diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go index 046b83b..04a4983 100644 --- a/middleware/jwt/jwt_test.go +++ b/middleware/jwt/jwt_test.go @@ -53,12 +53,10 @@ type JWKSKeys struct { Keys []JWKSKey `json:"keys"` } -type params struct { - data map[string]string -} +type params map[string]string func (p params) String(key string) (out string, found bool) { - out, found = p.data[key] + out, found = p[key] return } @@ -70,9 +68,7 @@ func (params) RawQuery() (s string) { return } -func (params) MarshalZerologObject(*zerolog.Event) { - return -} +func (params) MarshalZerologObject(*zerolog.Event) {} func init() { _ = log.ConfigureLogger("", "info", false, false) @@ -102,13 +98,15 @@ func TestHook_HandleAnnounceValid(t *testing.T) { defer s.Close() token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "CN=test", - Subject: "CN=test", - Audience: []string{"test"}, - ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, - NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, - ID: strconv.FormatInt(rand.Int63(), 16), + registeredClaimsWrapper: registeredClaimsWrapper{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "CN=test", + Subject: "CN=test", + Audience: []string{"test"}, + ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, + NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, + ID: strconv.FormatInt(rand.Int63(), 16), + }, }, InfoHash: infoHash.String(), }) @@ -126,12 +124,12 @@ func TestHook_HandleAnnounceValid(t *testing.T) { } h, err := build(cfg, nil) require.Nil(t, err) - data := make(map[string]string) + data := make(params) data[authorizationHeader] = bearerAuthPrefix + tokenString _, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{ InfoHash: infoHash, RequestPeer: bittorrent.RequestPeer{}, - Params: ¶ms{data: data}, + Params: data, }, nil) require.Nil(t, err) } @@ -145,13 +143,15 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) { // now we wll use HMAC-SHA256 with invalid random key // all errors should be nil except announce request token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "CN=test", - Subject: "CN=test", - Audience: []string{"test"}, - ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, - NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, - ID: strconv.FormatInt(rand.Int63(), 16), + registeredClaimsWrapper: registeredClaimsWrapper{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "CN=test", + Subject: "CN=test", + Audience: []string{"test"}, + ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, + NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, + ID: strconv.FormatInt(rand.Int63(), 16), + }, }, InfoHash: infoHash.String(), }) @@ -172,12 +172,12 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) { } h, err := build(cfg, nil) require.Nil(t, err) - data := make(map[string]string) + data := make(params) data["jwt"] = tokenString _, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{ InfoHash: infoHash, RequestPeer: bittorrent.RequestPeer{}, - Params: ¶ms{data: data}, + Params: data, }, nil) require.NotNil(t, err) } @@ -188,16 +188,27 @@ func TestHook_HandleScrapeValid(t *testing.T) { })) defer s.Close() + ihs := make(bittorrent.InfoHashes, rand.Intn(10)+1) + ihss := make([]string, len(ihs)) + for i := range ihs { + bb := []byte(infoHash) + bb[i] = byte(i) + ihs[i] = bittorrent.InfoHash(bb) + ihss[i] = ihs[i].String() + } + token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "CN=test", - Subject: "CN=test", - Audience: []string{"test"}, - ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, - NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, - ID: strconv.FormatInt(rand.Int63(), 16), + registeredClaimsWrapper: registeredClaimsWrapper{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "CN=test", + Subject: "CN=test", + Audience: []string{"test"}, + ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, + NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, + ID: strconv.FormatInt(rand.Int63(), 16), + }, }, - InfoHashes: []string{infoHash.String()}, + InfoHashes: ihss, }) token.Header["kid"] = jwksData.Keys[0].KeyID @@ -213,12 +224,12 @@ func TestHook_HandleScrapeValid(t *testing.T) { } h, err := build(cfg, nil) require.Nil(t, err) - data := make(map[string]string) + data := make(params) data[authorizationHeader] = bearerAuthPrefix + tokenString _, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{ - InfoHashes: bittorrent.InfoHashes{infoHash}, + InfoHashes: ihs, RequestAddresses: bittorrent.RequestAddresses{}, - Params: ¶ms{data: data}, + Params: data, }, nil) require.Nil(t, err) }