mirror of
https://github.com/sot-tech/mochi.git
synced 2026-05-27 02:04:45 -07:00
(partially tested) simplify client approval m/w
* sanitize code
This commit is contained in:
@@ -13,9 +13,9 @@ import (
|
||||
//
|
||||
// See ParseURLData for specifics on parsing and limitations.
|
||||
type Params interface {
|
||||
// String returns a string parsed from a query. Every key can be
|
||||
// GetString returns a string parsed from a query. Every key can be
|
||||
// returned as a string because they are encoded in the URL as strings.
|
||||
String(key string) (string, bool)
|
||||
GetString(key string) (string, bool)
|
||||
|
||||
zerolog.LogObjectMarshaler
|
||||
}
|
||||
|
||||
@@ -238,24 +238,10 @@ func (s Scrape) MarshalZerologObject(e *zerolog.Event) {
|
||||
// Scrapes wrapper of array of Scrape-s
|
||||
type Scrapes []Scrape
|
||||
|
||||
func (s *Scrapes) Len() int {
|
||||
return len(*s)
|
||||
}
|
||||
|
||||
func (s *Scrapes) Less(i, j int) bool {
|
||||
return (*s)[i].InfoHash < (*s)[j].InfoHash
|
||||
}
|
||||
|
||||
func (s *Scrapes) Swap(i, j int) {
|
||||
(*s)[i], (*s)[j] = (*s)[j], (*s)[i]
|
||||
}
|
||||
|
||||
// MarshalZerologArray writes array elements to zerolog event
|
||||
func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
||||
if s != nil {
|
||||
for _, scrape := range *s {
|
||||
a.Object(scrape)
|
||||
}
|
||||
func (s Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
||||
for _, scrape := range s {
|
||||
a.Object(scrape)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,10 +250,10 @@ func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
||||
// The Scrapes must be in the same order as the InfoHashes in the corresponding
|
||||
// ScrapeRequest.
|
||||
type ScrapeResponse struct {
|
||||
Files Scrapes
|
||||
Data Scrapes
|
||||
}
|
||||
|
||||
// MarshalZerologObject writes fields into zerolog event
|
||||
func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Array("scrapes", &sr.Files)
|
||||
e.Array("scrapes", sr.Data)
|
||||
}
|
||||
|
||||
6
dist/example_config.yaml
vendored
6
dist/example_config.yaml
vendored
@@ -190,10 +190,10 @@ prehooks:
|
||||
#
|
||||
# - name: client approval
|
||||
# config:
|
||||
# whitelist:
|
||||
# client_id_list:
|
||||
# - "OP1011"
|
||||
# blacklist:
|
||||
# - "OP1012"
|
||||
# true - whitelist mode, false - blacklist
|
||||
# invert: true
|
||||
#
|
||||
# - name: interval variation
|
||||
# config:
|
||||
|
||||
@@ -3,6 +3,7 @@ package http
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
@@ -12,11 +13,11 @@ type queryParams struct {
|
||||
*fasthttp.Args
|
||||
}
|
||||
|
||||
// String returns a string parsed from a query. Every key can be returned as a
|
||||
// GetString returns a string parsed from a query. Every key can be returned as a
|
||||
// string because they are encoded in the URL as strings.
|
||||
func (qp queryParams) String(key string) (string, bool) {
|
||||
func (qp queryParams) GetString(key string) (string, bool) {
|
||||
v := qp.Peek(key)
|
||||
return string(v), v != nil
|
||||
return str2bytes.BytesToString(v), v != nil
|
||||
}
|
||||
|
||||
// InfoHashes returns a list of requested infohashes.
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/frontend"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
)
|
||||
|
||||
// ParseOptions is the configuration used to parse an Announce Request.
|
||||
@@ -40,7 +41,7 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou
|
||||
// Attempt to parse the event from the request.
|
||||
var eventStr string
|
||||
var err error
|
||||
eventStr, request.EventProvided = qp.String("event")
|
||||
eventStr, request.EventProvided = qp.GetString("event")
|
||||
if request.EventProvided {
|
||||
if request.Event, err = bittorrent.NewEvent(eventStr); err != nil {
|
||||
return nil, err
|
||||
@@ -136,7 +137,7 @@ func parseScrape(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.ScrapeR
|
||||
func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) {
|
||||
if opts.AllowIPSpoofing {
|
||||
for _, f := range []string{"ip", "ipv4", "ipv6"} {
|
||||
if ipStr, ok := p.String(f); ok {
|
||||
if ipStr, ok := p.GetString(f); ok {
|
||||
addresses.Add(parseRequestAddress(ipStr, true))
|
||||
}
|
||||
}
|
||||
@@ -146,7 +147,7 @@ func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (ad
|
||||
for _, ipStr := range ipValues {
|
||||
for _, ipStr := range bytes.Split(ipStr, []byte{','}) {
|
||||
if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 {
|
||||
addresses.Add(parseRequestAddress(string(ipStr), false))
|
||||
addresses.Add(parseRequestAddress(str2bytes.BytesToString(ipStr), false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
var respBufferPool = bytepool.NewBufferPool()
|
||||
|
||||
func writeErrorResponse(w io.StringWriter, err error) {
|
||||
message := "internal server error"
|
||||
message := "mochi internal error"
|
||||
var clientErr bittorrent.ClientError
|
||||
if errors.As(err, &clientErr) {
|
||||
message = clientErr.Error()
|
||||
@@ -99,25 +100,33 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) {
|
||||
bb.Write(peer.ID.Bytes())
|
||||
}
|
||||
bb.WriteString("4:porti")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(peer.Port())))
|
||||
bb.Write([]byte{'e', 'e'})
|
||||
port := peer.Port()
|
||||
bb.Write([]byte{byte(port >> 8), byte(port), 'e', 'e'})
|
||||
}
|
||||
|
||||
func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) {
|
||||
bb := respBufferPool.Get()
|
||||
defer respBufferPool.Put(bb)
|
||||
bb.WriteString("d5:filesd")
|
||||
for _, scrape := range resp.Files {
|
||||
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
|
||||
bb.WriteByte(':')
|
||||
bb.Write([]byte(scrape.InfoHash))
|
||||
bb.WriteString("d8:completei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
|
||||
bb.WriteString("e10:downloadedi")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
|
||||
bb.WriteString("e10:incompletei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
|
||||
bb.WriteString("ee")
|
||||
l := len(resp.Data)
|
||||
if l > 0 {
|
||||
if l > 1 {
|
||||
sort.Slice(resp.Data, func(i, j int) bool {
|
||||
return resp.Data[i].InfoHash < resp.Data[j].InfoHash
|
||||
})
|
||||
}
|
||||
for _, scrape := range resp.Data {
|
||||
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
|
||||
bb.WriteByte(':')
|
||||
bb.Write([]byte(scrape.InfoHash))
|
||||
bb.WriteString("d8:completei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
|
||||
bb.WriteString("e10:downloadedi")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
|
||||
bb.WriteString("e10:incompletei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
|
||||
bb.Write([]byte{'e', 'e'})
|
||||
}
|
||||
}
|
||||
bb.Write([]byte{'e', 'e'})
|
||||
_, _ = bb.WriteTo(w)
|
||||
|
||||
@@ -34,7 +34,7 @@ const (
|
||||
|
||||
var (
|
||||
logger = log.NewLogger("frontend/udp")
|
||||
allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
|
||||
allowedGeneratedPrivateKeyRunes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -66,7 +66,7 @@ func (cfg Config) Validate() (validCfg Config) {
|
||||
|
||||
// Generate a private key if one isn't provided by the user.
|
||||
if cfg.PrivateKey == "" {
|
||||
pkeyRunes := make([]rune, defaultKeyLen)
|
||||
pkeyRunes := make([]byte, defaultKeyLen)
|
||||
for i := range pkeyRunes {
|
||||
pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))]
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
)
|
||||
|
||||
// ErrInvalidQueryEscape is returned when a query string contains invalid
|
||||
@@ -60,7 +61,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
||||
key, value = key[:i], key[i+1:]
|
||||
}
|
||||
var k, v string
|
||||
k, err = url.QueryUnescape(string(key))
|
||||
k, err = url.QueryUnescape(str2bytes.BytesToString(key))
|
||||
if err != nil {
|
||||
// QueryUnescape returns an error like "invalid escape: '%x'".
|
||||
// But frontends record these errors to prometheus, which generates
|
||||
@@ -68,7 +69,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
||||
// We log it here for debugging instead.
|
||||
return nil, ErrInvalidQueryEscape
|
||||
}
|
||||
v, err = url.QueryUnescape(string(value))
|
||||
v, err = url.QueryUnescape(str2bytes.BytesToString(value))
|
||||
if err != nil {
|
||||
// QueryUnescape returns an error like "invalid escape: '%x'".
|
||||
// But frontends record these errors to prometheus, which generates
|
||||
@@ -83,9 +84,9 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// String returns a string parsed from a query. Every key can be returned as a
|
||||
// GetString returns a string parsed from a query. Every key can be returned as a
|
||||
// string because they are encoded in the URL as strings.
|
||||
func (qp queryParams) String(key string) (string, bool) {
|
||||
func (qp queryParams) GetString(key string) (string, bool) {
|
||||
value, ok := qp.params[strings.ToLower(key)]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func TestHandleOptionalParameters(t *testing.T) {
|
||||
t.Fatalf("expected values %v for %x", tt.values, tt.data)
|
||||
} else {
|
||||
for key, want := range tt.values {
|
||||
if got, ok := params.String(key); !ok {
|
||||
if got, ok := params.GetString(key); !ok {
|
||||
t.Fatalf("params missing entry %s for data %x", key, tt.data)
|
||||
} else if got != want {
|
||||
t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data)
|
||||
|
||||
@@ -14,13 +14,17 @@ func writeErrorResponse(w io.Writer, txID []byte, err error) {
|
||||
buf := reqRespBufferPool.Get()
|
||||
defer reqRespBufferPool.Put(buf)
|
||||
writeHeader(buf, txID, errorActionID)
|
||||
message := "mochi internal error"
|
||||
var clientErr bittorrent.ClientError
|
||||
// If the client wasn't at fault, acknowledge it.
|
||||
if !errors.As(err, new(bittorrent.ClientError)) {
|
||||
buf.WriteString("internal error occurred: ")
|
||||
if errors.As(err, &clientErr) {
|
||||
message = clientErr.Error()
|
||||
} else {
|
||||
logger.Error().Err(err).Msg("internal error")
|
||||
}
|
||||
buf.WriteString(err.Error())
|
||||
buf.WriteString(message)
|
||||
buf.WriteByte('\000')
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeAnnounceResponse encodes an announce response according to BEP 15.
|
||||
@@ -51,7 +55,7 @@ func writeAnnounceResponse(w io.Writer, txID []byte, resp *bittorrent.AnnounceRe
|
||||
_ = binary.Write(buf, binary.BigEndian, peer.Port())
|
||||
}
|
||||
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeScrapeResponse encodes a scrape response according to BEP 15.
|
||||
@@ -61,13 +65,12 @@ func writeScrapeResponse(w io.Writer, txID []byte, resp *bittorrent.ScrapeRespon
|
||||
|
||||
writeHeader(buf, txID, scrapeActionID)
|
||||
|
||||
for _, scrape := range resp.Files {
|
||||
for _, scrape := range resp.Data {
|
||||
_ = binary.Write(buf, binary.BigEndian, scrape.Complete)
|
||||
_ = binary.Write(buf, binary.BigEndian, scrape.Snatches)
|
||||
_ = binary.Write(buf, binary.BigEndian, scrape.Incomplete)
|
||||
}
|
||||
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeConnectionID encodes a new connection response according to BEP 15.
|
||||
@@ -77,7 +80,7 @@ func writeConnectionID(w io.Writer, txID, connID []byte) {
|
||||
|
||||
writeHeader(buf, txID, connectActionID)
|
||||
buf.Write(connID)
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeHeader writes the action and transaction ID to the provided response
|
||||
|
||||
@@ -6,7 +6,6 @@ package clientapproval
|
||||
|
||||
import (
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
)
|
||||
|
||||
// ClientID represents the part of a PeerID that identifies a Peer's client
|
||||
@@ -17,14 +16,9 @@ type ClientID [6]byte
|
||||
func NewClientID(pid bittorrent.PeerID) ClientID {
|
||||
var cid ClientID
|
||||
if pid[0] == '-' {
|
||||
copy(cid[:], pid[1:7])
|
||||
cid = ClientID(pid[1:7])
|
||||
} else {
|
||||
copy(cid[:], pid[:6])
|
||||
cid = ClientID(pid[:6])
|
||||
}
|
||||
|
||||
return cid
|
||||
}
|
||||
|
||||
func (cid ClientID) String() string {
|
||||
return str2bytes.BytesToString(cid[:])
|
||||
}
|
||||
|
||||
@@ -20,23 +20,21 @@ func init() {
|
||||
middleware.RegisterBuilder(Name, build)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
|
||||
ErrClientUnapproved = bittorrent.ClientError("unapproved client")
|
||||
|
||||
errBothListsProvided = errors.New("using both whitelist and blacklist is invalid")
|
||||
)
|
||||
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
|
||||
var ErrClientUnapproved = bittorrent.ClientError("client not allowed by mochi")
|
||||
|
||||
// Config represents all the values required by this middleware to validate
|
||||
// peers based on their BitTorrent client ID.
|
||||
type Config struct {
|
||||
Whitelist []string
|
||||
Blacklist []string
|
||||
// Static list of client IDs.
|
||||
ClientIDList []string `cfg:"client_id_list"`
|
||||
// If Invert set to true, all client IDs stored in ClientIDList should be blacklisted.
|
||||
Invert bool
|
||||
}
|
||||
|
||||
type hook struct {
|
||||
approved map[ClientID]struct{}
|
||||
unapproved map[ClientID]struct{}
|
||||
clientIDs map[ClientID]any
|
||||
invert bool
|
||||
}
|
||||
|
||||
func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) {
|
||||
@@ -47,53 +45,31 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error
|
||||
}
|
||||
|
||||
h := &hook{
|
||||
approved: make(map[ClientID]struct{}),
|
||||
unapproved: make(map[ClientID]struct{}),
|
||||
clientIDs: make(map[ClientID]any, len(cfg.ClientIDList)),
|
||||
invert: cfg.Invert,
|
||||
}
|
||||
|
||||
if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 {
|
||||
return nil, errBothListsProvided
|
||||
}
|
||||
|
||||
for _, cidString := range cfg.Whitelist {
|
||||
for _, cidString := range cfg.ClientIDList {
|
||||
cidBytes := []byte(cidString)
|
||||
if len(cidBytes) != 6 {
|
||||
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
|
||||
}
|
||||
var cid ClientID
|
||||
copy(cid[:], cidBytes)
|
||||
h.approved[cid] = struct{}{}
|
||||
}
|
||||
|
||||
for _, cidString := range cfg.Blacklist {
|
||||
cidBytes := []byte(cidString)
|
||||
if len(cidBytes) != 6 {
|
||||
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
|
||||
}
|
||||
var cid ClientID
|
||||
copy(cid[:], cidBytes)
|
||||
h.unapproved[cid] = struct{}{}
|
||||
h.clientIDs[ClientID(cidBytes)] = true
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// HandleAnnounce checks if specified ClientID is approved or not.
|
||||
// If Config.Invert set to true and hash found in provided list, function will return ErrClientUnapproved,
|
||||
// that means that ClientID is blacklisted.
|
||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||
clientID := NewClientID(req.ID)
|
||||
|
||||
if len(h.approved) > 0 {
|
||||
if _, found := h.approved[clientID]; !found {
|
||||
return ctx, ErrClientUnapproved
|
||||
}
|
||||
var err error
|
||||
if _, contains := h.clientIDs[NewClientID(req.ID)]; contains == h.invert {
|
||||
err = ErrClientUnapproved
|
||||
}
|
||||
|
||||
if len(h.unapproved) > 0 {
|
||||
if _, found := h.unapproved[clientID]; found {
|
||||
return ctx, ErrClientUnapproved
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
||||
|
||||
@@ -19,7 +19,7 @@ var cases = []struct {
|
||||
// Client ID is whitelisted
|
||||
{
|
||||
Config{
|
||||
Whitelist: []string{"010203"},
|
||||
ClientIDList: []string{"010203"},
|
||||
},
|
||||
"01020304050607080900",
|
||||
true,
|
||||
@@ -27,7 +27,7 @@ var cases = []struct {
|
||||
// Client ID is not whitelisted
|
||||
{
|
||||
Config{
|
||||
Whitelist: []string{"010203"},
|
||||
ClientIDList: []string{"010203"},
|
||||
},
|
||||
"10203040506070809000",
|
||||
false,
|
||||
@@ -35,7 +35,8 @@ var cases = []struct {
|
||||
// Client ID is not blacklisted
|
||||
{
|
||||
Config{
|
||||
Blacklist: []string{"010203"},
|
||||
ClientIDList: []string{"010203"},
|
||||
Invert: true,
|
||||
},
|
||||
"00000000001234567890",
|
||||
true,
|
||||
@@ -43,7 +44,8 @@ var cases = []struct {
|
||||
// Client ID is blacklisted
|
||||
{
|
||||
Config{
|
||||
Blacklist: []string{"123456"},
|
||||
ClientIDList: []string{"123456"},
|
||||
Invert: true,
|
||||
},
|
||||
"12345678900000000000",
|
||||
false,
|
||||
@@ -53,7 +55,7 @@ var cases = []struct {
|
||||
func TestHandleAnnounce(t *testing.T) {
|
||||
for _, tt := range cases {
|
||||
t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) {
|
||||
c := conf.MapConfig{"whitelist": tt.cfg.Whitelist, "blacklist": tt.cfg.Blacklist}
|
||||
c := conf.MapConfig{"client_id_list": tt.cfg.ClientIDList, "invert": tt.cfg.Invert}
|
||||
h, err := build(c, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
@@ -67,7 +69,7 @@ func TestHandleAnnounce(t *testing.T) {
|
||||
|
||||
nctx, err := h.HandleAnnounce(ctx, req, resp)
|
||||
require.Equal(t, ctx, nctx)
|
||||
if tt.approved == true {
|
||||
if tt.approved {
|
||||
require.NotEqual(t, err, ErrClientUnapproved)
|
||||
} else {
|
||||
require.Equal(t, err, ErrClientUnapproved)
|
||||
|
||||
@@ -223,7 +223,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Files = append(resp.Files, scr)
|
||||
resp.Data = append(resp.Data, scr)
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
|
||||
@@ -34,10 +34,10 @@ func init() {
|
||||
var (
|
||||
logger = log.NewLogger("middleware/jwt")
|
||||
// ErrMissingJWT is returned when a JWT is missing from a request.
|
||||
ErrMissingJWT = bittorrent.ClientError("unapproved request: missing jwt")
|
||||
ErrMissingJWT = bittorrent.ClientError("request not allowed by mochi: missing jwt")
|
||||
|
||||
// ErrInvalidJWT is returned when a JWT fails to verify.
|
||||
ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt")
|
||||
ErrInvalidJWT = bittorrent.ClientError("request not allowed by mochi: invalid jwt")
|
||||
|
||||
errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL")
|
||||
|
||||
@@ -248,7 +248,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
||||
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
|
||||
if params != nil {
|
||||
var found bool
|
||||
if jwt, found = params.String(h.cfg.Header); found {
|
||||
if jwt, found = params.GetString(h.cfg.Header); found {
|
||||
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
|
||||
jwt = jwt[len(bearerAuthPrefix):]
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ type JWKSKeys struct {
|
||||
|
||||
type params map[string]string
|
||||
|
||||
func (p params) String(key string) (out string, found bool) {
|
||||
func (p params) GetString(key string) (out string, found bool) {
|
||||
out, found = p[key]
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
@@ -80,14 +79,13 @@ func (l *Logic) AfterAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
||||
func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) {
|
||||
logger.Debug().Object("request", req).Msg("new scrape request")
|
||||
resp = &bittorrent.ScrapeResponse{
|
||||
Files: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
|
||||
Data: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
|
||||
}
|
||||
for _, h := range l.preHooks {
|
||||
if ctx, err = h.HandleScrape(ctx, req, resp); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
sort.Sort(&resp.Files)
|
||||
|
||||
logger.Debug().Object("response", resp).Msg("generated scrape response")
|
||||
return ctx, resp, nil
|
||||
|
||||
@@ -65,7 +65,7 @@ func build(config conf.MapConfig, st storage.PeerStorage) (h middleware.Hook, er
|
||||
}
|
||||
|
||||
// ErrTorrentUnapproved is the error returned when a torrent hash is invalid.
|
||||
var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent")
|
||||
var ErrTorrentUnapproved = bittorrent.ClientError("torrent not allowed by mochi")
|
||||
|
||||
type hook struct {
|
||||
hashContainer container.Container
|
||||
|
||||
Reference in New Issue
Block a user