mirror of
https://github.com/sot-tech/mochi.git
synced 2026-05-23 00:04:49 -07:00
(minor) refactor jwt
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
---
|
||||
run:
|
||||
# mochi in not written with generics (a.t.m),
|
||||
# so we can check with 1.17
|
||||
go: "1.17"
|
||||
timeout: "5m"
|
||||
output:
|
||||
sort-results: true
|
||||
|
||||
@@ -123,27 +123,38 @@ func (h *hook) Stop() stop.Result {
|
||||
return c.Result()
|
||||
}
|
||||
|
||||
type compatibleClaims interface {
|
||||
Valid() error
|
||||
ToRegisteredClaims() jwt.RegisteredClaims
|
||||
type verifiableClaims interface {
|
||||
jwt.Claims
|
||||
VerifyIssuer(iss string, req bool) bool
|
||||
GetIssuer() string
|
||||
VerifyAudience(aud string, req bool) bool
|
||||
GetAudience() []string
|
||||
}
|
||||
|
||||
type registeredClaimsWrapper struct {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func (rc registeredClaimsWrapper) GetIssuer() string {
|
||||
return rc.Issuer
|
||||
}
|
||||
|
||||
func (rc registeredClaimsWrapper) GetAudience() []string {
|
||||
return rc.Audience
|
||||
}
|
||||
|
||||
type announceClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
registeredClaimsWrapper
|
||||
InfoHash string `json:"infohash,omitempty"`
|
||||
}
|
||||
|
||||
func (ac announceClaims) ToRegisteredClaims() jwt.RegisteredClaims {
|
||||
return ac.RegisteredClaims
|
||||
}
|
||||
|
||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||
if !h.cfg.HandleAnnounce {
|
||||
return ctx, nil
|
||||
}
|
||||
var err error
|
||||
|
||||
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
|
||||
if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
|
||||
err = ErrMissingJWT
|
||||
} else {
|
||||
claims := new(announceClaims)
|
||||
@@ -159,15 +170,15 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
||||
logger.Info().
|
||||
Err(err).
|
||||
Object("source", req.RequestPeer).
|
||||
Msg("InfoHash claim parse failed")
|
||||
Msg("'infohash' claim parse failed")
|
||||
err = ErrInvalidJWT
|
||||
}
|
||||
if req.InfoHash != claimIH {
|
||||
logger.Info().
|
||||
Stringer("provided", claimIH).
|
||||
Stringer("required", req.InfoHash).
|
||||
Object("source", req.RequestPeer).
|
||||
Msg("InfoHash claim not equals to request InfoHash")
|
||||
Stringer("claimInfoHash", claimIH).
|
||||
Stringer("requestInfoHash", req.InfoHash).
|
||||
Object("peer", req.RequestPeer).
|
||||
Msg("unequal 'infohash' claim when validating JWT")
|
||||
err = ErrInvalidJWT
|
||||
}
|
||||
}
|
||||
@@ -177,14 +188,10 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
||||
}
|
||||
|
||||
type scrapeClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
registeredClaimsWrapper
|
||||
InfoHashes []string `json:"infohashes,omitempty"`
|
||||
}
|
||||
|
||||
func (sc scrapeClaims) ToRegisteredClaims() jwt.RegisteredClaims {
|
||||
return sc.RegisteredClaims
|
||||
}
|
||||
|
||||
func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
||||
if !h.cfg.HandleScrape {
|
||||
return ctx, nil
|
||||
@@ -192,7 +199,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
||||
|
||||
var err error
|
||||
|
||||
if jwtParam := h.getJWT(req.Params); len(jwtParam) == 0 {
|
||||
if jwtParam := h.getJWTString(req.Params); len(jwtParam) == 0 {
|
||||
err = ErrMissingJWT
|
||||
} else {
|
||||
claims := new(scrapeClaims)
|
||||
@@ -210,8 +217,8 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
||||
} else {
|
||||
logger.Info().
|
||||
Err(err).
|
||||
Array("source", req.RequestAddresses).
|
||||
Msg("InfoHash claim parse failed")
|
||||
Array("addresses", req.RequestAddresses).
|
||||
Msg("'infohashes' claim parse failed")
|
||||
}
|
||||
}
|
||||
eq := len(req.InfoHashes) == len(claimIHs)
|
||||
@@ -232,10 +239,10 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
||||
}
|
||||
if !eq {
|
||||
logger.Info().
|
||||
Array("provided", claimIHs).
|
||||
Array("required", req.InfoHashes).
|
||||
Array("source", req.RequestAddresses).
|
||||
Msg("InfoHashes claim not equals to request InfoHashes")
|
||||
Array("claimInfoHashes", claimIHs).
|
||||
Array("requestInfoHashes", req.InfoHashes).
|
||||
Array("addresses", req.RequestAddresses).
|
||||
Msg("unequal 'infohashes' claim when validating JWT")
|
||||
err = ErrInvalidJWT
|
||||
}
|
||||
}
|
||||
@@ -244,7 +251,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
|
||||
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
|
||||
if params != nil {
|
||||
var found bool
|
||||
if jwt, found = params.String(h.cfg.Header); found {
|
||||
@@ -256,7 +263,7 @@ func (h *hook) getJWT(params bittorrent.Params) (jwt string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs []error) {
|
||||
func (h *hook) validateBaseJWT(jwtParam string, claims verifiableClaims) (errs []error) {
|
||||
if strings.HasPrefix(strings.ToLower(jwtParam), bearerAuthPrefix) {
|
||||
jwtParam = jwtParam[len(bearerAuthPrefix):]
|
||||
}
|
||||
@@ -266,25 +273,17 @@ func (h *hook) validateBaseJWT(jwtParam string, claims compatibleClaims) (errs [
|
||||
if err := claims.Valid(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if errs0 := h.validateRegisteredClaims(claims); len(errs0) > 0 {
|
||||
errs = append(errs, errs0...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *hook) validateRegisteredClaims(cl compatibleClaims) (errs []error) {
|
||||
rc := cl.ToRegisteredClaims()
|
||||
if !rc.VerifyIssuer(h.cfg.Issuer, true) {
|
||||
if !claims.VerifyIssuer(h.cfg.Issuer, true) {
|
||||
logger.Debug().
|
||||
Str("provided", rc.Issuer).
|
||||
Str("provided", claims.GetIssuer()).
|
||||
Str("required", h.cfg.Issuer).
|
||||
Msg("unequal or missing issuer when validating JWT")
|
||||
errs = append(errs, jwt.ErrTokenInvalidIssuer)
|
||||
}
|
||||
|
||||
if !rc.VerifyAudience(h.cfg.Audience, true) {
|
||||
if !claims.VerifyAudience(h.cfg.Audience, true) {
|
||||
logger.Debug().
|
||||
Strs("provided", rc.Audience).
|
||||
Strs("provided", claims.GetAudience()).
|
||||
Str("required", h.cfg.Audience).
|
||||
Msg("unequal or missing audience when validating JWT")
|
||||
errs = append(errs, jwt.ErrTokenInvalidAudience)
|
||||
|
||||
@@ -53,12 +53,10 @@ type JWKSKeys struct {
|
||||
Keys []JWKSKey `json:"keys"`
|
||||
}
|
||||
|
||||
type params struct {
|
||||
data map[string]string
|
||||
}
|
||||
type params map[string]string
|
||||
|
||||
func (p params) String(key string) (out string, found bool) {
|
||||
out, found = p.data[key]
|
||||
out, found = p[key]
|
||||
return
|
||||
}
|
||||
|
||||
@@ -70,9 +68,7 @@ func (params) RawQuery() (s string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (params) MarshalZerologObject(*zerolog.Event) {
|
||||
return
|
||||
}
|
||||
func (params) MarshalZerologObject(*zerolog.Event) {}
|
||||
|
||||
func init() {
|
||||
_ = log.ConfigureLogger("", "info", false, false)
|
||||
@@ -102,13 +98,15 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, announceClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
Audience: []string{"test"},
|
||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||
registeredClaimsWrapper: registeredClaimsWrapper{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
Audience: []string{"test"},
|
||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||
},
|
||||
},
|
||||
InfoHash: infoHash.String(),
|
||||
})
|
||||
@@ -126,12 +124,12 @@ func TestHook_HandleAnnounceValid(t *testing.T) {
|
||||
}
|
||||
h, err := build(cfg, nil)
|
||||
require.Nil(t, err)
|
||||
data := make(map[string]string)
|
||||
data := make(params)
|
||||
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
||||
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
|
||||
InfoHash: infoHash,
|
||||
RequestPeer: bittorrent.RequestPeer{},
|
||||
Params: ¶ms{data: data},
|
||||
Params: data,
|
||||
}, nil)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
@@ -145,13 +143,15 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
|
||||
// now we wll use HMAC-SHA256 with invalid random key
|
||||
// all errors should be nil except announce request
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, announceClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
Audience: []string{"test"},
|
||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||
registeredClaimsWrapper: registeredClaimsWrapper{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
Audience: []string{"test"},
|
||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||
},
|
||||
},
|
||||
InfoHash: infoHash.String(),
|
||||
})
|
||||
@@ -172,12 +172,12 @@ func TestHook_HandleAnnounceInvalid(t *testing.T) {
|
||||
}
|
||||
h, err := build(cfg, nil)
|
||||
require.Nil(t, err)
|
||||
data := make(map[string]string)
|
||||
data := make(params)
|
||||
data["jwt"] = tokenString
|
||||
_, err = h.HandleAnnounce(context.Background(), &bittorrent.AnnounceRequest{
|
||||
InfoHash: infoHash,
|
||||
RequestPeer: bittorrent.RequestPeer{},
|
||||
Params: ¶ms{data: data},
|
||||
Params: data,
|
||||
}, nil)
|
||||
require.NotNil(t, err)
|
||||
}
|
||||
@@ -188,16 +188,27 @@ func TestHook_HandleScrapeValid(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
ihs := make(bittorrent.InfoHashes, rand.Intn(10)+1)
|
||||
ihss := make([]string, len(ihs))
|
||||
for i := range ihs {
|
||||
bb := []byte(infoHash)
|
||||
bb[i] = byte(i)
|
||||
ihs[i] = bittorrent.InfoHash(bb)
|
||||
ihss[i] = ihs[i].String()
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, scrapeClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
Audience: []string{"test"},
|
||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||
registeredClaimsWrapper: registeredClaimsWrapper{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "CN=test",
|
||||
Subject: "CN=test",
|
||||
Audience: []string{"test"},
|
||||
ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(time.Hour)},
|
||||
NotBefore: &jwt.NumericDate{Time: time.Now().Add(-time.Hour)},
|
||||
ID: strconv.FormatInt(rand.Int63(), 16),
|
||||
},
|
||||
},
|
||||
InfoHashes: []string{infoHash.String()},
|
||||
InfoHashes: ihss,
|
||||
})
|
||||
|
||||
token.Header["kid"] = jwksData.Keys[0].KeyID
|
||||
@@ -213,12 +224,12 @@ func TestHook_HandleScrapeValid(t *testing.T) {
|
||||
}
|
||||
h, err := build(cfg, nil)
|
||||
require.Nil(t, err)
|
||||
data := make(map[string]string)
|
||||
data := make(params)
|
||||
data[authorizationHeader] = bearerAuthPrefix + tokenString
|
||||
_, err = h.HandleScrape(context.Background(), &bittorrent.ScrapeRequest{
|
||||
InfoHashes: bittorrent.InfoHashes{infoHash},
|
||||
InfoHashes: ihs,
|
||||
RequestAddresses: bittorrent.RequestAddresses{},
|
||||
Params: ¶ms{data: data},
|
||||
Params: data,
|
||||
}, nil)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user