diff --git a/bittorrent/params.go b/bittorrent/params.go index f97ac7a..e91977e 100644 --- a/bittorrent/params.go +++ b/bittorrent/params.go @@ -13,9 +13,9 @@ import ( // // See ParseURLData for specifics on parsing and limitations. type Params interface { - // String returns a string parsed from a query. Every key can be + // GetString returns a string parsed from a query. Every key can be // returned as a string because they are encoded in the URL as strings. - String(key string) (string, bool) + GetString(key string) (string, bool) zerolog.LogObjectMarshaler } diff --git a/bittorrent/request.go b/bittorrent/request.go index 150f87d..ad2a7ff 100644 --- a/bittorrent/request.go +++ b/bittorrent/request.go @@ -238,24 +238,10 @@ func (s Scrape) MarshalZerologObject(e *zerolog.Event) { // Scrapes wrapper of array of Scrape-s type Scrapes []Scrape -func (s *Scrapes) Len() int { - return len(*s) -} - -func (s *Scrapes) Less(i, j int) bool { - return (*s)[i].InfoHash < (*s)[j].InfoHash -} - -func (s *Scrapes) Swap(i, j int) { - (*s)[i], (*s)[j] = (*s)[j], (*s)[i] -} - // MarshalZerologArray writes array elements to zerolog event -func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) { - if s != nil { - for _, scrape := range *s { - a.Object(scrape) - } +func (s Scrapes) MarshalZerologArray(a *zerolog.Array) { + for _, scrape := range s { + a.Object(scrape) } } @@ -264,10 +250,10 @@ func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) { // The Scrapes must be in the same order as the InfoHashes in the corresponding // ScrapeRequest. type ScrapeResponse struct { - Files Scrapes + Data Scrapes } // MarshalZerologObject writes fields into zerolog event func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) { - e.Array("scrapes", &sr.Files) + e.Array("scrapes", sr.Data) } diff --git a/dist/example_config.yaml b/dist/example_config.yaml index e5d7afb..34a9093 100644 --- a/dist/example_config.yaml +++ b/dist/example_config.yaml @@ -190,10 +190,10 @@ prehooks: # # - name: client approval # config: -# whitelist: +# client_id_list: # - "OP1011" -# blacklist: -# - "OP1012" +# true - whitelist mode, false - blacklist +# invert: true # # - name: interval variation # config: diff --git a/frontend/http/params.go b/frontend/http/params.go index 3e9f759..60886e6 100644 --- a/frontend/http/params.go +++ b/frontend/http/params.go @@ -3,6 +3,7 @@ package http import ( "github.com/rs/zerolog" "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/pkg/str2bytes" "github.com/valyala/fasthttp" ) @@ -12,11 +13,11 @@ type queryParams struct { *fasthttp.Args } -// String returns a string parsed from a query. Every key can be returned as a +// GetString returns a string parsed from a query. Every key can be returned as a // string because they are encoded in the URL as strings. -func (qp queryParams) String(key string) (string, bool) { +func (qp queryParams) GetString(key string) (string, bool) { v := qp.Peek(key) - return string(v), v != nil + return str2bytes.BytesToString(v), v != nil } // InfoHashes returns a list of requested infohashes. diff --git a/frontend/http/parser.go b/frontend/http/parser.go index c96dd53..103d67b 100644 --- a/frontend/http/parser.go +++ b/frontend/http/parser.go @@ -5,10 +5,11 @@ import ( "errors" "net/netip" + "github.com/valyala/fasthttp" + "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/frontend" - - "github.com/valyala/fasthttp" + "github.com/sot-tech/mochi/pkg/str2bytes" ) // ParseOptions is the configuration used to parse an Announce Request. @@ -40,7 +41,7 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou // Attempt to parse the event from the request. var eventStr string var err error - eventStr, request.EventProvided = qp.String("event") + eventStr, request.EventProvided = qp.GetString("event") if request.EventProvided { if request.Event, err = bittorrent.NewEvent(eventStr); err != nil { return nil, err @@ -136,7 +137,7 @@ func parseScrape(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.ScrapeR func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) { if opts.AllowIPSpoofing { for _, f := range []string{"ip", "ipv4", "ipv6"} { - if ipStr, ok := p.String(f); ok { + if ipStr, ok := p.GetString(f); ok { addresses.Add(parseRequestAddress(ipStr, true)) } } @@ -146,7 +147,7 @@ func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (ad for _, ipStr := range ipValues { for _, ipStr := range bytes.Split(ipStr, []byte{','}) { if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 { - addresses.Add(parseRequestAddress(string(ipStr), false)) + addresses.Add(parseRequestAddress(str2bytes.BytesToString(ipStr), false)) } } } diff --git a/frontend/http/writer.go b/frontend/http/writer.go index 28c5610..7645b51 100644 --- a/frontend/http/writer.go +++ b/frontend/http/writer.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net" + "sort" "strconv" "time" @@ -17,7 +18,7 @@ import ( var respBufferPool = bytepool.NewBufferPool() func writeErrorResponse(w io.StringWriter, err error) { - message := "internal server error" + message := "mochi internal error" var clientErr bittorrent.ClientError if errors.As(err, &clientErr) { message = clientErr.Error() @@ -99,25 +100,33 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) { bb.Write(peer.ID.Bytes()) } bb.WriteString("4:porti") - bb.Write(fasthttp.AppendUint(nil, int(peer.Port()))) - bb.Write([]byte{'e', 'e'}) + port := peer.Port() + bb.Write([]byte{byte(port >> 8), byte(port), 'e', 'e'}) } func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) { bb := respBufferPool.Get() defer respBufferPool.Put(bb) bb.WriteString("d5:filesd") - for _, scrape := range resp.Files { - bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash))) - bb.WriteByte(':') - bb.Write([]byte(scrape.InfoHash)) - bb.WriteString("d8:completei") - bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete))) - bb.WriteString("e10:downloadedi") - bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches))) - bb.WriteString("e10:incompletei") - bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete))) - bb.WriteString("ee") + l := len(resp.Data) + if l > 0 { + if l > 1 { + sort.Slice(resp.Data, func(i, j int) bool { + return resp.Data[i].InfoHash < resp.Data[j].InfoHash + }) + } + for _, scrape := range resp.Data { + bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash))) + bb.WriteByte(':') + bb.Write([]byte(scrape.InfoHash)) + bb.WriteString("d8:completei") + bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete))) + bb.WriteString("e10:downloadedi") + bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches))) + bb.WriteString("e10:incompletei") + bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete))) + bb.Write([]byte{'e', 'e'}) + } } bb.Write([]byte{'e', 'e'}) _, _ = bb.WriteTo(w) diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index cb229b8..4e9a350 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -34,7 +34,7 @@ const ( var ( logger = log.NewLogger("frontend/udp") - allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") + allowedGeneratedPrivateKeyRunes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" ) func init() { @@ -66,7 +66,7 @@ func (cfg Config) Validate() (validCfg Config) { // Generate a private key if one isn't provided by the user. if cfg.PrivateKey == "" { - pkeyRunes := make([]rune, defaultKeyLen) + pkeyRunes := make([]byte, defaultKeyLen) for i := range pkeyRunes { pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))] } diff --git a/frontend/udp/params.go b/frontend/udp/params.go index 59a10e9..3f6d78b 100644 --- a/frontend/udp/params.go +++ b/frontend/udp/params.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog" "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/pkg/str2bytes" ) // ErrInvalidQueryEscape is returned when a query string contains invalid @@ -60,7 +61,7 @@ func parseQuery(query []byte) (q *queryParams, err error) { key, value = key[:i], key[i+1:] } var k, v string - k, err = url.QueryUnescape(string(key)) + k, err = url.QueryUnescape(str2bytes.BytesToString(key)) if err != nil { // QueryUnescape returns an error like "invalid escape: '%x'". // But frontends record these errors to prometheus, which generates @@ -68,7 +69,7 @@ func parseQuery(query []byte) (q *queryParams, err error) { // We log it here for debugging instead. return nil, ErrInvalidQueryEscape } - v, err = url.QueryUnescape(string(value)) + v, err = url.QueryUnescape(str2bytes.BytesToString(value)) if err != nil { // QueryUnescape returns an error like "invalid escape: '%x'". // But frontends record these errors to prometheus, which generates @@ -83,9 +84,9 @@ func parseQuery(query []byte) (q *queryParams, err error) { return q, nil } -// String returns a string parsed from a query. Every key can be returned as a +// GetString returns a string parsed from a query. Every key can be returned as a // string because they are encoded in the URL as strings. -func (qp queryParams) String(key string) (string, bool) { +func (qp queryParams) GetString(key string) (string, bool) { value, ok := qp.params[strings.ToLower(key)] return value, ok } diff --git a/frontend/udp/parser_test.go b/frontend/udp/parser_test.go index 585327f..517b44a 100644 --- a/frontend/udp/parser_test.go +++ b/frontend/udp/parser_test.go @@ -64,7 +64,7 @@ func TestHandleOptionalParameters(t *testing.T) { t.Fatalf("expected values %v for %x", tt.values, tt.data) } else { for key, want := range tt.values { - if got, ok := params.String(key); !ok { + if got, ok := params.GetString(key); !ok { t.Fatalf("params missing entry %s for data %x", key, tt.data) } else if got != want { t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data) diff --git a/frontend/udp/writer.go b/frontend/udp/writer.go index 61c195b..57cb00d 100644 --- a/frontend/udp/writer.go +++ b/frontend/udp/writer.go @@ -14,13 +14,17 @@ func writeErrorResponse(w io.Writer, txID []byte, err error) { buf := reqRespBufferPool.Get() defer reqRespBufferPool.Put(buf) writeHeader(buf, txID, errorActionID) + message := "mochi internal error" + var clientErr bittorrent.ClientError // If the client wasn't at fault, acknowledge it. - if !errors.As(err, new(bittorrent.ClientError)) { - buf.WriteString("internal error occurred: ") + if errors.As(err, &clientErr) { + message = clientErr.Error() + } else { + logger.Error().Err(err).Msg("internal error") } - buf.WriteString(err.Error()) + buf.WriteString(message) buf.WriteByte('\000') - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeAnnounceResponse encodes an announce response according to BEP 15. @@ -51,7 +55,7 @@ func writeAnnounceResponse(w io.Writer, txID []byte, resp *bittorrent.AnnounceRe _ = binary.Write(buf, binary.BigEndian, peer.Port()) } - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeScrapeResponse encodes a scrape response according to BEP 15. @@ -61,13 +65,12 @@ func writeScrapeResponse(w io.Writer, txID []byte, resp *bittorrent.ScrapeRespon writeHeader(buf, txID, scrapeActionID) - for _, scrape := range resp.Files { + for _, scrape := range resp.Data { _ = binary.Write(buf, binary.BigEndian, scrape.Complete) _ = binary.Write(buf, binary.BigEndian, scrape.Snatches) _ = binary.Write(buf, binary.BigEndian, scrape.Incomplete) } - - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeConnectionID encodes a new connection response according to BEP 15. @@ -77,7 +80,7 @@ func writeConnectionID(w io.Writer, txID, connID []byte) { writeHeader(buf, txID, connectActionID) buf.Write(connID) - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeHeader writes the action and transaction ID to the provided response diff --git a/middleware/clientapproval/client_id.go b/middleware/clientapproval/client_id.go index 685bf67..6be8605 100644 --- a/middleware/clientapproval/client_id.go +++ b/middleware/clientapproval/client_id.go @@ -6,7 +6,6 @@ package clientapproval import ( "github.com/sot-tech/mochi/bittorrent" - "github.com/sot-tech/mochi/pkg/str2bytes" ) // ClientID represents the part of a PeerID that identifies a Peer's client @@ -17,14 +16,9 @@ type ClientID [6]byte func NewClientID(pid bittorrent.PeerID) ClientID { var cid ClientID if pid[0] == '-' { - copy(cid[:], pid[1:7]) + cid = ClientID(pid[1:7]) } else { - copy(cid[:], pid[:6]) + cid = ClientID(pid[:6]) } - return cid } - -func (cid ClientID) String() string { - return str2bytes.BytesToString(cid[:]) -} diff --git a/middleware/clientapproval/clientapproval.go b/middleware/clientapproval/clientapproval.go index 8c272dd..ef9e6a9 100644 --- a/middleware/clientapproval/clientapproval.go +++ b/middleware/clientapproval/clientapproval.go @@ -20,23 +20,21 @@ func init() { middleware.RegisterBuilder(Name, build) } -var ( - // ErrClientUnapproved is the error returned when a client's PeerID is invalid. - ErrClientUnapproved = bittorrent.ClientError("unapproved client") - - errBothListsProvided = errors.New("using both whitelist and blacklist is invalid") -) +// ErrClientUnapproved is the error returned when a client's PeerID is invalid. +var ErrClientUnapproved = bittorrent.ClientError("client not allowed by mochi") // Config represents all the values required by this middleware to validate // peers based on their BitTorrent client ID. type Config struct { - Whitelist []string - Blacklist []string + // Static list of client IDs. + ClientIDList []string `cfg:"client_id_list"` + // If Invert set to true, all client IDs stored in ClientIDList should be blacklisted. + Invert bool } type hook struct { - approved map[ClientID]struct{} - unapproved map[ClientID]struct{} + clientIDs map[ClientID]any + invert bool } func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) { @@ -47,53 +45,31 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error } h := &hook{ - approved: make(map[ClientID]struct{}), - unapproved: make(map[ClientID]struct{}), + clientIDs: make(map[ClientID]any, len(cfg.ClientIDList)), + invert: cfg.Invert, } - if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 { - return nil, errBothListsProvided - } - - for _, cidString := range cfg.Whitelist { + for _, cidString := range cfg.ClientIDList { cidBytes := []byte(cidString) if len(cidBytes) != 6 { return nil, errors.New("client ID " + cidString + " must be 6 bytes") } - var cid ClientID - copy(cid[:], cidBytes) - h.approved[cid] = struct{}{} - } - - for _, cidString := range cfg.Blacklist { - cidBytes := []byte(cidString) - if len(cidBytes) != 6 { - return nil, errors.New("client ID " + cidString + " must be 6 bytes") - } - var cid ClientID - copy(cid[:], cidBytes) - h.unapproved[cid] = struct{}{} + h.clientIDs[ClientID(cidBytes)] = true } return h, nil } +// HandleAnnounce checks if specified ClientID is approved or not. +// If Config.Invert set to true and hash found in provided list, function will return ErrClientUnapproved, +// that means that ClientID is blacklisted. func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) { - clientID := NewClientID(req.ID) - - if len(h.approved) > 0 { - if _, found := h.approved[clientID]; !found { - return ctx, ErrClientUnapproved - } + var err error + if _, contains := h.clientIDs[NewClientID(req.ID)]; contains == h.invert { + err = ErrClientUnapproved } - if len(h.unapproved) > 0 { - if _, found := h.unapproved[clientID]; found { - return ctx, ErrClientUnapproved - } - } - - return ctx, nil + return ctx, err } func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) { diff --git a/middleware/clientapproval/clientapproval_test.go b/middleware/clientapproval/clientapproval_test.go index 8dabd2a..7c85ea2 100644 --- a/middleware/clientapproval/clientapproval_test.go +++ b/middleware/clientapproval/clientapproval_test.go @@ -19,7 +19,7 @@ var cases = []struct { // Client ID is whitelisted { Config{ - Whitelist: []string{"010203"}, + ClientIDList: []string{"010203"}, }, "01020304050607080900", true, @@ -27,7 +27,7 @@ var cases = []struct { // Client ID is not whitelisted { Config{ - Whitelist: []string{"010203"}, + ClientIDList: []string{"010203"}, }, "10203040506070809000", false, @@ -35,7 +35,8 @@ var cases = []struct { // Client ID is not blacklisted { Config{ - Blacklist: []string{"010203"}, + ClientIDList: []string{"010203"}, + Invert: true, }, "00000000001234567890", true, @@ -43,7 +44,8 @@ var cases = []struct { // Client ID is blacklisted { Config{ - Blacklist: []string{"123456"}, + ClientIDList: []string{"123456"}, + Invert: true, }, "12345678900000000000", false, @@ -53,7 +55,7 @@ var cases = []struct { func TestHandleAnnounce(t *testing.T) { for _, tt := range cases { t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) { - c := conf.MapConfig{"whitelist": tt.cfg.Whitelist, "blacklist": tt.cfg.Blacklist} + c := conf.MapConfig{"client_id_list": tt.cfg.ClientIDList, "invert": tt.cfg.Invert} h, err := build(c, nil) require.Nil(t, err) @@ -67,7 +69,7 @@ func TestHandleAnnounce(t *testing.T) { nctx, err := h.HandleAnnounce(ctx, req, resp) require.Equal(t, ctx, nctx) - if tt.approved == true { + if tt.approved { require.NotEqual(t, err, ErrClientUnapproved) } else { require.Equal(t, err, ErrClientUnapproved) diff --git a/middleware/hooks.go b/middleware/hooks.go index 11c6861..ce7ec89 100644 --- a/middleware/hooks.go +++ b/middleware/hooks.go @@ -223,7 +223,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR if err != nil { return } - resp.Files = append(resp.Files, scr) + resp.Data = append(resp.Data, scr) } return ctx, nil diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 61142e0..aa23c3b 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -34,10 +34,10 @@ func init() { var ( logger = log.NewLogger("middleware/jwt") // ErrMissingJWT is returned when a JWT is missing from a request. - ErrMissingJWT = bittorrent.ClientError("unapproved request: missing jwt") + ErrMissingJWT = bittorrent.ClientError("request not allowed by mochi: missing jwt") // ErrInvalidJWT is returned when a JWT fails to verify. - ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt") + ErrInvalidJWT = bittorrent.ClientError("request not allowed by mochi: invalid jwt") errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL") @@ -248,7 +248,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, func (h *hook) getJWTString(params bittorrent.Params) (jwt string) { if params != nil { var found bool - if jwt, found = params.String(h.cfg.Header); found { + if jwt, found = params.GetString(h.cfg.Header); found { if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) { jwt = jwt[len(bearerAuthPrefix):] } diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go index 35ce4ee..ade3931 100644 --- a/middleware/jwt/jwt_test.go +++ b/middleware/jwt/jwt_test.go @@ -55,7 +55,7 @@ type JWKSKeys struct { type params map[string]string -func (p params) String(key string) (out string, found bool) { +func (p params) GetString(key string) (out string, found bool) { out, found = p[key] return } diff --git a/middleware/logic.go b/middleware/logic.go index 2476ff3..27d500c 100644 --- a/middleware/logic.go +++ b/middleware/logic.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "sort" "time" "github.com/sot-tech/mochi/bittorrent" @@ -80,14 +79,13 @@ func (l *Logic) AfterAnnounce(ctx context.Context, req *bittorrent.AnnounceReque func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) { logger.Debug().Object("request", req).Msg("new scrape request") resp = &bittorrent.ScrapeResponse{ - Files: make([]bittorrent.Scrape, 0, len(req.InfoHashes)), + Data: make([]bittorrent.Scrape, 0, len(req.InfoHashes)), } for _, h := range l.preHooks { if ctx, err = h.HandleScrape(ctx, req, resp); err != nil { return nil, nil, err } } - sort.Sort(&resp.Files) logger.Debug().Object("response", resp).Msg("generated scrape response") return ctx, resp, nil diff --git a/middleware/torrentapproval/torrentapproval.go b/middleware/torrentapproval/torrentapproval.go index 977f8b1..c873967 100644 --- a/middleware/torrentapproval/torrentapproval.go +++ b/middleware/torrentapproval/torrentapproval.go @@ -65,7 +65,7 @@ func build(config conf.MapConfig, st storage.PeerStorage) (h middleware.Hook, er } // ErrTorrentUnapproved is the error returned when a torrent hash is invalid. -var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent") +var ErrTorrentUnapproved = bittorrent.ClientError("torrent not allowed by mochi") type hook struct { hashContainer container.Container