(partially tested) simplify client approval m/w

* sanitize code
This commit is contained in:
Lawrence, Rendall
2023-03-23 00:34:10 +03:00
parent 967611620e
commit 61f859e3f6
18 changed files with 99 additions and 128 deletions

View File

@@ -13,9 +13,9 @@ import (
//
// See ParseURLData for specifics on parsing and limitations.
type Params interface {
// String returns a string parsed from a query. Every key can be
// GetString returns a string parsed from a query. Every key can be
// returned as a string because they are encoded in the URL as strings.
String(key string) (string, bool)
GetString(key string) (string, bool)
zerolog.LogObjectMarshaler
}

View File

@@ -238,24 +238,10 @@ func (s Scrape) MarshalZerologObject(e *zerolog.Event) {
// Scrapes wrapper of array of Scrape-s
type Scrapes []Scrape
func (s *Scrapes) Len() int {
return len(*s)
}
func (s *Scrapes) Less(i, j int) bool {
return (*s)[i].InfoHash < (*s)[j].InfoHash
}
func (s *Scrapes) Swap(i, j int) {
(*s)[i], (*s)[j] = (*s)[j], (*s)[i]
}
// MarshalZerologArray writes array elements to zerolog event
func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
if s != nil {
for _, scrape := range *s {
a.Object(scrape)
}
func (s Scrapes) MarshalZerologArray(a *zerolog.Array) {
for _, scrape := range s {
a.Object(scrape)
}
}
@@ -264,10 +250,10 @@ func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
// The Scrapes must be in the same order as the InfoHashes in the corresponding
// ScrapeRequest.
type ScrapeResponse struct {
Files Scrapes
Data Scrapes
}
// MarshalZerologObject writes fields into zerolog event
func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) {
e.Array("scrapes", &sr.Files)
e.Array("scrapes", sr.Data)
}

View File

@@ -190,10 +190,10 @@ prehooks:
#
# - name: client approval
# config:
# whitelist:
# client_id_list:
# - "OP1011"
# blacklist:
# - "OP1012"
# true - whitelist mode, false - blacklist
# invert: true
#
# - name: interval variation
# config:

View File

@@ -3,6 +3,7 @@ package http
import (
"github.com/rs/zerolog"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/str2bytes"
"github.com/valyala/fasthttp"
)
@@ -12,11 +13,11 @@ type queryParams struct {
*fasthttp.Args
}
// String returns a string parsed from a query. Every key can be returned as a
// GetString returns a string parsed from a query. Every key can be returned as a
// string because they are encoded in the URL as strings.
func (qp queryParams) String(key string) (string, bool) {
func (qp queryParams) GetString(key string) (string, bool) {
v := qp.Peek(key)
return string(v), v != nil
return str2bytes.BytesToString(v), v != nil
}
// InfoHashes returns a list of requested infohashes.

View File

@@ -5,10 +5,11 @@ import (
"errors"
"net/netip"
"github.com/valyala/fasthttp"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/frontend"
"github.com/valyala/fasthttp"
"github.com/sot-tech/mochi/pkg/str2bytes"
)
// ParseOptions is the configuration used to parse an Announce Request.
@@ -40,7 +41,7 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou
// Attempt to parse the event from the request.
var eventStr string
var err error
eventStr, request.EventProvided = qp.String("event")
eventStr, request.EventProvided = qp.GetString("event")
if request.EventProvided {
if request.Event, err = bittorrent.NewEvent(eventStr); err != nil {
return nil, err
@@ -136,7 +137,7 @@ func parseScrape(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.ScrapeR
func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) {
if opts.AllowIPSpoofing {
for _, f := range []string{"ip", "ipv4", "ipv6"} {
if ipStr, ok := p.String(f); ok {
if ipStr, ok := p.GetString(f); ok {
addresses.Add(parseRequestAddress(ipStr, true))
}
}
@@ -146,7 +147,7 @@ func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (ad
for _, ipStr := range ipValues {
for _, ipStr := range bytes.Split(ipStr, []byte{','}) {
if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 {
addresses.Add(parseRequestAddress(string(ipStr), false))
addresses.Add(parseRequestAddress(str2bytes.BytesToString(ipStr), false))
}
}
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"io"
"net"
"sort"
"strconv"
"time"
@@ -17,7 +18,7 @@ import (
var respBufferPool = bytepool.NewBufferPool()
func writeErrorResponse(w io.StringWriter, err error) {
message := "internal server error"
message := "mochi internal error"
var clientErr bittorrent.ClientError
if errors.As(err, &clientErr) {
message = clientErr.Error()
@@ -99,25 +100,33 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) {
bb.Write(peer.ID.Bytes())
}
bb.WriteString("4:porti")
bb.Write(fasthttp.AppendUint(nil, int(peer.Port())))
bb.Write([]byte{'e', 'e'})
port := peer.Port()
bb.Write([]byte{byte(port >> 8), byte(port), 'e', 'e'})
}
func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) {
bb := respBufferPool.Get()
defer respBufferPool.Put(bb)
bb.WriteString("d5:filesd")
for _, scrape := range resp.Files {
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
bb.WriteByte(':')
bb.Write([]byte(scrape.InfoHash))
bb.WriteString("d8:completei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
bb.WriteString("e10:downloadedi")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
bb.WriteString("e10:incompletei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
bb.WriteString("ee")
l := len(resp.Data)
if l > 0 {
if l > 1 {
sort.Slice(resp.Data, func(i, j int) bool {
return resp.Data[i].InfoHash < resp.Data[j].InfoHash
})
}
for _, scrape := range resp.Data {
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
bb.WriteByte(':')
bb.Write([]byte(scrape.InfoHash))
bb.WriteString("d8:completei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
bb.WriteString("e10:downloadedi")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
bb.WriteString("e10:incompletei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
bb.Write([]byte{'e', 'e'})
}
}
bb.Write([]byte{'e', 'e'})
_, _ = bb.WriteTo(w)

View File

@@ -34,7 +34,7 @@ const (
var (
logger = log.NewLogger("frontend/udp")
allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
allowedGeneratedPrivateKeyRunes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
)
func init() {
@@ -66,7 +66,7 @@ func (cfg Config) Validate() (validCfg Config) {
// Generate a private key if one isn't provided by the user.
if cfg.PrivateKey == "" {
pkeyRunes := make([]rune, defaultKeyLen)
pkeyRunes := make([]byte, defaultKeyLen)
for i := range pkeyRunes {
pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))]
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/rs/zerolog"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/str2bytes"
)
// ErrInvalidQueryEscape is returned when a query string contains invalid
@@ -60,7 +61,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
key, value = key[:i], key[i+1:]
}
var k, v string
k, err = url.QueryUnescape(string(key))
k, err = url.QueryUnescape(str2bytes.BytesToString(key))
if err != nil {
// QueryUnescape returns an error like "invalid escape: '%x'".
// But frontends record these errors to prometheus, which generates
@@ -68,7 +69,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
// We log it here for debugging instead.
return nil, ErrInvalidQueryEscape
}
v, err = url.QueryUnescape(string(value))
v, err = url.QueryUnescape(str2bytes.BytesToString(value))
if err != nil {
// QueryUnescape returns an error like "invalid escape: '%x'".
// But frontends record these errors to prometheus, which generates
@@ -83,9 +84,9 @@ func parseQuery(query []byte) (q *queryParams, err error) {
return q, nil
}
// String returns a string parsed from a query. Every key can be returned as a
// GetString returns a string parsed from a query. Every key can be returned as a
// string because they are encoded in the URL as strings.
func (qp queryParams) String(key string) (string, bool) {
func (qp queryParams) GetString(key string) (string, bool) {
value, ok := qp.params[strings.ToLower(key)]
return value, ok
}

View File

@@ -64,7 +64,7 @@ func TestHandleOptionalParameters(t *testing.T) {
t.Fatalf("expected values %v for %x", tt.values, tt.data)
} else {
for key, want := range tt.values {
if got, ok := params.String(key); !ok {
if got, ok := params.GetString(key); !ok {
t.Fatalf("params missing entry %s for data %x", key, tt.data)
} else if got != want {
t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data)

View File

@@ -14,13 +14,17 @@ func writeErrorResponse(w io.Writer, txID []byte, err error) {
buf := reqRespBufferPool.Get()
defer reqRespBufferPool.Put(buf)
writeHeader(buf, txID, errorActionID)
message := "mochi internal error"
var clientErr bittorrent.ClientError
// If the client wasn't at fault, acknowledge it.
if !errors.As(err, new(bittorrent.ClientError)) {
buf.WriteString("internal error occurred: ")
if errors.As(err, &clientErr) {
message = clientErr.Error()
} else {
logger.Error().Err(err).Msg("internal error")
}
buf.WriteString(err.Error())
buf.WriteString(message)
buf.WriteByte('\000')
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeAnnounceResponse encodes an announce response according to BEP 15.
@@ -51,7 +55,7 @@ func writeAnnounceResponse(w io.Writer, txID []byte, resp *bittorrent.AnnounceRe
_ = binary.Write(buf, binary.BigEndian, peer.Port())
}
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeScrapeResponse encodes a scrape response according to BEP 15.
@@ -61,13 +65,12 @@ func writeScrapeResponse(w io.Writer, txID []byte, resp *bittorrent.ScrapeRespon
writeHeader(buf, txID, scrapeActionID)
for _, scrape := range resp.Files {
for _, scrape := range resp.Data {
_ = binary.Write(buf, binary.BigEndian, scrape.Complete)
_ = binary.Write(buf, binary.BigEndian, scrape.Snatches)
_ = binary.Write(buf, binary.BigEndian, scrape.Incomplete)
}
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeConnectionID encodes a new connection response according to BEP 15.
@@ -77,7 +80,7 @@ func writeConnectionID(w io.Writer, txID, connID []byte) {
writeHeader(buf, txID, connectActionID)
buf.Write(connID)
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeHeader writes the action and transaction ID to the provided response

View File

@@ -6,7 +6,6 @@ package clientapproval
import (
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/str2bytes"
)
// ClientID represents the part of a PeerID that identifies a Peer's client
@@ -17,14 +16,9 @@ type ClientID [6]byte
func NewClientID(pid bittorrent.PeerID) ClientID {
var cid ClientID
if pid[0] == '-' {
copy(cid[:], pid[1:7])
cid = ClientID(pid[1:7])
} else {
copy(cid[:], pid[:6])
cid = ClientID(pid[:6])
}
return cid
}
func (cid ClientID) String() string {
return str2bytes.BytesToString(cid[:])
}

View File

@@ -20,23 +20,21 @@ func init() {
middleware.RegisterBuilder(Name, build)
}
var (
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
ErrClientUnapproved = bittorrent.ClientError("unapproved client")
errBothListsProvided = errors.New("using both whitelist and blacklist is invalid")
)
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
var ErrClientUnapproved = bittorrent.ClientError("client not allowed by mochi")
// Config represents all the values required by this middleware to validate
// peers based on their BitTorrent client ID.
type Config struct {
Whitelist []string
Blacklist []string
// Static list of client IDs.
ClientIDList []string `cfg:"client_id_list"`
// If Invert set to true, all client IDs stored in ClientIDList should be blacklisted.
Invert bool
}
type hook struct {
approved map[ClientID]struct{}
unapproved map[ClientID]struct{}
clientIDs map[ClientID]any
invert bool
}
func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) {
@@ -47,53 +45,31 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error
}
h := &hook{
approved: make(map[ClientID]struct{}),
unapproved: make(map[ClientID]struct{}),
clientIDs: make(map[ClientID]any, len(cfg.ClientIDList)),
invert: cfg.Invert,
}
if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 {
return nil, errBothListsProvided
}
for _, cidString := range cfg.Whitelist {
for _, cidString := range cfg.ClientIDList {
cidBytes := []byte(cidString)
if len(cidBytes) != 6 {
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
}
var cid ClientID
copy(cid[:], cidBytes)
h.approved[cid] = struct{}{}
}
for _, cidString := range cfg.Blacklist {
cidBytes := []byte(cidString)
if len(cidBytes) != 6 {
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
}
var cid ClientID
copy(cid[:], cidBytes)
h.unapproved[cid] = struct{}{}
h.clientIDs[ClientID(cidBytes)] = true
}
return h, nil
}
// HandleAnnounce checks if specified ClientID is approved or not.
// If Config.Invert set to true and hash found in provided list, function will return ErrClientUnapproved,
// that means that ClientID is blacklisted.
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
clientID := NewClientID(req.ID)
if len(h.approved) > 0 {
if _, found := h.approved[clientID]; !found {
return ctx, ErrClientUnapproved
}
var err error
if _, contains := h.clientIDs[NewClientID(req.ID)]; contains == h.invert {
err = ErrClientUnapproved
}
if len(h.unapproved) > 0 {
if _, found := h.unapproved[clientID]; found {
return ctx, ErrClientUnapproved
}
}
return ctx, nil
return ctx, err
}
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {

View File

@@ -19,7 +19,7 @@ var cases = []struct {
// Client ID is whitelisted
{
Config{
Whitelist: []string{"010203"},
ClientIDList: []string{"010203"},
},
"01020304050607080900",
true,
@@ -27,7 +27,7 @@ var cases = []struct {
// Client ID is not whitelisted
{
Config{
Whitelist: []string{"010203"},
ClientIDList: []string{"010203"},
},
"10203040506070809000",
false,
@@ -35,7 +35,8 @@ var cases = []struct {
// Client ID is not blacklisted
{
Config{
Blacklist: []string{"010203"},
ClientIDList: []string{"010203"},
Invert: true,
},
"00000000001234567890",
true,
@@ -43,7 +44,8 @@ var cases = []struct {
// Client ID is blacklisted
{
Config{
Blacklist: []string{"123456"},
ClientIDList: []string{"123456"},
Invert: true,
},
"12345678900000000000",
false,
@@ -53,7 +55,7 @@ var cases = []struct {
func TestHandleAnnounce(t *testing.T) {
for _, tt := range cases {
t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) {
c := conf.MapConfig{"whitelist": tt.cfg.Whitelist, "blacklist": tt.cfg.Blacklist}
c := conf.MapConfig{"client_id_list": tt.cfg.ClientIDList, "invert": tt.cfg.Invert}
h, err := build(c, nil)
require.Nil(t, err)
@@ -67,7 +69,7 @@ func TestHandleAnnounce(t *testing.T) {
nctx, err := h.HandleAnnounce(ctx, req, resp)
require.Equal(t, ctx, nctx)
if tt.approved == true {
if tt.approved {
require.NotEqual(t, err, ErrClientUnapproved)
} else {
require.Equal(t, err, ErrClientUnapproved)

View File

@@ -223,7 +223,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
if err != nil {
return
}
resp.Files = append(resp.Files, scr)
resp.Data = append(resp.Data, scr)
}
return ctx, nil

View File

@@ -34,10 +34,10 @@ func init() {
var (
logger = log.NewLogger("middleware/jwt")
// ErrMissingJWT is returned when a JWT is missing from a request.
ErrMissingJWT = bittorrent.ClientError("unapproved request: missing jwt")
ErrMissingJWT = bittorrent.ClientError("request not allowed by mochi: missing jwt")
// ErrInvalidJWT is returned when a JWT fails to verify.
ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt")
ErrInvalidJWT = bittorrent.ClientError("request not allowed by mochi: invalid jwt")
errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL")
@@ -248,7 +248,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
if params != nil {
var found bool
if jwt, found = params.String(h.cfg.Header); found {
if jwt, found = params.GetString(h.cfg.Header); found {
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
jwt = jwt[len(bearerAuthPrefix):]
}

View File

@@ -55,7 +55,7 @@ type JWKSKeys struct {
type params map[string]string
func (p params) String(key string) (out string, found bool) {
func (p params) GetString(key string) (out string, found bool) {
out, found = p[key]
return
}

View File

@@ -2,7 +2,6 @@ package middleware
import (
"context"
"sort"
"time"
"github.com/sot-tech/mochi/bittorrent"
@@ -80,14 +79,13 @@ func (l *Logic) AfterAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) {
logger.Debug().Object("request", req).Msg("new scrape request")
resp = &bittorrent.ScrapeResponse{
Files: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
Data: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
}
for _, h := range l.preHooks {
if ctx, err = h.HandleScrape(ctx, req, resp); err != nil {
return nil, nil, err
}
}
sort.Sort(&resp.Files)
logger.Debug().Object("response", resp).Msg("generated scrape response")
return ctx, resp, nil

View File

@@ -65,7 +65,7 @@ func build(config conf.MapConfig, st storage.PeerStorage) (h middleware.Hook, er
}
// ErrTorrentUnapproved is the error returned when a torrent hash is invalid.
var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent")
var ErrTorrentUnapproved = bittorrent.ClientError("torrent not allowed by mochi")
type hook struct {
hashContainer container.Container