mirror of
https://github.com/sot-tech/mochi.git
synced 2026-05-23 08:14:48 -07:00
(minor) refactor jwt
This commit is contained in:
@@ -1,5 +1,8 @@
|
|||||||
---
|
---
|
||||||
run:
|
run:
|
||||||
|
# mochi in not written with generics (a.t.m),
|
||||||
|
# so we can check with 1.17
|
||||||
|
go: "1.17"
|
||||||
timeout: "5m"
|
timeout: "5m"
|
||||||
output:
|
output:
|
||||||
sort-results: true
|
sort-results: true
|
||||||
@@ -51,4 +54,4 @@ issues:
|
|||||||
- "EXC0012" # Exported should have comment
|
- "EXC0012" # Exported should have comment
|
||||||
- "EXC0013" # Package comment should be of form
|
- "EXC0013" # Package comment should be of form
|
||||||
- "EXC0014" # Comment on exported should be of form
|
- "EXC0014" # Comment on exported should be of form
|
||||||
- "EXC0015" # Should have a package comment
|
- "EXC0015" # Should have a package comment
|
||||||
|
|||||||
@@ -123,27 +123,38 @@ func (h *hook) Stop() stop.Result {
|
|||||||
return c.Result()
|
return c.Result()
|
||||||
}
|
}
|
||||||
|
|
||||||
type compatibleClaims interface {
|
type verifiableClaims interface {
|
||||||
Valid() error
|
jwt.Claims
|
||||||
ToRegisteredClaims() jwt.RegisteredClaims
|
VerifyIssuer(iss string, req bool) bool
|
||||||
|
GetIssuer() string
|
||||||
|
VerifyAudience(aud string, req bool) bool
|
||||||
|
GetAudience() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type registeredClaimsWrapper struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc registeredClaimsWrapper) GetIssuer() string {
|
||||||
|
return rc.Issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc registeredClaimsWrapper) GetAudience() []string {
|
||||||
|
return rc.Audience
|
||||||
}
|
}
|
||||||
|
|
||||||
type announceClaims struct {
|
type announceClaims struct {
|
||||||
jwt.RegisteredClaims
|
registeredClaimsWrapper
|
||||||
InfoHash string `json:"infohash,omitempty"`
|
InfoHash string `json:"infohash,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims {
|
|
||||||
return ac.RegisteredClaims
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||||
if !h.cfg.HandleAnnounce {
|
if !h.cfg.HandleAnnounce {
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
|
if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
|
||||||
err = ErrMissingJWT
|
err = ErrMissingJWT
|
||||||
} else {
|
} else {
|
||||||
claims := new(announceClaims)
|
claims := new(announceClaims)
|
||||||
@@ -159,15 +170,15 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
|||||||
logger.Info().
|
logger.Info().
|
||||||
Err(err).
|
Err(err).
|
||||||
Object("source", req.RequestPeer).
|
Object("source", req.RequestPeer).
|
||||||
Msg("InfoHash claim parse failed")
|
Msg("'infohash' claim parse failed")
|
||||||
err = ErrInvalidJWT
|
err = ErrInvalidJWT
|
||||||
}
|
}
|
||||||
if req.InfoHash != claimIH {
|
if req.InfoHash != claimIH {
|
||||||
logger.Info().
|
logger.Info().
|
||||||
Stringer("provided", claimIH).
|
Stringer("claimInfoHash", claimIH).
|
||||||
Stringer("required", req.InfoHash).
|
Stringer("requestInfoHash", req.InfoHash).
|
||||||
Object("source", req.RequestPeer).
|
Object("peer", req.RequestPeer).
|
||||||
Msg("InfoHash claim not equals to request InfoHash")
|
Msg("unequal 'infohash' claim when validating JWT")
|
||||||
err = ErrInvalidJWT
|
err = ErrInvalidJWT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -177,14 +188,10 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
|||||||
}
|
}
|
||||||
|
|
||||||
type scrapeClaims struct {
|
type scrapeClaims struct {
|
||||||
jwt.RegisteredClaims
|
registeredClaimsWrapper
|
||||||
InfoHashes []string `json:"infohashes,omitempty"`
|
InfoHashes []string `json:"infohashes,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims {
|
|
||||||
return sc.RegisteredClaims
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
||||||
if !h.cfg.HandleScrape {
|
if !h.cfg.HandleScrape {
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
@@ -192,7 +199,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
|
if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
|
||||||
err = ErrMissingJWT
|
err = ErrMissingJWT
|
||||||
} else {
|
} else {
|
||||||
claims := new(scrapeClaims)
|
claims := new(scrapeClaims)
|
||||||
@@ -210,8 +217,8 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
|||||||
} else {
|
} else {
|
||||||
logger.Info().
|
logger.Info().
|
||||||
Err(err).
|
Err(err).
|
||||||
Array("source", req.RequestAddresses).
|
Array("addresses", req.RequestAddresses).
|
||||||
Msg("InfoHash claim parse failed")
|
Msg("'infohashes' claim parse failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
eq := len(req.InfoHashes) == len(claimIHs)
|
eq := len(req.InfoHashes) == len(claimIHs)
|
||||||
@@ -232,10 +239,10 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
|||||||
}
|
}
|
||||||
if !eq {
|
if !eq {
|
||||||
logger.Info().
|
logger.Info().
|
||||||
Array("provided", claimIHs).
|
Array("claimInfoHashes", claimIHs).
|
||||||
Array("required", req.InfoHashes).
|
Array("requestInfoHashes", req.InfoHashes).
|
||||||
Array("source", req.RequestAddresses).
|
Array("addresses", req.RequestAddresses).
|
||||||
Msg("InfoHashes claim not equals to request InfoHashes")
|
Msg("unequal 'infohashes' claim when validating JWT")
|
||||||
err = ErrInvalidJWT
|
err = ErrInvalidJWT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -244,7 +251,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
|||||||
return ctx, err
|
return ctx, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
|
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
|
||||||
if params != nil {
|
if params != nil {
|
||||||
var found bool
|
var found bool
|
||||||
if jwt, found = params.String(h.cfg.Header); found {
|
if jwt, found = params.String(h.cfg.Header); found {
|
||||||
@@ -256,7 +263,7 @@ func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) {
|
func (h *hook) validateBaseJWT(jwtParam string, claims verifiableClaims) (errs []error) {
|
||||||
if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) {
|
if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) {
|
||||||
jwtParam = jwtParam[len(bearerAuthPrefix):]
|
jwtParam = jwtParam[len(bearerAuthPrefix):]
|
||||||
}
|
}
|
||||||
@@ -266,25 +273,17 @@ func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs [
|
|||||||
if err := claims.Valid(); err != nil {
|
if err := claims.Valid(); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 {
|
|
||||||
errs = append(errs, errs0...)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) {
|
if !claims.VerifyIssuer(h.cfg.Issuer, true) {
|
||||||
rc := cl.ToRegisteredClaims()
|
|
||||||
if !rc.VerifyIssuer(h.cfg.Issuer, true) {
|
|
||||||
logger.Debug().
|
logger.Debug().
|
||||||
Str("provided", rc.Issuer).
|
Str("provided", claims.GetIssuer()).
|
||||||
Str("required", h.cfg.Issuer).
|
Str("required", h.cfg.Issuer).
|
||||||
Msg("unequal or missing issuer when validating JWT")
|
Msg("unequal or missing issuer when validating JWT")
|
||||||
errs = append(errs, jwt.ErrTokenInvalidIssuer)
|
errs = append(errs, jwt.ErrTokenInvalidIssuer)
|
||||||
}
|
}
|
||||||
|
if !claims.VerifyAudience(h.cfg.Audience, true) {
|
||||||
if !rc.VerifyAudience(h.cfg.Audience, true) {
|
|
||||||
logger.Debug().
|
logger.Debug().
|
||||||
Strs("provided", rc.Audience).
|
Strs("provided", claims.GetAudience()).
|
||||||
Str("required", h.cfg.Audience).
|
Str("required", h.cfg.Audience).
|
||||||
Msg("unequal or missing audience when validating JWT")
|
Msg("unequal or missing audience when validating JWT")
|
||||||
errs = append(errs, jwt.ErrTokenInvalidAudience)
|
errs = append(errs, jwt.ErrTokenInvalidAudience)
|
||||||
|
|||||||
@@ -53,12 +53,10 @@ type JWKSKeys struct {
|
|||||||
Keys []JWKSKey `json:"keys"`
|
Keys []JWKSKey `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type params struct {
|
type params map[string]string
|
||||||
data map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p params) String(key string) (out string, found bool) {
|
func (p params) String(key string) (out string, found bool) {
|
||||||
out, found = p.data[key]
|
out, found = p[key]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,9 +68,7 @@ func (params) RawQuery() (s string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (params) MarshalZerologObject(*zerolog.Event) {
|
func (params) MarshalZerologObject(*zerolog.Event) {}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
_ = log.ConfigureLogger("", "info", false, false)
|
_ = log.ConfigureLogger("", "info", false, false)
|
||||||
@@ -102,13 +98,15 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
|
|||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
registeredClaimsWrapper: registeredClaimsWrapper{
|
||||||
Issuer: "CN=test",
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
Subject: "CN=test",
|
Issuer: "CN=test",
|
||||||
Audience: []string{"test"},
|
Subject: "CN=test",
|
||||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
Audience: []string{"test"},
|
||||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||||
|
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
InfoHash: infoHash.String(),
|
InfoHash: infoHash.String(),
|
||||||
})
|
})
|
||||||
@@ -126,12 +124,12 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
h, err := build(cfg, nil)
|
h, err := build(cfg, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
data := make(map[string]string)
|
data := make(params)
|
||||||
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
||||||
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
|
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
|
||||||
InfoHash: infoHash,
|
InfoHash: infoHash,
|
||||||
RequestPeer: bittorrent.RequestPeer{},
|
RequestPeer: bittorrent.RequestPeer{},
|
||||||
Params: ¶ms{data: data},
|
Params: data,
|
||||||
}, nil)
|
}, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
@@ -145,13 +143,15 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
|
|||||||
// now we wll use HMAC-SHA256 with invalid random key
|
// now we wll use HMAC-SHA256 with invalid random key
|
||||||
// all errors should be nil except announce request
|
// all errors should be nil except announce request
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
registeredClaimsWrapper: registeredClaimsWrapper{
|
||||||
Issuer: "CN=test",
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
Subject: "CN=test",
|
Issuer: "CN=test",
|
||||||
Audience: []string{"test"},
|
Subject: "CN=test",
|
||||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
Audience: []string{"test"},
|
||||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||||
|
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
InfoHash: infoHash.String(),
|
InfoHash: infoHash.String(),
|
||||||
})
|
})
|
||||||
@@ -172,12 +172,12 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
h, err := build(cfg, nil)
|
h, err := build(cfg, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
data := make(map[string]string)
|
data := make(params)
|
||||||
data["jwt"] = tokenString
|
data["jwt"] = tokenString
|
||||||
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
|
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
|
||||||
InfoHash: infoHash,
|
InfoHash: infoHash,
|
||||||
RequestPeer: bittorrent.RequestPeer{},
|
RequestPeer: bittorrent.RequestPeer{},
|
||||||
Params: ¶ms{data: data},
|
Params: data,
|
||||||
}, nil)
|
}, nil)
|
||||||
require.NotNil(t, err)
|
require.NotNil(t, err)
|
||||||
}
|
}
|
||||||
@@ -188,16 +188,27 @@ func TestHook_HandleScrapeValid(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
ihs := make(bittorrent.InfoHashes, rand.Intn(10)+1)
|
||||||
|
ihss := make([]string, len(ihs))
|
||||||
|
for i := range ihs {
|
||||||
|
bb := []byte(infoHash)
|
||||||
|
bb[i] = byte(i)
|
||||||
|
ihs[i] = bittorrent.InfoHash(bb)
|
||||||
|
ihss[i] = ihs[i].String()
|
||||||
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
registeredClaimsWrapper: registeredClaimsWrapper{
|
||||||
Issuer: "CN=test",
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
Subject: "CN=test",
|
Issuer: "CN=test",
|
||||||
Audience: []string{"test"},
|
Subject: "CN=test",
|
||||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
Audience: []string{"test"},
|
||||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||||
|
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
InfoHashes: []string{infoHash.String()},
|
InfoHashes: ihss,
|
||||||
})
|
})
|
||||||
|
|
||||||
token.Header["kid"] = jwksData.Keys[0].KeyID
|
token.Header["kid"] = jwksData.Keys[0].KeyID
|
||||||
@@ -213,12 +224,12 @@ func TestHook_HandleScrapeValid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
h, err := build(cfg, nil)
|
h, err := build(cfg, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
data := make(map[string]string)
|
data := make(params)
|
||||||
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
||||||
_, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{
|
_, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{
|
||||||
InfoHashes: bittorrent.InfoHashes{infoHash},
|
InfoHashes: ihs,
|
||||||
RequestAddresses: bittorrent.RequestAddresses{},
|
RequestAddresses: bittorrent.RequestAddresses{},
|
||||||
Params: ¶ms{data: data},
|
Params: data,
|
||||||
}, nil)
|
}, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user