mirror of
https://github.com/sot-tech/mochi.git
synced 2026-05-26 17:54:45 -07:00
(partially tested) simplify client approval m/w
* sanitize code
This commit is contained in:
@@ -13,9 +13,9 @@ import (
|
|||||||
//
|
//
|
||||||
// See ParseURLData for specifics on parsing and limitations.
|
// See ParseURLData for specifics on parsing and limitations.
|
||||||
type Params interface {
|
type Params interface {
|
||||||
// String returns a string parsed from a query. Every key can be
|
// GetString returns a string parsed from a query. Every key can be
|
||||||
// returned as a string because they are encoded in the URL as strings.
|
// returned as a string because they are encoded in the URL as strings.
|
||||||
String(key string) (string, bool)
|
GetString(key string) (string, bool)
|
||||||
|
|
||||||
zerolog.LogObjectMarshaler
|
zerolog.LogObjectMarshaler
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -238,24 +238,10 @@ func (s Scrape) MarshalZerologObject(e *zerolog.Event) {
|
|||||||
// Scrapes wrapper of array of Scrape-s
|
// Scrapes wrapper of array of Scrape-s
|
||||||
type Scrapes []Scrape
|
type Scrapes []Scrape
|
||||||
|
|
||||||
func (s *Scrapes) Len() int {
|
|
||||||
return len(*s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Scrapes) Less(i, j int) bool {
|
|
||||||
return (*s)[i].InfoHash < (*s)[j].InfoHash
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Scrapes) Swap(i, j int) {
|
|
||||||
(*s)[i], (*s)[j] = (*s)[j], (*s)[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalZerologArray writes array elements to zerolog event
|
// MarshalZerologArray writes array elements to zerolog event
|
||||||
func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
func (s Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
||||||
if s != nil {
|
for _, scrape := range s {
|
||||||
for _, scrape := range *s {
|
a.Object(scrape)
|
||||||
a.Object(scrape)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,10 +250,10 @@ func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
|||||||
// The Scrapes must be in the same order as the InfoHashes in the corresponding
|
// The Scrapes must be in the same order as the InfoHashes in the corresponding
|
||||||
// ScrapeRequest.
|
// ScrapeRequest.
|
||||||
type ScrapeResponse struct {
|
type ScrapeResponse struct {
|
||||||
Files Scrapes
|
Data Scrapes
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalZerologObject writes fields into zerolog event
|
// MarshalZerologObject writes fields into zerolog event
|
||||||
func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) {
|
func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) {
|
||||||
e.Array("scrapes", &sr.Files)
|
e.Array("scrapes", sr.Data)
|
||||||
}
|
}
|
||||||
|
|||||||
6
dist/example_config.yaml
vendored
6
dist/example_config.yaml
vendored
@@ -190,10 +190,10 @@ prehooks:
|
|||||||
#
|
#
|
||||||
# - name: client approval
|
# - name: client approval
|
||||||
# config:
|
# config:
|
||||||
# whitelist:
|
# client_id_list:
|
||||||
# - "OP1011"
|
# - "OP1011"
|
||||||
# blacklist:
|
# true - whitelist mode, false - blacklist
|
||||||
# - "OP1012"
|
# invert: true
|
||||||
#
|
#
|
||||||
# - name: interval variation
|
# - name: interval variation
|
||||||
# config:
|
# config:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/sot-tech/mochi/bittorrent"
|
"github.com/sot-tech/mochi/bittorrent"
|
||||||
|
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -12,11 +13,11 @@ type queryParams struct {
|
|||||||
*fasthttp.Args
|
*fasthttp.Args
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a string parsed from a query. Every key can be returned as a
|
// GetString returns a string parsed from a query. Every key can be returned as a
|
||||||
// string because they are encoded in the URL as strings.
|
// string because they are encoded in the URL as strings.
|
||||||
func (qp queryParams) String(key string) (string, bool) {
|
func (qp queryParams) GetString(key string) (string, bool) {
|
||||||
v := qp.Peek(key)
|
v := qp.Peek(key)
|
||||||
return string(v), v != nil
|
return str2bytes.BytesToString(v), v != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InfoHashes returns a list of requested infohashes.
|
// InfoHashes returns a list of requested infohashes.
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"github.com/sot-tech/mochi/bittorrent"
|
"github.com/sot-tech/mochi/bittorrent"
|
||||||
"github.com/sot-tech/mochi/frontend"
|
"github.com/sot-tech/mochi/frontend"
|
||||||
|
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseOptions is the configuration used to parse an Announce Request.
|
// ParseOptions is the configuration used to parse an Announce Request.
|
||||||
@@ -40,7 +41,7 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou
|
|||||||
// Attempt to parse the event from the request.
|
// Attempt to parse the event from the request.
|
||||||
var eventStr string
|
var eventStr string
|
||||||
var err error
|
var err error
|
||||||
eventStr, request.EventProvided = qp.String("event")
|
eventStr, request.EventProvided = qp.GetString("event")
|
||||||
if request.EventProvided {
|
if request.EventProvided {
|
||||||
if request.Event, err = bittorrent.NewEvent(eventStr); err != nil {
|
if request.Event, err = bittorrent.NewEvent(eventStr); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -136,7 +137,7 @@ func parseScrape(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.ScrapeR
|
|||||||
func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) {
|
func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) {
|
||||||
if opts.AllowIPSpoofing {
|
if opts.AllowIPSpoofing {
|
||||||
for _, f := range []string{"ip", "ipv4", "ipv6"} {
|
for _, f := range []string{"ip", "ipv4", "ipv6"} {
|
||||||
if ipStr, ok := p.String(f); ok {
|
if ipStr, ok := p.GetString(f); ok {
|
||||||
addresses.Add(parseRequestAddress(ipStr, true))
|
addresses.Add(parseRequestAddress(ipStr, true))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -146,7 +147,7 @@ func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (ad
|
|||||||
for _, ipStr := range ipValues {
|
for _, ipStr := range ipValues {
|
||||||
for _, ipStr := range bytes.Split(ipStr, []byte{','}) {
|
for _, ipStr := range bytes.Split(ipStr, []byte{','}) {
|
||||||
if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 {
|
if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 {
|
||||||
addresses.Add(parseRequestAddress(string(ipStr), false))
|
addresses.Add(parseRequestAddress(str2bytes.BytesToString(ipStr), false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,7 +18,7 @@ import (
|
|||||||
var respBufferPool = bytepool.NewBufferPool()
|
var respBufferPool = bytepool.NewBufferPool()
|
||||||
|
|
||||||
func writeErrorResponse(w io.StringWriter, err error) {
|
func writeErrorResponse(w io.StringWriter, err error) {
|
||||||
message := "internal server error"
|
message := "mochi internal error"
|
||||||
var clientErr bittorrent.ClientError
|
var clientErr bittorrent.ClientError
|
||||||
if errors.As(err, &clientErr) {
|
if errors.As(err, &clientErr) {
|
||||||
message = clientErr.Error()
|
message = clientErr.Error()
|
||||||
@@ -99,25 +100,33 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) {
|
|||||||
bb.Write(peer.ID.Bytes())
|
bb.Write(peer.ID.Bytes())
|
||||||
}
|
}
|
||||||
bb.WriteString("4:porti")
|
bb.WriteString("4:porti")
|
||||||
bb.Write(fasthttp.AppendUint(nil, int(peer.Port())))
|
port := peer.Port()
|
||||||
bb.Write([]byte{'e', 'e'})
|
bb.Write([]byte{byte(port >> 8), byte(port), 'e', 'e'})
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) {
|
func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) {
|
||||||
bb := respBufferPool.Get()
|
bb := respBufferPool.Get()
|
||||||
defer respBufferPool.Put(bb)
|
defer respBufferPool.Put(bb)
|
||||||
bb.WriteString("d5:filesd")
|
bb.WriteString("d5:filesd")
|
||||||
for _, scrape := range resp.Files {
|
l := len(resp.Data)
|
||||||
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
|
if l > 0 {
|
||||||
bb.WriteByte(':')
|
if l > 1 {
|
||||||
bb.Write([]byte(scrape.InfoHash))
|
sort.Slice(resp.Data, func(i, j int) bool {
|
||||||
bb.WriteString("d8:completei")
|
return resp.Data[i].InfoHash < resp.Data[j].InfoHash
|
||||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
|
})
|
||||||
bb.WriteString("e10:downloadedi")
|
}
|
||||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
|
for _, scrape := range resp.Data {
|
||||||
bb.WriteString("e10:incompletei")
|
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
|
||||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
|
bb.WriteByte(':')
|
||||||
bb.WriteString("ee")
|
bb.Write([]byte(scrape.InfoHash))
|
||||||
|
bb.WriteString("d8:completei")
|
||||||
|
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
|
||||||
|
bb.WriteString("e10:downloadedi")
|
||||||
|
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
|
||||||
|
bb.WriteString("e10:incompletei")
|
||||||
|
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
|
||||||
|
bb.Write([]byte{'e', 'e'})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
bb.Write([]byte{'e', 'e'})
|
bb.Write([]byte{'e', 'e'})
|
||||||
_, _ = bb.WriteTo(w)
|
_, _ = bb.WriteTo(w)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
logger = log.NewLogger("frontend/udp")
|
logger = log.NewLogger("frontend/udp")
|
||||||
allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
|
allowedGeneratedPrivateKeyRunes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -66,7 +66,7 @@ func (cfg Config) Validate() (validCfg Config) {
|
|||||||
|
|
||||||
// Generate a private key if one isn't provided by the user.
|
// Generate a private key if one isn't provided by the user.
|
||||||
if cfg.PrivateKey == "" {
|
if cfg.PrivateKey == "" {
|
||||||
pkeyRunes := make([]rune, defaultKeyLen)
|
pkeyRunes := make([]byte, defaultKeyLen)
|
||||||
for i := range pkeyRunes {
|
for i := range pkeyRunes {
|
||||||
pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))]
|
pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/sot-tech/mochi/bittorrent"
|
"github.com/sot-tech/mochi/bittorrent"
|
||||||
|
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrInvalidQueryEscape is returned when a query string contains invalid
|
// ErrInvalidQueryEscape is returned when a query string contains invalid
|
||||||
@@ -60,7 +61,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
|||||||
key, value = key[:i], key[i+1:]
|
key, value = key[:i], key[i+1:]
|
||||||
}
|
}
|
||||||
var k, v string
|
var k, v string
|
||||||
k, err = url.QueryUnescape(string(key))
|
k, err = url.QueryUnescape(str2bytes.BytesToString(key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// QueryUnescape returns an error like "invalid escape: '%x'".
|
// QueryUnescape returns an error like "invalid escape: '%x'".
|
||||||
// But frontends record these errors to prometheus, which generates
|
// But frontends record these errors to prometheus, which generates
|
||||||
@@ -68,7 +69,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
|||||||
// We log it here for debugging instead.
|
// We log it here for debugging instead.
|
||||||
return nil, ErrInvalidQueryEscape
|
return nil, ErrInvalidQueryEscape
|
||||||
}
|
}
|
||||||
v, err = url.QueryUnescape(string(value))
|
v, err = url.QueryUnescape(str2bytes.BytesToString(value))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// QueryUnescape returns an error like "invalid escape: '%x'".
|
// QueryUnescape returns an error like "invalid escape: '%x'".
|
||||||
// But frontends record these errors to prometheus, which generates
|
// But frontends record these errors to prometheus, which generates
|
||||||
@@ -83,9 +84,9 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
|||||||
return q, nil
|
return q, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a string parsed from a query. Every key can be returned as a
|
// GetString returns a string parsed from a query. Every key can be returned as a
|
||||||
// string because they are encoded in the URL as strings.
|
// string because they are encoded in the URL as strings.
|
||||||
func (qp queryParams) String(key string) (string, bool) {
|
func (qp queryParams) GetString(key string) (string, bool) {
|
||||||
value, ok := qp.params[strings.ToLower(key)]
|
value, ok := qp.params[strings.ToLower(key)]
|
||||||
return value, ok
|
return value, ok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func TestHandleOptionalParameters(t *testing.T) {
|
|||||||
t.Fatalf("expected values %v for %x", tt.values, tt.data)
|
t.Fatalf("expected values %v for %x", tt.values, tt.data)
|
||||||
} else {
|
} else {
|
||||||
for key, want := range tt.values {
|
for key, want := range tt.values {
|
||||||
if got, ok := params.String(key); !ok {
|
if got, ok := params.GetString(key); !ok {
|
||||||
t.Fatalf("params missing entry %s for data %x", key, tt.data)
|
t.Fatalf("params missing entry %s for data %x", key, tt.data)
|
||||||
} else if got != want {
|
} else if got != want {
|
||||||
t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data)
|
t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data)
|
||||||
|
|||||||
@@ -14,13 +14,17 @@ func writeErrorResponse(w io.Writer, txID []byte, err error) {
|
|||||||
buf := reqRespBufferPool.Get()
|
buf := reqRespBufferPool.Get()
|
||||||
defer reqRespBufferPool.Put(buf)
|
defer reqRespBufferPool.Put(buf)
|
||||||
writeHeader(buf, txID, errorActionID)
|
writeHeader(buf, txID, errorActionID)
|
||||||
|
message := "mochi internal error"
|
||||||
|
var clientErr bittorrent.ClientError
|
||||||
// If the client wasn't at fault, acknowledge it.
|
// If the client wasn't at fault, acknowledge it.
|
||||||
if !errors.As(err, new(bittorrent.ClientError)) {
|
if errors.As(err, &clientErr) {
|
||||||
buf.WriteString("internal error occurred: ")
|
message = clientErr.Error()
|
||||||
|
} else {
|
||||||
|
logger.Error().Err(err).Msg("internal error")
|
||||||
}
|
}
|
||||||
buf.WriteString(err.Error())
|
buf.WriteString(message)
|
||||||
buf.WriteByte('\000')
|
buf.WriteByte('\000')
|
||||||
_, _ = w.Write(buf.Bytes())
|
_, _ = buf.WriteTo(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeAnnounceResponse encodes an announce response according to BEP 15.
|
// writeAnnounceResponse encodes an announce response according to BEP 15.
|
||||||
@@ -51,7 +55,7 @@ func writeAnnounceResponse(w io.Writer, txID []byte, resp *bittorrent.AnnounceRe
|
|||||||
_ = binary.Write(buf, binary.BigEndian, peer.Port())
|
_ = binary.Write(buf, binary.BigEndian, peer.Port())
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = w.Write(buf.Bytes())
|
_, _ = buf.WriteTo(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeScrapeResponse encodes a scrape response according to BEP 15.
|
// writeScrapeResponse encodes a scrape response according to BEP 15.
|
||||||
@@ -61,13 +65,12 @@ func writeScrapeResponse(w io.Writer, txID []byte, resp *bittorrent.ScrapeRespon
|
|||||||
|
|
||||||
writeHeader(buf, txID, scrapeActionID)
|
writeHeader(buf, txID, scrapeActionID)
|
||||||
|
|
||||||
for _, scrape := range resp.Files {
|
for _, scrape := range resp.Data {
|
||||||
_ = binary.Write(buf, binary.BigEndian, scrape.Complete)
|
_ = binary.Write(buf, binary.BigEndian, scrape.Complete)
|
||||||
_ = binary.Write(buf, binary.BigEndian, scrape.Snatches)
|
_ = binary.Write(buf, binary.BigEndian, scrape.Snatches)
|
||||||
_ = binary.Write(buf, binary.BigEndian, scrape.Incomplete)
|
_ = binary.Write(buf, binary.BigEndian, scrape.Incomplete)
|
||||||
}
|
}
|
||||||
|
_, _ = buf.WriteTo(w)
|
||||||
_, _ = w.Write(buf.Bytes())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeConnectionID encodes a new connection response according to BEP 15.
|
// writeConnectionID encodes a new connection response according to BEP 15.
|
||||||
@@ -77,7 +80,7 @@ func writeConnectionID(w io.Writer, txID, connID []byte) {
|
|||||||
|
|
||||||
writeHeader(buf, txID, connectActionID)
|
writeHeader(buf, txID, connectActionID)
|
||||||
buf.Write(connID)
|
buf.Write(connID)
|
||||||
_, _ = w.Write(buf.Bytes())
|
_, _ = buf.WriteTo(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeHeader writes the action and transaction ID to the provided response
|
// writeHeader writes the action and transaction ID to the provided response
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package clientapproval
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sot-tech/mochi/bittorrent"
|
"github.com/sot-tech/mochi/bittorrent"
|
||||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientID represents the part of a PeerID that identifies a Peer's client
|
// ClientID represents the part of a PeerID that identifies a Peer's client
|
||||||
@@ -17,14 +16,9 @@ type ClientID [6]byte
|
|||||||
func NewClientID(pid bittorrent.PeerID) ClientID {
|
func NewClientID(pid bittorrent.PeerID) ClientID {
|
||||||
var cid ClientID
|
var cid ClientID
|
||||||
if pid[0] == '-' {
|
if pid[0] == '-' {
|
||||||
copy(cid[:], pid[1:7])
|
cid = ClientID(pid[1:7])
|
||||||
} else {
|
} else {
|
||||||
copy(cid[:], pid[:6])
|
cid = ClientID(pid[:6])
|
||||||
}
|
}
|
||||||
|
|
||||||
return cid
|
return cid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cid ClientID) String() string {
|
|
||||||
return str2bytes.BytesToString(cid[:])
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -20,23 +20,21 @@ func init() {
|
|||||||
middleware.RegisterBuilder(Name, build)
|
middleware.RegisterBuilder(Name, build)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
|
||||||
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
|
var ErrClientUnapproved = bittorrent.ClientError("client not allowed by mochi")
|
||||||
ErrClientUnapproved = bittorrent.ClientError("unapproved client")
|
|
||||||
|
|
||||||
errBothListsProvided = errors.New("using both whitelist and blacklist is invalid")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Config represents all the values required by this middleware to validate
|
// Config represents all the values required by this middleware to validate
|
||||||
// peers based on their BitTorrent client ID.
|
// peers based on their BitTorrent client ID.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Whitelist []string
|
// Static list of client IDs.
|
||||||
Blacklist []string
|
ClientIDList []string `cfg:"client_id_list"`
|
||||||
|
// If Invert set to true, all client IDs stored in ClientIDList should be blacklisted.
|
||||||
|
Invert bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type hook struct {
|
type hook struct {
|
||||||
approved map[ClientID]struct{}
|
clientIDs map[ClientID]any
|
||||||
unapproved map[ClientID]struct{}
|
invert bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) {
|
func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) {
|
||||||
@@ -47,53 +45,31 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
h := &hook{
|
h := &hook{
|
||||||
approved: make(map[ClientID]struct{}),
|
clientIDs: make(map[ClientID]any, len(cfg.ClientIDList)),
|
||||||
unapproved: make(map[ClientID]struct{}),
|
invert: cfg.Invert,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 {
|
for _, cidString := range cfg.ClientIDList {
|
||||||
return nil, errBothListsProvided
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cidString := range cfg.Whitelist {
|
|
||||||
cidBytes := []byte(cidString)
|
cidBytes := []byte(cidString)
|
||||||
if len(cidBytes) != 6 {
|
if len(cidBytes) != 6 {
|
||||||
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
|
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
|
||||||
}
|
}
|
||||||
var cid ClientID
|
h.clientIDs[ClientID(cidBytes)] = true
|
||||||
copy(cid[:], cidBytes)
|
|
||||||
h.approved[cid] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cidString := range cfg.Blacklist {
|
|
||||||
cidBytes := []byte(cidString)
|
|
||||||
if len(cidBytes) != 6 {
|
|
||||||
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
|
|
||||||
}
|
|
||||||
var cid ClientID
|
|
||||||
copy(cid[:], cidBytes)
|
|
||||||
h.unapproved[cid] = struct{}{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandleAnnounce checks if specified ClientID is approved or not.
|
||||||
|
// If Config.Invert set to true and hash found in provided list, function will return ErrClientUnapproved,
|
||||||
|
// that means that ClientID is blacklisted.
|
||||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||||
clientID := NewClientID(req.ID)
|
var err error
|
||||||
|
if _, contains := h.clientIDs[NewClientID(req.ID)]; contains == h.invert {
|
||||||
if len(h.approved) > 0 {
|
err = ErrClientUnapproved
|
||||||
if _, found := h.approved[clientID]; !found {
|
|
||||||
return ctx, ErrClientUnapproved
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(h.unapproved) > 0 {
|
return ctx, err
|
||||||
if _, found := h.unapproved[clientID]; found {
|
|
||||||
return ctx, ErrClientUnapproved
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ var cases = []struct {
|
|||||||
// Client ID is whitelisted
|
// Client ID is whitelisted
|
||||||
{
|
{
|
||||||
Config{
|
Config{
|
||||||
Whitelist: []string{"010203"},
|
ClientIDList: []string{"010203"},
|
||||||
},
|
},
|
||||||
"01020304050607080900",
|
"01020304050607080900",
|
||||||
true,
|
true,
|
||||||
@@ -27,7 +27,7 @@ var cases = []struct {
|
|||||||
// Client ID is not whitelisted
|
// Client ID is not whitelisted
|
||||||
{
|
{
|
||||||
Config{
|
Config{
|
||||||
Whitelist: []string{"010203"},
|
ClientIDList: []string{"010203"},
|
||||||
},
|
},
|
||||||
"10203040506070809000",
|
"10203040506070809000",
|
||||||
false,
|
false,
|
||||||
@@ -35,7 +35,8 @@ var cases = []struct {
|
|||||||
// Client ID is not blacklisted
|
// Client ID is not blacklisted
|
||||||
{
|
{
|
||||||
Config{
|
Config{
|
||||||
Blacklist: []string{"010203"},
|
ClientIDList: []string{"010203"},
|
||||||
|
Invert: true,
|
||||||
},
|
},
|
||||||
"00000000001234567890",
|
"00000000001234567890",
|
||||||
true,
|
true,
|
||||||
@@ -43,7 +44,8 @@ var cases = []struct {
|
|||||||
// Client ID is blacklisted
|
// Client ID is blacklisted
|
||||||
{
|
{
|
||||||
Config{
|
Config{
|
||||||
Blacklist: []string{"123456"},
|
ClientIDList: []string{"123456"},
|
||||||
|
Invert: true,
|
||||||
},
|
},
|
||||||
"12345678900000000000",
|
"12345678900000000000",
|
||||||
false,
|
false,
|
||||||
@@ -53,7 +55,7 @@ var cases = []struct {
|
|||||||
func TestHandleAnnounce(t *testing.T) {
|
func TestHandleAnnounce(t *testing.T) {
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) {
|
t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) {
|
||||||
c := conf.MapConfig{"whitelist": tt.cfg.Whitelist, "blacklist": tt.cfg.Blacklist}
|
c := conf.MapConfig{"client_id_list": tt.cfg.ClientIDList, "invert": tt.cfg.Invert}
|
||||||
h, err := build(c, nil)
|
h, err := build(c, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
@@ -67,7 +69,7 @@ func TestHandleAnnounce(t *testing.T) {
|
|||||||
|
|
||||||
nctx, err := h.HandleAnnounce(ctx, req, resp)
|
nctx, err := h.HandleAnnounce(ctx, req, resp)
|
||||||
require.Equal(t, ctx, nctx)
|
require.Equal(t, ctx, nctx)
|
||||||
if tt.approved == true {
|
if tt.approved {
|
||||||
require.NotEqual(t, err, ErrClientUnapproved)
|
require.NotEqual(t, err, ErrClientUnapproved)
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, err, ErrClientUnapproved)
|
require.Equal(t, err, ErrClientUnapproved)
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.Files = append(resp.Files, scr)
|
resp.Data = append(resp.Data, scr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
|
|||||||
@@ -34,10 +34,10 @@ func init() {
|
|||||||
var (
|
var (
|
||||||
logger = log.NewLogger("middleware/jwt")
|
logger = log.NewLogger("middleware/jwt")
|
||||||
// ErrMissingJWT is returned when a JWT is missing from a request.
|
// ErrMissingJWT is returned when a JWT is missing from a request.
|
||||||
ErrMissingJWT = bittorrent.ClientError("unapproved request: missing jwt")
|
ErrMissingJWT = bittorrent.ClientError("request not allowed by mochi: missing jwt")
|
||||||
|
|
||||||
// ErrInvalidJWT is returned when a JWT fails to verify.
|
// ErrInvalidJWT is returned when a JWT fails to verify.
|
||||||
ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt")
|
ErrInvalidJWT = bittorrent.ClientError("request not allowed by mochi: invalid jwt")
|
||||||
|
|
||||||
errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL")
|
errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL")
|
||||||
|
|
||||||
@@ -248,7 +248,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
|||||||
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
|
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
|
||||||
if params != nil {
|
if params != nil {
|
||||||
var found bool
|
var found bool
|
||||||
if jwt, found = params.String(h.cfg.Header); found {
|
if jwt, found = params.GetString(h.cfg.Header); found {
|
||||||
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
|
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
|
||||||
jwt = jwt[len(bearerAuthPrefix):]
|
jwt = jwt[len(bearerAuthPrefix):]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ type JWKSKeys struct {
|
|||||||
|
|
||||||
type params map[string]string
|
type params map[string]string
|
||||||
|
|
||||||
func (p params) String(key string) (out string, found bool) {
|
func (p params) GetString(key string) (out string, found bool) {
|
||||||
out, found = p[key]
|
out, found = p[key]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sort"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sot-tech/mochi/bittorrent"
|
"github.com/sot-tech/mochi/bittorrent"
|
||||||
@@ -80,14 +79,13 @@ func (l *Logic) AfterAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
|||||||
func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) {
|
func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) {
|
||||||
logger.Debug().Object("request", req).Msg("new scrape request")
|
logger.Debug().Object("request", req).Msg("new scrape request")
|
||||||
resp = &bittorrent.ScrapeResponse{
|
resp = &bittorrent.ScrapeResponse{
|
||||||
Files: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
|
Data: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
|
||||||
}
|
}
|
||||||
for _, h := range l.preHooks {
|
for _, h := range l.preHooks {
|
||||||
if ctx, err = h.HandleScrape(ctx, req, resp); err != nil {
|
if ctx, err = h.HandleScrape(ctx, req, resp); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort.Sort(&resp.Files)
|
|
||||||
|
|
||||||
logger.Debug().Object("response", resp).Msg("generated scrape response")
|
logger.Debug().Object("response", resp).Msg("generated scrape response")
|
||||||
return ctx, resp, nil
|
return ctx, resp, nil
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func build(config conf.MapConfig, st storage.PeerStorage) (h middleware.Hook, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ErrTorrentUnapproved is the error returned when a torrent hash is invalid.
|
// ErrTorrentUnapproved is the error returned when a torrent hash is invalid.
|
||||||
var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent")
|
var ErrTorrentUnapproved = bittorrent.ClientError("torrent not allowed by mochi")
|
||||||
|
|
||||||
type hook struct {
|
type hook struct {
|
||||||
hashContainer container.Container
|
hashContainer container.Container
|
||||||
|
|||||||
Reference in New Issue
Block a user