(minor) refactor jwt

This commit is contained in:
Lawrence, Rendall
2022-08-26 11:13:18 +03:00
parent 498779aeaf
commit d7777ca583
3 changed files with 90 additions and 77 deletions

View File

@@ -1,5 +1,8 @@
---
run:
# mochi in not written with generics (a.t.m),
# so we can check with 1.17
go: "1.17"
timeout: "5m"
output:
sort-results: true
@@ -51,4 +54,4 @@ issues:
- "EXC0012" # Exported should have comment
- "EXC0013" # Package comment should be of form
- "EXC0014" # Comment on exported should be of form
- "EXC0015" # Should have a package comment
- "EXC0015" # Should have a package comment

View File

@@ -123,27 +123,38 @@ func (h *hook) Stop() stop.Result {
return c.Result()
}
type compatibleClaims interface {
Valid() error
ToRegisteredClaims() jwt.RegisteredClaims
type verifiableClaims interface {
jwt.Claims
VerifyIssuer(iss string, req bool) bool
GetIssuer() string
VerifyAudience(aud string, req bool) bool
GetAudience() []string
}
type registeredClaimsWrapper struct {
jwt.RegisteredClaims
}
func (rc registeredClaimsWrapper) GetIssuer() string {
return rc.Issuer
}
func (rc registeredClaimsWrapper) GetAudience() []string {
return rc.Audience
}
type announceClaims struct {
jwt.RegisteredClaims
registeredClaimsWrapper
InfoHash string `json:"infohash,omitempty"`
}
func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims {
return ac.RegisteredClaims
}
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
if !h.cfg.HandleAnnounce {
return ctx, nil
}
var err error
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
err = ErrMissingJWT
} else {
claims := new(announceClaims)
@@ -159,15 +170,15 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
logger.Info().
Err(err).
Object("source", req.RequestPeer).
Msg("InfoHash claim parse failed")
Msg("'infohash' claim parse failed")
err = ErrInvalidJWT
}
if req.InfoHash != claimIH {
logger.Info().
Stringer("provided", claimIH).
Stringer("required", req.InfoHash).
Object("source", req.RequestPeer).
Msg("InfoHash claim not equals to request InfoHash")
Stringer("claimInfoHash", claimIH).
Stringer("requestInfoHash", req.InfoHash).
Object("peer", req.RequestPeer).
Msg("unequal 'infohash' claim when validating JWT")
err = ErrInvalidJWT
}
}
@@ -177,14 +188,10 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
}
type scrapeClaims struct {
jwt.RegisteredClaims
registeredClaimsWrapper
InfoHashes []string `json:"infohashes,omitempty"`
}
func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims {
return sc.RegisteredClaims
}
func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
if !h.cfg.HandleScrape {
return ctx, nil
@@ -192,7 +199,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
var err error
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
err = ErrMissingJWT
} else {
claims := new(scrapeClaims)
@@ -210,8 +217,8 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
} else {
logger.Info().
Err(err).
Array("source", req.RequestAddresses).
Msg("InfoHash claim parse failed")
Array("addresses", req.RequestAddresses).
Msg("'infohashes' claim parse failed")
}
}
eq := len(req.InfoHashes) == len(claimIHs)
@@ -232,10 +239,10 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
}
if !eq {
logger.Info().
Array("provided", claimIHs).
Array("required", req.InfoHashes).
Array("source", req.RequestAddresses).
Msg("InfoHashes claim not equals to request InfoHashes")
Array("claimInfoHashes", claimIHs).
Array("requestInfoHashes", req.InfoHashes).
Array("addresses", req.RequestAddresses).
Msg("unequal 'infohashes' claim when validating JWT")
err = ErrInvalidJWT
}
}
@@ -244,7 +251,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
return ctx, err
}
func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
if params != nil {
var found bool
if jwt, found = params.String(h.cfg.Header); found {
@@ -256,7 +263,7 @@ func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
return
}
func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) {
func (h *hook) validateBaseJWT(jwtParam string, claims verifiableClaims) (errs []error) {
if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) {
jwtParam = jwtParam[len(bearerAuthPrefix):]
}
@@ -266,25 +273,17 @@ func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs [
if err := claims.Valid(); err != nil {
errs = append(errs, err)
}
if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 {
errs = append(errs, errs0...)
}
return
}
func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) {
rc := cl.ToRegisteredClaims()
if !rc.VerifyIssuer(h.cfg.Issuer, true) {
if !claims.VerifyIssuer(h.cfg.Issuer, true) {
logger.Debug().
Str("provided", rc.Issuer).
Str("provided", claims.GetIssuer()).
Str("required", h.cfg.Issuer).
Msg("unequal or missing issuer when validating JWT")
errs = append(errs, jwt.ErrTokenInvalidIssuer)
}
if !rc.VerifyAudience(h.cfg.Audience, true) {
if !claims.VerifyAudience(h.cfg.Audience, true) {
logger.Debug().
Strs("provided", rc.Audience).
Strs("provided", claims.GetAudience()).
Str("required", h.cfg.Audience).
Msg("unequal or missing audience when validating JWT")
errs = append(errs, jwt.ErrTokenInvalidAudience)

View File

@@ -53,12 +53,10 @@ type JWKSKeys struct {
Keys []JWKSKey `json:"keys"`
}
type params struct {
data map[string]string
}
type params map[string]string
func (p params) String(key string) (out string, found bool) {
out, found = p.data[key]
out, found = p[key]
return
}
@@ -70,9 +68,7 @@ func (params) RawQuery() (s string) {
return
}
func (params) MarshalZerologObject(*zerolog.Event) {
return
}
func (params) MarshalZerologObject(*zerolog.Event) {}
func init() {
_ = log.ConfigureLogger("", "info", false, false)
@@ -102,13 +98,15 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
defer s.Close()
token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
Audience: []string{"test"},
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
registeredClaimsWrapper: registeredClaimsWrapper{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
Audience: []string{"test"},
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
},
},
InfoHash: infoHash.String(),
})
@@ -126,12 +124,12 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
}
h, err := build(cfg, nil)
require.Nil(t, err)
data := make(map[string]string)
data := make(params)
data[authorizationHeader] = bearerAuthPrefix + tokenString
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
InfoHash: infoHash,
RequestPeer: bittorrent.RequestPeer{},
Params: &params{data: data},
Params: data,
}, nil)
require.Nil(t, err)
}
@@ -145,13 +143,15 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
// now we wll use HMAC-SHA256 with invalid random key
// all errors should be nil except announce request
token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
Audience: []string{"test"},
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
registeredClaimsWrapper: registeredClaimsWrapper{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
Audience: []string{"test"},
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
},
},
InfoHash: infoHash.String(),
})
@@ -172,12 +172,12 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
}
h, err := build(cfg, nil)
require.Nil(t, err)
data := make(map[string]string)
data := make(params)
data["jwt"] = tokenString
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
InfoHash: infoHash,
RequestPeer: bittorrent.RequestPeer{},
Params: &params{data: data},
Params: data,
}, nil)
require.NotNil(t, err)
}
@@ -188,16 +188,27 @@ func TestHook_HandleScrapeValid(t *testing.T) {
}))
defer s.Close()
ihs := make(bittorrent.InfoHashes, rand.Intn(10)+1)
ihss := make([]string, len(ihs))
for i := range ihs {
bb := []byte(infoHash)
bb[i] = byte(i)
ihs[i] = bittorrent.InfoHash(bb)
ihss[i] = ihs[i].String()
}
token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
Audience: []string{"test"},
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
registeredClaimsWrapper: registeredClaimsWrapper{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "CN=test",
Subject: "CN=test",
Audience: []string{"test"},
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
ID: strconv.FormatInt(rand.Int63(), 16),
},
},
InfoHashes: []string{infoHash.String()},
InfoHashes: ihss,
})
token.Header["kid"] = jwksData.Keys[0].KeyID
@@ -213,12 +224,12 @@ func TestHook_HandleScrapeValid(t *testing.T) {
}
h, err := build(cfg, nil)
require.Nil(t, err)
data := make(map[string]string)
data := make(params)
data[authorizationHeader] = bearerAuthPrefix + tokenString
_, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{
InfoHashes: bittorrent.InfoHashes{infoHash},
InfoHashes: ihs,
RequestAddresses: bittorrent.RequestAddresses{},
Params: &params{data: data},
Params: data,
}, nil)
require.Nil(t, err)
}