(minor) refactor jwt

This commit is contained in:
Lawrence, Rendall
2022-08-26 11:13:18 +03:00
parent 498779aeaf
commit d7777ca583
3 changed files with 90 additions and 77 deletions

View File

@@ -1,5 +1,8 @@
--- ---
run: run:
# mochi in not written with generics (a.t.m),
# so we can check with 1.17
go: "1.17"
timeout: "5m" timeout: "5m"
output: output:
sort-results: true sort-results: true
@@ -51,4 +54,4 @@ issues:
- "EXC0012" # Exported should have comment - "EXC0012" # Exported should have comment
- "EXC0013" # Package comment should be of form - "EXC0013" # Package comment should be of form
- "EXC0014" # Comment on exported should be of form - "EXC0014" # Comment on exported should be of form
- "EXC0015" # Should have a package comment - "EXC0015" # Should have a package comment

View File

@@ -123,27 +123,38 @@ func (h *hook) Stop() stop.Result {
return c.Result() return c.Result()
} }
type compatibleClaims interface { type verifiableClaims interface {
Valid() error jwt.Claims
ToRegisteredClaims() jwt.RegisteredClaims VerifyIssuer(iss string, req bool) bool
GetIssuer() string
VerifyAudience(aud string, req bool) bool
GetAudience() []string
}
type registeredClaimsWrapper struct {
jwt.RegisteredClaims
}
func (rc registeredClaimsWrapper) GetIssuer() string {
return rc.Issuer
}
func (rc registeredClaimsWrapper) GetAudience() []string {
return rc.Audience
} }
type announceClaims struct { type announceClaims struct {
jwt.RegisteredClaims registeredClaimsWrapper
InfoHash string `json:"infohash,omitempty"` InfoHash string `json:"infohash,omitempty"`
} }
func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims {
return ac.RegisteredClaims
}
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) { func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
if !h.cfg.HandleAnnounce { if !h.cfg.HandleAnnounce {
return ctx, nil return ctx, nil
} }
var err error var err error
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 { if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
err = ErrMissingJWT err = ErrMissingJWT
} else { } else {
claims := new(announceClaims) claims := new(announceClaims)
@@ -159,15 +170,15 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
logger.Info(). logger.Info().
Err(err). Err(err).
Object("source", req.RequestPeer). Object("source", req.RequestPeer).
Msg("InfoHash claim parse failed") Msg("'infohash' claim parse failed")
err = ErrInvalidJWT err = ErrInvalidJWT
} }
if req.InfoHash != claimIH { if req.InfoHash != claimIH {
logger.Info(). logger.Info().
Stringer("provided", claimIH). Stringer("claimInfoHash", claimIH).
Stringer("required", req.InfoHash). Stringer("requestInfoHash", req.InfoHash).
Object("source", req.RequestPeer). Object("peer", req.RequestPeer).
Msg("InfoHash claim not equals to request InfoHash") Msg("unequal 'infohash' claim when validating JWT")
err = ErrInvalidJWT err = ErrInvalidJWT
} }
} }
@@ -177,14 +188,10 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
} }
type scrapeClaims struct { type scrapeClaims struct {
jwt.RegisteredClaims registeredClaimsWrapper
InfoHashes []string `json:"infohashes,omitempty"` InfoHashes []string `json:"infohashes,omitempty"`
} }
func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims {
return sc.RegisteredClaims
}
func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) { func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
if !h.cfg.HandleScrape { if !h.cfg.HandleScrape {
return ctx, nil return ctx, nil
@@ -192,7 +199,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
var err error var err error
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 { if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
err = ErrMissingJWT err = ErrMissingJWT
} else { } else {
claims := new(scrapeClaims) claims := new(scrapeClaims)
@@ -210,8 +217,8 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
} else { } else {
logger.Info(). logger.Info().
Err(err). Err(err).
Array("source", req.RequestAddresses). Array("addresses", req.RequestAddresses).
Msg("InfoHash claim parse failed") Msg("'infohashes' claim parse failed")
} }
} }
eq := len(req.InfoHashes) == len(claimIHs) eq := len(req.InfoHashes) == len(claimIHs)
@@ -232,10 +239,10 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
} }
if !eq { if !eq {
logger.Info(). logger.Info().
Array("provided", claimIHs). Array("claimInfoHashes", claimIHs).
Array("required", req.InfoHashes). Array("requestInfoHashes", req.InfoHashes).
Array("source", req.RequestAddresses). Array("addresses", req.RequestAddresses).
Msg("InfoHashes claim not equals to request InfoHashes") Msg("unequal 'infohashes' claim when validating JWT")
err = ErrInvalidJWT err = ErrInvalidJWT
} }
} }
@@ -244,7 +251,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
return ctx, err return ctx, err
} }
func (h *hook) getJWT(params bittorrent.Params) (jwt string) { func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
if params != nil { if params != nil {
var found bool var found bool
if jwt, found = params.String(h.cfg.Header); found { if jwt, found = params.String(h.cfg.Header); found {
@@ -256,7 +263,7 @@ func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
return return
} }
func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) { func (h *hook) validateBaseJWT(jwtParam string, claims verifiableClaims) (errs []error) {
if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) { if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) {
jwtParam = jwtParam[len(bearerAuthPrefix):] jwtParam = jwtParam[len(bearerAuthPrefix):]
} }
@@ -266,25 +273,17 @@ func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs [
if err := claims.Valid(); err != nil { if err := claims.Valid(); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 {
errs = append(errs, errs0...)
}
return
}
func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) { if !claims.VerifyIssuer(h.cfg.Issuer, true) {
rc := cl.ToRegisteredClaims()
if !rc.VerifyIssuer(h.cfg.Issuer, true) {
logger.Debug(). logger.Debug().
Str("provided", rc.Issuer). Str("provided", claims.GetIssuer()).
Str("required", h.cfg.Issuer). Str("required", h.cfg.Issuer).
Msg("unequal or missing issuer when validating JWT") Msg("unequal or missing issuer when validating JWT")
errs = append(errs, jwt.ErrTokenInvalidIssuer) errs = append(errs, jwt.ErrTokenInvalidIssuer)
} }
if !claims.VerifyAudience(h.cfg.Audience, true) {
if !rc.VerifyAudience(h.cfg.Audience, true) {
logger.Debug(). logger.Debug().
Strs("provided", rc.Audience). Strs("provided", claims.GetAudience()).
Str("required", h.cfg.Audience). Str("required", h.cfg.Audience).
Msg("unequal or missing audience when validating JWT") Msg("unequal or missing audience when validating JWT")
errs = append(errs, jwt.ErrTokenInvalidAudience) errs = append(errs, jwt.ErrTokenInvalidAudience)

View File

@@ -53,12 +53,10 @@ type JWKSKeys struct {
Keys []JWKSKey `json:"keys"` Keys []JWKSKey `json:"keys"`
} }
type params struct { type params map[string]string
data map[string]string
}
func (p params) String(key string) (out string, found bool) { func (p params) String(key string) (out string, found bool) {
out, found = p.data[key] out, found = p[key]
return return
} }
@@ -70,9 +68,7 @@ func (params) RawQuery() (s string) {
return return
} }
func (params) MarshalZerologObject(*zerolog.Event) { func (params) MarshalZerologObject(*zerolog.Event) {}
return
}
func init() { func init() {
_ = log.ConfigureLogger("", "info", false, false) _ = log.ConfigureLogger("", "info", false, false)
@@ -102,13 +98,15 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
defer s.Close() defer s.Close()
token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{ token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{
RegisteredClaims: jwt.RegisteredClaims{ registeredClaimsWrapper: registeredClaimsWrapper{
Issuer: "CN=test", RegisteredClaims: jwt.RegisteredClaims{
Subject: "CN=test", Issuer: "CN=test",
Audience: []string{"test"}, Subject: "CN=test",
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, Audience: []string{"test"},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16), NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
},
}, },
InfoHash: infoHash.String(), InfoHash: infoHash.String(),
}) })
@@ -126,12 +124,12 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
} }
h, err := build(cfg, nil) h, err := build(cfg, nil)
require.Nil(t, err) require.Nil(t, err)
data := make(map[string]string) data := make(params)
data[authorizationHeader] = bearerAuthPrefix + tokenString data[authorizationHeader] = bearerAuthPrefix + tokenString
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{ _, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
InfoHash: infoHash, InfoHash: infoHash,
RequestPeer: bittorrent.RequestPeer{}, RequestPeer: bittorrent.RequestPeer{},
Params: &params{data: data}, Params: data,
}, nil) }, nil)
require.Nil(t, err) require.Nil(t, err)
} }
@@ -145,13 +143,15 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
// now we wll use HMAC-SHA256 with invalid random key // now we wll use HMAC-SHA256 with invalid random key
// all errors should be nil except announce request // all errors should be nil except announce request
token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{ token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{
RegisteredClaims: jwt.RegisteredClaims{ registeredClaimsWrapper: registeredClaimsWrapper{
Issuer: "CN=test", RegisteredClaims: jwt.RegisteredClaims{
Subject: "CN=test", Issuer: "CN=test",
Audience: []string{"test"}, Subject: "CN=test",
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, Audience: []string{"test"},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16), NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
},
}, },
InfoHash: infoHash.String(), InfoHash: infoHash.String(),
}) })
@@ -172,12 +172,12 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
} }
h, err := build(cfg, nil) h, err := build(cfg, nil)
require.Nil(t, err) require.Nil(t, err)
data := make(map[string]string) data := make(params)
data["jwt"] = tokenString data["jwt"] = tokenString
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{ _, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
InfoHash: infoHash, InfoHash: infoHash,
RequestPeer: bittorrent.RequestPeer{}, RequestPeer: bittorrent.RequestPeer{},
Params: &params{data: data}, Params: data,
}, nil) }, nil)
require.NotNil(t, err) require.NotNil(t, err)
} }
@@ -188,16 +188,27 @@ func TestHook_HandleScrapeValid(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
ihs := make(bittorrent.InfoHashes, rand.Intn(10)+1)
ihss := make([]string, len(ihs))
for i := range ihs {
bb := []byte(infoHash)
bb[i] = byte(i)
ihs[i] = bittorrent.InfoHash(bb)
ihss[i] = ihs[i].String()
}
token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{ token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{
RegisteredClaims: jwt.RegisteredClaims{ registeredClaimsWrapper: registeredClaimsWrapper{
Issuer: "CN=test", RegisteredClaims: jwt.RegisteredClaims{
Subject: "CN=test", Issuer: "CN=test",
Audience: []string{"test"}, Subject: "CN=test",
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)}, Audience: []string{"test"},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)}, ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16), NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
},
}, },
InfoHashes: []string{infoHash.String()}, InfoHashes: ihss,
}) })
token.Header["kid"] = jwksData.Keys[0].KeyID token.Header["kid"] = jwksData.Keys[0].KeyID
@@ -213,12 +224,12 @@ func TestHook_HandleScrapeValid(t *testing.T) {
} }
h, err := build(cfg, nil) h, err := build(cfg, nil)
require.Nil(t, err) require.Nil(t, err)
data := make(map[string]string) data := make(params)
data[authorizationHeader] = bearerAuthPrefix + tokenString data[authorizationHeader] = bearerAuthPrefix + tokenString
_, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{ _, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{
InfoHashes: bittorrent.InfoHashes{infoHash}, InfoHashes: ihs,
RequestAddresses: bittorrent.RequestAddresses{}, RequestAddresses: bittorrent.RequestAddresses{},
Params: &params{data: data}, Params: data,
}, nil) }, nil)
require.Nil(t, err) require.Nil(t, err)
} }