diff --git a/bittorrent/params.go b/bittorrent/params.go index f97ac7a..e91977e 100644 --- a/bittorrent/params.go +++ b/bittorrent/params.go @@ -13,9 +13,9 @@ import ( // // See ParseURLData for specifics on parsing and limitations. type Params interface { - // String returns a string parsed from a query. Every key can be + // GetString returns a string parsed from a query. Every key can be // returned as a string because they are encoded in the URL as strings. - String(key string) (string, bool) + GetString(key string) (string, bool) zerolog.LogObjectMarshaler } diff --git a/bittorrent/peer.go b/bittorrent/peer.go index 3227e42..8d0b323 100644 --- a/bittorrent/peer.go +++ b/bittorrent/peer.go @@ -6,14 +6,12 @@ package bittorrent import ( "crypto/sha1" "crypto/sha256" - "encoding/binary" "encoding/hex" - "errors" "fmt" - "net" "net/netip" "github.com/rs/zerolog" + "github.com/sot-tech/mochi/pkg/str2bytes" ) // PeerIDLen is length of peer id field in bytes @@ -25,26 +23,31 @@ type PeerID [PeerIDLen]byte // ErrInvalidPeerIDSize holds error about invalid PeerID size var ErrInvalidPeerIDSize = fmt.Errorf("peer ID must be %d bytes", PeerIDLen) +var zeroPeerID PeerID + // NewPeerID creates a PeerID from a byte slice. // // It panics if b is not 20 bytes long. func NewPeerID(b []byte) (PeerID, error) { - var p PeerID if len(b) != PeerIDLen { - return p, ErrInvalidPeerIDSize + return zeroPeerID, ErrInvalidPeerIDSize } - copy(p[:], b) - return p, nil + return PeerID(b), nil +} + +// Bytes returns slice of bytes represents this PeerID +func (p PeerID) Bytes() []byte { + return p[:] } // String implements fmt.Stringer, returning the base16 encoded PeerID. func (p PeerID) String() string { - return hex.EncodeToString(p[:]) + return hex.EncodeToString(p.Bytes()) } // RawString returns a 20-byte string of the raw bytes of the ID. func (p PeerID) RawString() string { - return string(p[:]) + return str2bytes.BytesToString(p.Bytes()) } // InfoHash represents an infohash. @@ -55,16 +58,33 @@ const ( InfoHashV1Len = sha1.Size // InfoHashV2Len ... sha256.Size InfoHashV2Len = sha256.Size - // NoneInfoHash dummy invalid InfoHash - NoneInfoHash InfoHash = "" ) -var ( - // ErrInvalidHashType holds error about invalid InfoHash input type - ErrInvalidHashType = errors.New("info hash must be provided as byte slice or raw/hex string") - // ErrInvalidHashSize holds error about invalid InfoHash size - ErrInvalidHashSize = fmt.Errorf("info hash must be either %d (for torrent V1) or %d (V2) bytes", InfoHashV1Len, InfoHashV2Len) -) +// ErrInvalidHashSize holds error about invalid InfoHash size +var ErrInvalidHashSize = fmt.Errorf("info hash must be either %d (for torrent V1) or %d (V2) bytes or same sizes x2 (if HEX encoded)", InfoHashV1Len, InfoHashV2Len) + +// NewInfoHash creates an InfoHash from raw/hex byte slice. +func NewInfoHash(data []byte) (InfoHash, error) { + var ih InfoHash + switch l := len(data); l { + case InfoHashV1Len, InfoHashV2Len: + ih = InfoHash(data) + case InfoHashV1Len * 2, InfoHashV2Len * 2: + bb := make([]byte, l/2) + if _, err := hex.Decode(bb, data); err != nil { + return "", err + } + ih = InfoHash(str2bytes.BytesToString(bb)) + default: + return "", ErrInvalidHashSize + } + return ih, nil +} + +// NewInfoHashString creates an InfoHash from raw/hex string. +func NewInfoHashString(data string) (InfoHash, error) { + return NewInfoHash(str2bytes.StringToBytes(data)) +} // TruncateV1 returns truncated to 20-bytes length array of the corresponding InfoHash. // If InfoHash is V2 (32 bytes), it will be truncated to 20 bytes @@ -76,48 +96,14 @@ func (i InfoHash) TruncateV1() InfoHash { return i } -// NewInfoHash creates an InfoHash from a byte slice or raw/hex string. -func NewInfoHash(data any) (InfoHash, error) { - if data == nil { - return NoneInfoHash, ErrInvalidHashType - } - var ba []byte - switch t := data.(type) { - case [InfoHashV1Len]byte: - ba = t[:] - case [InfoHashV2Len]byte: - ba = t[:] - case []byte: - l := len(t) - if l == InfoHashV1Len*2 || l == InfoHashV2Len*2 { - ba = make([]byte, l/2) - if _, err := hex.Decode(ba, t); err != nil { - return NoneInfoHash, err - } - } else { - ba = t - } - case string: - l := len(t) - if l == InfoHashV1Len*2 || l == InfoHashV2Len*2 { - var err error - if ba, err = hex.DecodeString(t); err != nil { - return NoneInfoHash, err - } - } else { - ba = []byte(t) - } - } - l := len(ba) - if l != InfoHashV1Len && l != InfoHashV2Len { - return NoneInfoHash, ErrInvalidHashSize - } - return InfoHash(ba), nil +// Bytes returns slice of bytes represents this InfoHash +func (i InfoHash) Bytes() []byte { + return str2bytes.StringToBytes(string(i)) } // String implements fmt.Stringer, returning the base16 encoded InfoHash. func (i InfoHash) String() string { - return hex.EncodeToString([]byte(i)) + return hex.EncodeToString(i.Bytes()) } // RawString returns a string of the raw bytes of the InfoHash. @@ -132,47 +118,6 @@ type Peer struct { netip.AddrPort } -// PeerMinimumLen is the least allowed length of string serialized Peer -const PeerMinimumLen = PeerIDLen + 2 + net.IPv4len - -// ErrInvalidPeerDataSize holds error about invalid Peer data size -var ErrInvalidPeerDataSize = fmt.Errorf("invalid peer data it must be at least %d bytes (InfoHash + Port + IPv4)", PeerMinimumLen) - -// NewPeer constructs Peer from serialized by Peer.RawString data: PeerID[20by]Port[2by]net.IP[4/16by] -func NewPeer(data string) (Peer, error) { - var peer Peer - if len(data) < PeerMinimumLen { - return peer, ErrInvalidPeerDataSize - } - b := []byte(data) - peerID, err := NewPeerID(b[:PeerIDLen]) - if err == nil { - if addr, isOk := netip.AddrFromSlice(b[PeerIDLen+2:]); isOk { - peer = Peer{ - ID: peerID, - AddrPort: netip.AddrPortFrom( - addr.Unmap(), - binary.BigEndian.Uint16(b[PeerIDLen:PeerIDLen+2]), - ), - } - } else { - err = ErrInvalidIP - } - } - - return peer, err -} - -// RawString generates concatenation of PeerID, net port and IP-address -func (p Peer) RawString() string { - ip := p.Addr() - b := make([]byte, PeerIDLen+2+(ip.BitLen()/8)) - copy(b[:PeerIDLen], p.ID[:]) - binary.BigEndian.PutUint16(b[PeerIDLen:PeerIDLen+2], p.Port()) - copy(b[PeerIDLen+2:], ip.AsSlice()) - return string(b) -} - // Addr returns unmapped peer's IP address func (p Peer) Addr() netip.Addr { return p.AddrPort.Addr().Unmap() diff --git a/bittorrent/request.go b/bittorrent/request.go index 150f87d..ad2a7ff 100644 --- a/bittorrent/request.go +++ b/bittorrent/request.go @@ -238,24 +238,10 @@ func (s Scrape) MarshalZerologObject(e *zerolog.Event) { // Scrapes wrapper of array of Scrape-s type Scrapes []Scrape -func (s *Scrapes) Len() int { - return len(*s) -} - -func (s *Scrapes) Less(i, j int) bool { - return (*s)[i].InfoHash < (*s)[j].InfoHash -} - -func (s *Scrapes) Swap(i, j int) { - (*s)[i], (*s)[j] = (*s)[j], (*s)[i] -} - // MarshalZerologArray writes array elements to zerolog event -func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) { - if s != nil { - for _, scrape := range *s { - a.Object(scrape) - } +func (s Scrapes) MarshalZerologArray(a *zerolog.Array) { + for _, scrape := range s { + a.Object(scrape) } } @@ -264,10 +250,10 @@ func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) { // The Scrapes must be in the same order as the InfoHashes in the corresponding // ScrapeRequest. type ScrapeResponse struct { - Files Scrapes + Data Scrapes } // MarshalZerologObject writes fields into zerolog event func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) { - e.Array("scrapes", &sr.Files) + e.Array("scrapes", sr.Data) } diff --git a/cmd/mochi/server_test.go b/cmd/mochi/server_test.go new file mode 100644 index 0000000..5af270b --- /dev/null +++ b/cmd/mochi/server_test.go @@ -0,0 +1,288 @@ +package main + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "math" + "math/rand" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + cr "crypto/rand" + + "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/frontend" + hf "github.com/sot-tech/mochi/frontend/http" + l "github.com/sot-tech/mochi/pkg/log" + "github.com/sot-tech/mochi/pkg/str2bytes" +) + +const ( + connectUDPPacketSize = 16 + announceUDPPacketSize = 98 + announceNumWant = 10 + + timeout = time.Second * 2 +) + +var ( + udpConnectHeader = []byte{0x0, 0x0, 0x4, 0x17, 0x27, 0x10, 0x19, 0x80} + errUDPSendTruncated = errors.New("data not fully sent") + errUDPRecvTruncated = errors.New("data not fully received") + errUDPUnexpectedAction = errors.New("unexpected action") + errTxIDMissmatch = errors.New("transaction ID missmatch") + hashes = make([][]byte, 100) + peers = make([][]byte, 100) +) + +func init() { + _ = l.ConfigureLogger("", "error", false, false) + for i := range hashes { + bb := make([]byte, bittorrent.InfoHashV1Len) + if _, err := cr.Read(bb); err != nil { + panic(err) + } + hashes[i] = bb + } + + for i := range peers { + bb := make([]byte, bittorrent.PeerIDLen) + if _, err := cr.Read(bb); err != nil { + panic(err) + } + peers[i] = bb + } +} + +func buildUDPConnReq() []byte { + req := make([]byte, connectUDPPacketSize) + copy(req, udpConnectHeader) + + // TxID + binary.BigEndian.PutUint32(req[12:16], rand.Uint32()) + return req +} + +func sendUDPConnReq(addr string) ([]byte, []byte, error) { + req := buildUDPConnReq() + + conn, err := net.Dial("udp", addr) + if err != nil { + return nil, nil, err + } + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + + defer conn.Close() + + n, err := conn.Write(req) + if err != nil { + return nil, nil, err + } + + if n != len(req) { + return nil, nil, errUDPSendTruncated + } + + resp := make([]byte, 16) + n, err = conn.Read(resp) + if err != nil { + return nil, nil, err + } + + if n != len(resp) { + return nil, nil, errUDPRecvTruncated + } + + action := binary.BigEndian.Uint32(resp[:4]) + if action != 0 { + return nil, nil, errUDPUnexpectedAction + } + + if !bytes.Equal(resp[4:8], req[12:16]) { + return nil, nil, errTxIDMissmatch + } + + // TxID, ConnectionID + return resp[4:8], resp[8:16], nil +} + +func buildAnnounceUDPReq(txID, connID []byte) []byte { + req := make([]byte, announceUDPPacketSize) + + // Connection ID + copy(req[:8], connID) + + // Action + req[11] = 1 + + // TxID + copy(req[12:16], txID) + + // InfoHash + copy(req[16:36], hashes[rand.Intn(len(hashes))]) + + // PeerID + copy(req[36:56], peers[rand.Intn(len(peers))]) + + var down, left uint64 + if rand.Intn(2) == 0 { + down, left = 1, 0 + } else { + down, left = 0, 1 + } + // Downloaded + binary.BigEndian.PutUint64(req[56:64], down) + // Left + binary.BigEndian.PutUint64(req[64:72], left) + + // Event + req[83] = 1 + + // Numwant + req[92], req[95] = byte(announceNumWant>>24), byte(announceNumWant>>16) + + // Port + p := rand.Intn(math.MaxInt16) + 1 + req[96], req[97] = byte(p>>8), byte(p) + return req +} + +func BenchmarkServerUDPAnnounce(b *testing.B) { + var s Server + if err := s.Run(QuickConfig); err != nil { + b.Fatal(err) + } + defer s.Shutdown() + + addr := "127.0.0.1" + frontend.DefaultListenAddress + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + txID, connID, err := sendUDPConnReq(addr) + if err != nil { + b.Log(err) + return + } + req := buildAnnounceUDPReq(txID, connID) + + conn, err := net.Dial("udp", addr) + if err != nil { + b.Log(err) + return + } + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + n, err := conn.Write(req) + if err != nil { + _ = conn.Close() + b.Log(err) + return + } + if n != announceUDPPacketSize { + _ = conn.Close() + b.Logf("packet not fully sent, %d bytes instead of %d", n, announceUDPPacketSize) + return + } + resp := make([]byte, 1024) + n, err = conn.Read(resp) + _ = conn.Close() + if err != nil { + b.Log(err) + return + } + if n < 20 { + b.Logf("packet not fully received, only %d bytes", n) + return + } + action := binary.BigEndian.Uint32(resp[:4]) + if action != 1 { + if action == 3 { + errVal := string(resp[8:n]) + b.Logf("tracker error: %s", errVal) + } else { + b.Logf("unexpected action: %d", action) + } + return + } + + if !bytes.Equal(resp[4:8], req[12:16]) { + b.Log("transaction ID missmatch") + } + } + }) +} + +func sendHTTPReq(u string) (err error) { + var r *http.Response + // nolint:gosec + if r, err = http.Get(u); err == nil { + defer r.Body.Close() + if r.StatusCode < 400 { + var out []byte + if out, err = io.ReadAll(r.Body); err == nil { + sout := string(out) + if strings.Contains(sout, "failure reason") { + return errors.New(sout) + } + } + } else { + return errors.New(r.Status) + } + } + return +} + +func BenchmarkServerHTTPAnnounce(b *testing.B) { + var s Server + if err := s.Run(QuickConfig); err != nil { + b.Fatal(err) + } + defer s.Shutdown() + + reqs := make([]string, len(hashes)*len(peers)) + addr := "127.0.0.1" + frontend.DefaultListenAddress + for i := range reqs { + var down, left string + if rand.Intn(2) == 0 { + down, left = "1", "0" + } else { + down, left = "0", "1" + } + u := url.URL{ + Scheme: "http", + Host: addr, + Path: hf.DefaultAnnounceRoute, + RawQuery: url.Values{ + "event": []string{bittorrent.StartedStr}, + "compact": []string{"1"}, + "left": []string{left}, + "downloaded": []string{down}, + "uploaded": []string{"0"}, + "numwant": []string{"1"}, + "port": []string{strconv.FormatInt(int64(rand.Intn(math.MaxInt16)+1), 10)}, + "info_hash": []string{str2bytes.BytesToString(hashes[rand.Intn(len(hashes))])}, + "peer_id": []string{str2bytes.BytesToString(peers[rand.Intn(len(peers))])}, + }.Encode(), + } + reqs[i] = u.String() + } + + var cnt atomic.Uint32 + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := sendHTTPReq(reqs[int(cnt.Add(1))%len(reqs)]); err != nil { + b.Log(err) + } + } + }) +} diff --git a/dist/example_config.yaml b/dist/example_config.yaml index e5d7afb..34a9093 100644 --- a/dist/example_config.yaml +++ b/dist/example_config.yaml @@ -190,10 +190,10 @@ prehooks: # # - name: client approval # config: -# whitelist: +# client_id_list: # - "OP1011" -# blacklist: -# - "OP1012" +# true - whitelist mode, false - blacklist +# invert: true # # - name: interval variation # config: diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index 13df09d..58a113c 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -51,11 +51,15 @@ type Config struct { } const ( - defaultReadTimeout = 2 * time.Second - defaultWriteTimeout = 2 * time.Second - defaultIdleTimeout = 30 * time.Second - defaultAnnounceRoute = "/announce" - defaultScrapeRoute = "/scrape" + defaultReadTimeout = 2 * time.Second + defaultWriteTimeout = 2 * time.Second + defaultIdleTimeout = 30 * time.Second + // DefaultAnnounceRoute is the default url path to listen announce + // requests if nothing else provided + DefaultAnnounceRoute = "/announce" + // DefaultScrapeRoute is the default url path to listen scrape + // requests if nothing else provided + DefaultScrapeRoute = "/scrape" ) // Validate sanity checks values set in a config and returns a new config with @@ -98,7 +102,7 @@ func (cfg Config) Validate() (validCfg Config, err error) { } } if len(cfg.AnnounceRoutes) == 0 { - validCfg.AnnounceRoutes = []string{defaultAnnounceRoute} + validCfg.AnnounceRoutes = []string{DefaultAnnounceRoute} logger.Warn(). Str("name", "AnnounceRoutes"). Strs("provided", cfg.AnnounceRoutes). @@ -106,7 +110,7 @@ func (cfg Config) Validate() (validCfg Config, err error) { Msg("falling back to default configuration") } if len(cfg.ScrapeRoutes) == 0 { - validCfg.ScrapeRoutes = []string{defaultScrapeRoute} + validCfg.ScrapeRoutes = []string{DefaultScrapeRoute} logger.Warn(). Str("name", "ScrapeRoutes"). Strs("provided", cfg.ScrapeRoutes). @@ -211,7 +215,10 @@ func runServer(s *fasthttp.Server, cfg *Config) { err = s.ServeTLS(ln, "", "") } } - if !errors.Is(err, http.ErrServerClosed) { + defer ln.Close() + if err == nil { + logger.Info().Msg("server stopped") + } else if !errors.Is(err, http.ErrServerClosed) { logger.Fatal().Err(err).Msg("server failed") } } diff --git a/frontend/http/frontend_test.go b/frontend/http/frontend_test.go index 0bf4779..f47271b 100644 --- a/frontend/http/frontend_test.go +++ b/frontend/http/frontend_test.go @@ -17,8 +17,6 @@ import ( "github.com/sot-tech/mochi/pkg/log" ) -const iterations = 10000 - var ( addr = fmt.Sprintf("127.0.0.1:%d", rand.Int63n(10000)+16384) hashes = make([]string, 10) @@ -87,7 +85,7 @@ func BenchmarkPing(b *testing.B) { Path: "ping", } us := u.String() - for i := 0; i < iterations; i++ { + for i := 0; i < b.N; i++ { if err := runGet(us, false); err != nil { b.Error(err) } @@ -95,11 +93,11 @@ func BenchmarkPing(b *testing.B) { } func BenchmarkAnnounce(b *testing.B) { - for i := 0; i < iterations; i++ { + for i := 0; i < b.N; i++ { u := url.URL{ Scheme: "http", Host: addr, - Path: defaultAnnounceRoute, + Path: DefaultAnnounceRoute, RawQuery: url.Values{ "event": []string{bittorrent.StartedStr}, "compact": []string{"1"}, @@ -119,11 +117,11 @@ func BenchmarkAnnounce(b *testing.B) { } func BenchmarkScrape(b *testing.B) { - for i := 0; i < iterations; i++ { + for i := 0; i < b.N; i++ { u := url.URL{ Scheme: "http", Host: addr, - Path: defaultScrapeRoute, + Path: DefaultScrapeRoute, RawQuery: url.Values{"info_hash": hashes[:len(hashes)/2]}.Encode(), } if err := runGet(u.String(), true); err != nil { diff --git a/frontend/http/params.go b/frontend/http/params.go index 3e9f759..60886e6 100644 --- a/frontend/http/params.go +++ b/frontend/http/params.go @@ -3,6 +3,7 @@ package http import ( "github.com/rs/zerolog" "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/pkg/str2bytes" "github.com/valyala/fasthttp" ) @@ -12,11 +13,11 @@ type queryParams struct { *fasthttp.Args } -// String returns a string parsed from a query. Every key can be returned as a +// GetString returns a string parsed from a query. Every key can be returned as a // string because they are encoded in the URL as strings. -func (qp queryParams) String(key string) (string, bool) { +func (qp queryParams) GetString(key string) (string, bool) { v := qp.Peek(key) - return string(v), v != nil + return str2bytes.BytesToString(v), v != nil } // InfoHashes returns a list of requested infohashes. diff --git a/frontend/http/parser.go b/frontend/http/parser.go index 6542496..103d67b 100644 --- a/frontend/http/parser.go +++ b/frontend/http/parser.go @@ -5,10 +5,11 @@ import ( "errors" "net/netip" + "github.com/valyala/fasthttp" + "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/frontend" - - "github.com/valyala/fasthttp" + "github.com/sot-tech/mochi/pkg/str2bytes" ) // ParseOptions is the configuration used to parse an Announce Request. @@ -40,7 +41,7 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou // Attempt to parse the event from the request. var eventStr string var err error - eventStr, request.EventProvided = qp.String("event") + eventStr, request.EventProvided = qp.GetString("event") if request.EventProvided { if request.Event, err = bittorrent.NewEvent(eventStr); err != nil { return nil, err @@ -57,7 +58,6 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou if len(infoHashes) > 1 { return nil, errMultipleInfoHashes } - // FIXME: make sure that we have a copy of InfoHash request.InfoHash = infoHashes[0] // Parse the PeerID from the request. @@ -137,7 +137,7 @@ func parseScrape(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.ScrapeR func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) { if opts.AllowIPSpoofing { for _, f := range []string{"ip", "ipv4", "ipv6"} { - if ipStr, ok := p.String(f); ok { + if ipStr, ok := p.GetString(f); ok { addresses.Add(parseRequestAddress(ipStr, true)) } } @@ -147,7 +147,7 @@ func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (ad for _, ipStr := range ipValues { for _, ipStr := range bytes.Split(ipStr, []byte{','}) { if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 { - addresses.Add(parseRequestAddress(string(ipStr), false)) + addresses.Add(parseRequestAddress(str2bytes.BytesToString(ipStr), false)) } } } diff --git a/frontend/http/writer.go b/frontend/http/writer.go index b384703..7645b51 100644 --- a/frontend/http/writer.go +++ b/frontend/http/writer.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net" + "sort" "strconv" "time" @@ -17,7 +18,7 @@ import ( var respBufferPool = bytepool.NewBufferPool() func writeErrorResponse(w io.StringWriter, err error) { - message := "internal server error" + message := "mochi internal error" var clientErr bittorrent.ClientError if errors.As(err, &clientErr) { message = clientErr.Error() @@ -96,28 +97,36 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) { bb.WriteString(addr) if includePeerID { bb.WriteString("7:peer id20:") - bb.Write(peer.ID[:]) + bb.Write(peer.ID.Bytes()) } bb.WriteString("4:porti") - bb.Write(fasthttp.AppendUint(nil, int(peer.Port()))) - bb.Write([]byte{'e', 'e'}) + port := peer.Port() + bb.Write([]byte{byte(port >> 8), byte(port), 'e', 'e'}) } func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) { bb := respBufferPool.Get() defer respBufferPool.Put(bb) bb.WriteString("d5:filesd") - for _, scrape := range resp.Files { - bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash))) - bb.WriteByte(':') - bb.Write([]byte(scrape.InfoHash)) - bb.WriteString("d8:completei") - bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete))) - bb.WriteString("e10:downloadedi") - bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches))) - bb.WriteString("e10:incompletei") - bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete))) - bb.WriteString("ee") + l := len(resp.Data) + if l > 0 { + if l > 1 { + sort.Slice(resp.Data, func(i, j int) bool { + return resp.Data[i].InfoHash < resp.Data[j].InfoHash + }) + } + for _, scrape := range resp.Data { + bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash))) + bb.WriteByte(':') + bb.Write([]byte(scrape.InfoHash)) + bb.WriteString("d8:completei") + bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete))) + bb.WriteString("e10:downloadedi") + bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches))) + bb.WriteString("e10:incompletei") + bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete))) + bb.Write([]byte{'e', 'e'}) + } } bb.Write([]byte{'e', 'e'}) _, _ = bb.WriteTo(w) diff --git a/frontend/options.go b/frontend/options.go index 61fc631..ba315c7 100644 --- a/frontend/options.go +++ b/frontend/options.go @@ -9,7 +9,9 @@ import ( "github.com/libp2p/go-reuseport" ) -const defaultListenAddress = ":6969" +// DefaultListenAddress is the default net listen address +// if nothing else provided +const DefaultListenAddress = ":6969" var errUnexpectedListenerType = errors.New("unexpected listener type") @@ -26,7 +28,7 @@ type ListenOptions struct { func (lo ListenOptions) Validate(logger *log.Logger) (validOptions ListenOptions) { validOptions = lo if len(lo.Addr) == 0 { - validOptions.Addr = defaultListenAddress + validOptions.Addr = DefaultListenAddress logger.Warn(). Str("name", "Addr"). Str("provided", lo.Addr). diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index de6f4fa..4e9a350 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -34,7 +34,7 @@ const ( var ( logger = log.NewLogger("frontend/udp") - allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") + allowedGeneratedPrivateKeyRunes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" ) func init() { @@ -57,16 +57,16 @@ func (cfg Config) Validate() (validCfg Config) { validCfg.ListenOptions = cfg.ListenOptions.Validate(logger) if cfg.Workers == 0 { - cfg.Workers = 1 + validCfg.Workers = 1 } - if cfg.Workers > 1 && !cfg.ReusePort { - cfg.ReusePort = true + if validCfg.Workers > 1 && !validCfg.ReusePort { + validCfg.ReusePort = true logger.Warn().Msg("forcibly enabling ReusePort because Workers > 1") } // Generate a private key if one isn't provided by the user. if cfg.PrivateKey == "" { - pkeyRunes := make([]rune, defaultKeyLen) + pkeyRunes := make([]byte, defaultKeyLen) for i := range pkeyRunes { pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))] } diff --git a/frontend/udp/params.go b/frontend/udp/params.go index 59a10e9..3f6d78b 100644 --- a/frontend/udp/params.go +++ b/frontend/udp/params.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog" "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/pkg/str2bytes" ) // ErrInvalidQueryEscape is returned when a query string contains invalid @@ -60,7 +61,7 @@ func parseQuery(query []byte) (q *queryParams, err error) { key, value = key[:i], key[i+1:] } var k, v string - k, err = url.QueryUnescape(string(key)) + k, err = url.QueryUnescape(str2bytes.BytesToString(key)) if err != nil { // QueryUnescape returns an error like "invalid escape: '%x'". // But frontends record these errors to prometheus, which generates @@ -68,7 +69,7 @@ func parseQuery(query []byte) (q *queryParams, err error) { // We log it here for debugging instead. return nil, ErrInvalidQueryEscape } - v, err = url.QueryUnescape(string(value)) + v, err = url.QueryUnescape(str2bytes.BytesToString(value)) if err != nil { // QueryUnescape returns an error like "invalid escape: '%x'". // But frontends record these errors to prometheus, which generates @@ -83,9 +84,9 @@ func parseQuery(query []byte) (q *queryParams, err error) { return q, nil } -// String returns a string parsed from a query. Every key can be returned as a +// GetString returns a string parsed from a query. Every key can be returned as a // string because they are encoded in the URL as strings. -func (qp queryParams) String(key string) (string, bool) { +func (qp queryParams) GetString(key string) (string, bool) { value, ok := qp.params[strings.ToLower(key)] return value, ok } diff --git a/frontend/udp/parser.go b/frontend/udp/parser.go index 2117392..6f229bb 100644 --- a/frontend/udp/parser.go +++ b/frontend/udp/parser.go @@ -71,7 +71,6 @@ func parseAnnounce(r Request, v6Action bool, opts frontend.ParseOptions) (*bitto // XXX: pure V2 hashes will cause invalid parsing, // but BEP-52 says, that V2 hashes SHOULD be truncated - // FIXME: make sure that we have a copy of InfoHash request.InfoHash, err = bittorrent.NewInfoHash(r.Packet[16:36]) if err != nil { return nil, errInvalidInfoHash @@ -178,7 +177,6 @@ func parseScrape(r Request, opts frontend.ParseOptions) (*bittorrent.ScrapeReque var request *bittorrent.ScrapeRequest for len(r.Packet) >= bittorrent.InfoHashV1Len { var ih bittorrent.InfoHash - // FIXME: make sure that we have a copy of InfoHash if ih, err = bittorrent.NewInfoHash(r.Packet[:bittorrent.InfoHashV1Len]); err == nil { infoHashes = append(infoHashes, ih) r.Packet = r.Packet[bittorrent.InfoHashV1Len:] diff --git a/frontend/udp/parser_test.go b/frontend/udp/parser_test.go index 585327f..517b44a 100644 --- a/frontend/udp/parser_test.go +++ b/frontend/udp/parser_test.go @@ -64,7 +64,7 @@ func TestHandleOptionalParameters(t *testing.T) { t.Fatalf("expected values %v for %x", tt.values, tt.data) } else { for key, want := range tt.values { - if got, ok := params.String(key); !ok { + if got, ok := params.GetString(key); !ok { t.Fatalf("params missing entry %s for data %x", key, tt.data) } else if got != want { t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data) diff --git a/frontend/udp/writer.go b/frontend/udp/writer.go index 61c195b..57cb00d 100644 --- a/frontend/udp/writer.go +++ b/frontend/udp/writer.go @@ -14,13 +14,17 @@ func writeErrorResponse(w io.Writer, txID []byte, err error) { buf := reqRespBufferPool.Get() defer reqRespBufferPool.Put(buf) writeHeader(buf, txID, errorActionID) + message := "mochi internal error" + var clientErr bittorrent.ClientError // If the client wasn't at fault, acknowledge it. - if !errors.As(err, new(bittorrent.ClientError)) { - buf.WriteString("internal error occurred: ") + if errors.As(err, &clientErr) { + message = clientErr.Error() + } else { + logger.Error().Err(err).Msg("internal error") } - buf.WriteString(err.Error()) + buf.WriteString(message) buf.WriteByte('\000') - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeAnnounceResponse encodes an announce response according to BEP 15. @@ -51,7 +55,7 @@ func writeAnnounceResponse(w io.Writer, txID []byte, resp *bittorrent.AnnounceRe _ = binary.Write(buf, binary.BigEndian, peer.Port()) } - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeScrapeResponse encodes a scrape response according to BEP 15. @@ -61,13 +65,12 @@ func writeScrapeResponse(w io.Writer, txID []byte, resp *bittorrent.ScrapeRespon writeHeader(buf, txID, scrapeActionID) - for _, scrape := range resp.Files { + for _, scrape := range resp.Data { _ = binary.Write(buf, binary.BigEndian, scrape.Complete) _ = binary.Write(buf, binary.BigEndian, scrape.Snatches) _ = binary.Write(buf, binary.BigEndian, scrape.Incomplete) } - - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeConnectionID encodes a new connection response according to BEP 15. @@ -77,7 +80,7 @@ func writeConnectionID(w io.Writer, txID, connID []byte) { writeHeader(buf, txID, connectActionID) buf.Write(connID) - _, _ = w.Write(buf.Bytes()) + _, _ = buf.WriteTo(w) } // writeHeader writes the action and transaction ID to the provided response diff --git a/go.mod b/go.mod index a3ba01b..775e4f5 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,11 @@ module github.com/sot-tech/mochi -go 1.19 +go 1.20 require ( code.cloudfoundry.org/go-diodes v0.0.0-20230317203753-49f1af6d2f1a github.com/MicahParks/keyfunc v1.9.0 - github.com/anacrolix/torrent v1.48.0 + github.com/anacrolix/torrent v1.49.0 github.com/cespare/xxhash/v2 v2.2.0 github.com/golang-jwt/jwt/v4 v4.5.0 github.com/jackc/pgx/v5 v5.3.1 @@ -41,7 +41,7 @@ require ( github.com/klauspost/compress v1.16.3 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.17 // indirect + github.com/mattn/go-isatty v0.0.18 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/onsi/ginkgo/v2 v2.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index c47f4d5..e1195c1 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,8 @@ github.com/anacrolix/stm v0.2.0/go.mod h1:zoVQRvSiGjGoTmbM0vSLIiaKjWtNPeTvXUSdJQ github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= github.com/anacrolix/tagflag v1.0.0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= github.com/anacrolix/tagflag v1.1.0/go.mod h1:Scxs9CV10NQatSmbyjqmqmeQNwGzlNe0CMUMIxqHIG8= -github.com/anacrolix/torrent v1.48.0 h1:OQe1aQb8WnhDzpcI7r3yWoHzHWKyPbfhXGfO9Q/pvbY= -github.com/anacrolix/torrent v1.48.0/go.mod h1:3UtkJ8BnxXDRwvk+eT+uwiZalfFJ8YzAhvxe4QRPSJI= +github.com/anacrolix/torrent v1.49.0 h1:v/TAd8BKsZarYEYv7VkPNv8tY5zZCwQyxMMlKKbAF4I= +github.com/anacrolix/torrent v1.49.0/go.mod h1:qT3yS5oQwDUHnBXy+zf3nozLPudG7SFNDL3Jl/zQwFw= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= @@ -176,8 +176,8 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= -github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= +github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= diff --git a/middleware/clientapproval/client_id.go b/middleware/clientapproval/client_id.go index 601bc6a..6be8605 100644 --- a/middleware/clientapproval/client_id.go +++ b/middleware/clientapproval/client_id.go @@ -4,7 +4,9 @@ // - https://github.com/webtorrent/bittorrent-peerid/blob/master/lib/utils.js package clientapproval -import "github.com/sot-tech/mochi/bittorrent" +import ( + "github.com/sot-tech/mochi/bittorrent" +) // ClientID represents the part of a PeerID that identifies a Peer's client // software. @@ -14,14 +16,9 @@ type ClientID [6]byte func NewClientID(pid bittorrent.PeerID) ClientID { var cid ClientID if pid[0] == '-' { - copy(cid[:], pid[1:7]) + cid = ClientID(pid[1:7]) } else { - copy(cid[:], pid[:6]) + cid = ClientID(pid[:6]) } - return cid } - -func (cid ClientID) String() string { - return string(cid[:]) -} diff --git a/middleware/clientapproval/clientapproval.go b/middleware/clientapproval/clientapproval.go index 8c272dd..ef9e6a9 100644 --- a/middleware/clientapproval/clientapproval.go +++ b/middleware/clientapproval/clientapproval.go @@ -20,23 +20,21 @@ func init() { middleware.RegisterBuilder(Name, build) } -var ( - // ErrClientUnapproved is the error returned when a client's PeerID is invalid. - ErrClientUnapproved = bittorrent.ClientError("unapproved client") - - errBothListsProvided = errors.New("using both whitelist and blacklist is invalid") -) +// ErrClientUnapproved is the error returned when a client's PeerID is invalid. +var ErrClientUnapproved = bittorrent.ClientError("client not allowed by mochi") // Config represents all the values required by this middleware to validate // peers based on their BitTorrent client ID. type Config struct { - Whitelist []string - Blacklist []string + // Static list of client IDs. + ClientIDList []string `cfg:"client_id_list"` + // If Invert set to true, all client IDs stored in ClientIDList should be blacklisted. + Invert bool } type hook struct { - approved map[ClientID]struct{} - unapproved map[ClientID]struct{} + clientIDs map[ClientID]any + invert bool } func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) { @@ -47,53 +45,31 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error } h := &hook{ - approved: make(map[ClientID]struct{}), - unapproved: make(map[ClientID]struct{}), + clientIDs: make(map[ClientID]any, len(cfg.ClientIDList)), + invert: cfg.Invert, } - if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 { - return nil, errBothListsProvided - } - - for _, cidString := range cfg.Whitelist { + for _, cidString := range cfg.ClientIDList { cidBytes := []byte(cidString) if len(cidBytes) != 6 { return nil, errors.New("client ID " + cidString + " must be 6 bytes") } - var cid ClientID - copy(cid[:], cidBytes) - h.approved[cid] = struct{}{} - } - - for _, cidString := range cfg.Blacklist { - cidBytes := []byte(cidString) - if len(cidBytes) != 6 { - return nil, errors.New("client ID " + cidString + " must be 6 bytes") - } - var cid ClientID - copy(cid[:], cidBytes) - h.unapproved[cid] = struct{}{} + h.clientIDs[ClientID(cidBytes)] = true } return h, nil } +// HandleAnnounce checks if specified ClientID is approved or not. +// If Config.Invert set to true and hash found in provided list, function will return ErrClientUnapproved, +// that means that ClientID is blacklisted. func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) { - clientID := NewClientID(req.ID) - - if len(h.approved) > 0 { - if _, found := h.approved[clientID]; !found { - return ctx, ErrClientUnapproved - } + var err error + if _, contains := h.clientIDs[NewClientID(req.ID)]; contains == h.invert { + err = ErrClientUnapproved } - if len(h.unapproved) > 0 { - if _, found := h.unapproved[clientID]; found { - return ctx, ErrClientUnapproved - } - } - - return ctx, nil + return ctx, err } func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) { diff --git a/middleware/clientapproval/clientapproval_test.go b/middleware/clientapproval/clientapproval_test.go index 8dabd2a..7c85ea2 100644 --- a/middleware/clientapproval/clientapproval_test.go +++ b/middleware/clientapproval/clientapproval_test.go @@ -19,7 +19,7 @@ var cases = []struct { // Client ID is whitelisted { Config{ - Whitelist: []string{"010203"}, + ClientIDList: []string{"010203"}, }, "01020304050607080900", true, @@ -27,7 +27,7 @@ var cases = []struct { // Client ID is not whitelisted { Config{ - Whitelist: []string{"010203"}, + ClientIDList: []string{"010203"}, }, "10203040506070809000", false, @@ -35,7 +35,8 @@ var cases = []struct { // Client ID is not blacklisted { Config{ - Blacklist: []string{"010203"}, + ClientIDList: []string{"010203"}, + Invert: true, }, "00000000001234567890", true, @@ -43,7 +44,8 @@ var cases = []struct { // Client ID is blacklisted { Config{ - Blacklist: []string{"123456"}, + ClientIDList: []string{"123456"}, + Invert: true, }, "12345678900000000000", false, @@ -53,7 +55,7 @@ var cases = []struct { func TestHandleAnnounce(t *testing.T) { for _, tt := range cases { t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) { - c := conf.MapConfig{"whitelist": tt.cfg.Whitelist, "blacklist": tt.cfg.Blacklist} + c := conf.MapConfig{"client_id_list": tt.cfg.ClientIDList, "invert": tt.cfg.Invert} h, err := build(c, nil) require.Nil(t, err) @@ -67,7 +69,7 @@ func TestHandleAnnounce(t *testing.T) { nctx, err := h.HandleAnnounce(ctx, req, resp) require.Equal(t, ctx, nctx) - if tt.approved == true { + if tt.approved { require.NotEqual(t, err, ErrClientUnapproved) } else { require.Equal(t, err, ErrClientUnapproved) diff --git a/middleware/hooks.go b/middleware/hooks.go index 11c6861..ce7ec89 100644 --- a/middleware/hooks.go +++ b/middleware/hooks.go @@ -223,7 +223,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR if err != nil { return } - resp.Files = append(resp.Files, scr) + resp.Data = append(resp.Data, scr) } return ctx, nil diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 65c6eaf..aa23c3b 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -34,10 +34,10 @@ func init() { var ( logger = log.NewLogger("middleware/jwt") // ErrMissingJWT is returned when a JWT is missing from a request. - ErrMissingJWT = bittorrent.ClientError("unapproved request: missing jwt") + ErrMissingJWT = bittorrent.ClientError("request not allowed by mochi: missing jwt") // ErrInvalidJWT is returned when a JWT fails to verify. - ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt") + ErrInvalidJWT = bittorrent.ClientError("request not allowed by mochi: invalid jwt") errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL") @@ -160,7 +160,7 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque err = ErrInvalidJWT } else { var claimIH bittorrent.InfoHash - if claimIH, err = bittorrent.NewInfoHash(claims.InfoHash); err != nil { + if claimIH, err = bittorrent.NewInfoHashString(claims.InfoHash); err != nil { logger.Info(). Err(err). Object("source", req.RequestPeer). @@ -206,7 +206,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, } else { var claimIHs bittorrent.InfoHashes for _, s := range claims.InfoHashes { - if providedIh, err := bittorrent.NewInfoHash(s); err == nil { + if providedIh, err := bittorrent.NewInfoHashString(s); err == nil { claimIHs = append(claimIHs, providedIh) } else { logger.Info(). @@ -248,7 +248,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, func (h *hook) getJWTString(params bittorrent.Params) (jwt string) { if params != nil { var found bool - if jwt, found = params.String(h.cfg.Header); found { + if jwt, found = params.GetString(h.cfg.Header); found { if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) { jwt = jwt[len(bearerAuthPrefix):] } diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go index 35ce4ee..ade3931 100644 --- a/middleware/jwt/jwt_test.go +++ b/middleware/jwt/jwt_test.go @@ -55,7 +55,7 @@ type JWKSKeys struct { type params map[string]string -func (p params) String(key string) (out string, found bool) { +func (p params) GetString(key string) (out string, found bool) { out, found = p[key] return } diff --git a/middleware/logic.go b/middleware/logic.go index 2476ff3..27d500c 100644 --- a/middleware/logic.go +++ b/middleware/logic.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "sort" "time" "github.com/sot-tech/mochi/bittorrent" @@ -80,14 +79,13 @@ func (l *Logic) AfterAnnounce(ctx context.Context, req *bittorrent.AnnounceReque func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) { logger.Debug().Object("request", req).Msg("new scrape request") resp = &bittorrent.ScrapeResponse{ - Files: make([]bittorrent.Scrape, 0, len(req.InfoHashes)), + Data: make([]bittorrent.Scrape, 0, len(req.InfoHashes)), } for _, h := range l.preHooks { if ctx, err = h.HandleScrape(ctx, req, resp); err != nil { return nil, nil, err } } - sort.Sort(&resp.Files) logger.Debug().Object("response", resp).Msg("generated scrape response") return ctx, resp, nil diff --git a/middleware/torrentapproval/container/list/list.go b/middleware/torrentapproval/container/list/list.go index cd61d25..f857b4a 100644 --- a/middleware/torrentapproval/container/list/list.go +++ b/middleware/torrentapproval/container/list/list.go @@ -56,7 +56,7 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er if len(c.HashList) > 0 { init := make([]storage.Entry, 0, len(c.HashList)) for _, hashString := range c.HashList { - ih, err := bittorrent.NewInfoHash(hashString) + ih, err := bittorrent.NewInfoHashString(hashString) if err != nil { return nil, fmt.Errorf("whitelist : %s : %w", hashString, err) } diff --git a/middleware/torrentapproval/torrentapproval.go b/middleware/torrentapproval/torrentapproval.go index 977f8b1..c873967 100644 --- a/middleware/torrentapproval/torrentapproval.go +++ b/middleware/torrentapproval/torrentapproval.go @@ -65,7 +65,7 @@ func build(config conf.MapConfig, st storage.PeerStorage) (h middleware.Hook, er } // ErrTorrentUnapproved is the error returned when a torrent hash is invalid. -var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent") +var ErrTorrentUnapproved = bittorrent.ClientError("torrent not allowed by mochi") type hook struct { hashContainer container.Container diff --git a/middleware/torrentapproval/torrentapproval_test.go b/middleware/torrentapproval/torrentapproval_test.go index 8bf69ba..a6ca2b9 100644 --- a/middleware/torrentapproval/torrentapproval_test.go +++ b/middleware/torrentapproval/torrentapproval_test.go @@ -84,7 +84,7 @@ func TestHandleAnnounce(t *testing.T) { req := &bittorrent.AnnounceRequest{} resp := &bittorrent.AnnounceResponse{} - hashinfo, err := bittorrent.NewInfoHash(tt.ih) + hashinfo, err := bittorrent.NewInfoHashString(tt.ih) require.Nil(t, err) req.InfoHash = hashinfo diff --git a/pkg/str2bytes/str2bytes.go b/pkg/str2bytes/str2bytes.go new file mode 100644 index 0000000..1309c81 --- /dev/null +++ b/pkg/str2bytes/str2bytes.go @@ -0,0 +1,21 @@ +// Package str2bytes provides fast, but unsafe functions to convert string to []byte +// or vice versa. +package str2bytes + +import "unsafe" + +// StringToBytes converts string to slice of bytes +// without allocation. Note, that returned slice +// must NOT be modified, since strings in Go are +// immutable. +// See unsafe.Slice. +func StringToBytes(s string) []byte { + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// BytesToString converts slice of bytes to string +// without allocation. +// See unsafe.String +func BytesToString(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) +} diff --git a/storage/keydb/storage.go b/storage/keydb/storage.go index bbc6365..7c3facf 100644 --- a/storage/keydb/storage.go +++ b/storage/keydb/storage.go @@ -114,19 +114,19 @@ func (s *store) delPeer(ctx context.Context, infoHashKey, peerID string) error { } func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString()) + return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), r.PackPeer(peer)) } func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString()) + return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), r.PackPeer(peer)) } func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString()) + return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), r.PackPeer(peer)) } func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString()) + return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), r.PackPeer(peer)) } func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) (err error) { @@ -134,18 +134,18 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee Stringer("infoHash", ih). Object("peer", peer). Msg("graduate leecher") - infoHash, peerID := ih.RawString(), peer.RawString() + infoHash, peerID := ih.RawString(), r.PackPeer(peer) ihSeederKey := r.InfoHashKey(infoHash, true, peer.Addr().Is6()) ihLeecherKey := r.InfoHashKey(infoHash, false, peer.Addr().Is6()) var moved bool if moved, err = s.SMove(ctx, ihLeecherKey, ihSeederKey, peerID).Result(); err == nil { - if moved { - err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL)) - } else { - err = s.addPeer(ctx, ihSeederKey, peerID) + if !moved { + err = s.SAdd(ctx, ihSeederKey, peerID).Err() } - if err == nil { - err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err() + if err != nil { + if err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL)); err == nil { + err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err() + } } } return err @@ -160,7 +160,7 @@ func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSe Bool("v6", v6). Msg("announce peers") - return s.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd { + return s.GetPeers(ctx, ih, forSeeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd { return s.SRandMemberN(ctx, infoHashKey, int64(maxCount)) }) } diff --git a/storage/memory/storage.go b/storage/memory/storage.go index 8697b81..a43453c 100644 --- a/storage/memory/storage.go +++ b/storage/memory/storage.go @@ -177,14 +177,15 @@ func (p *peers) len() int { return len(p.m) } -func (p *peers) keys(fn func(k bittorrent.Peer) bool) { +func (p *peers) keys(fn func(k bittorrent.Peer) bool) bool { p.RLock() + defer p.RUnlock() for k := range p.m { if !fn(k) { - break + return false } } - p.RUnlock() + return true } func (p *peers) forEach(fn func(k bittorrent.Peer, v int64) bool) { @@ -268,7 +269,7 @@ func (ps *peerStore) shardIndex(infoHash bittorrent.InfoHash, v6 bool) uint32 { // There are twice the amount of shards specified by the user, the first // half is dedicated to IPv4 swarms and the second half is dedicated to // IPv6 swarms. - idx := binary.BigEndian.Uint32([]byte(infoHash[:4])) % (uint32(len(ps.shards)) / 2) + idx := binary.BigEndian.Uint32(infoHash.Bytes()[:4]) % (uint32(len(ps.shards)) / 2) if v6 { idx += uint32(len(ps.shards) / 2) } @@ -417,8 +418,7 @@ func (ps *peerStore) AnnouncePeers(_ context.Context, ih bittorrent.InfoHash, fo if forSeeder { sw.leechers.keys(rangeFn) } else { - sw.seeders.keys(rangeFn) - if numWant > 0 { + if sw.seeders.keys(rangeFn) { sw.leechers.keys(rangeFn) } } @@ -446,8 +446,8 @@ func (ps *peerStore) ScrapeSwarm(_ context.Context, ih bittorrent.InfoHash) (lee Stringer("infoHash", ih). Msg("scrape swarm") - leechers, seeders = ps.countPeers(ih, true) - l, s := ps.countPeers(ih, false) + leechers, seeders = ps.countPeers(ih, false) + l, s := ps.countPeers(ih, true) leechers, seeders = leechers+l, seeders+s return diff --git a/storage/pg/storage.go b/storage/pg/storage.go index 631acec..b7873b0 100644 --- a/storage/pg/storage.go +++ b/storage/pg/storage.go @@ -359,15 +359,15 @@ func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) { }() } -func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) { +func (s *store) putPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, seeder bool) (err error) { logger.Trace(). - Stringer("infoHash", ih). + Hex("infoHash", ih). Object("peer", peer). Bool("seeder", seeder). Msg("put peer") _, err = s.Exec(ctx, s.Peer.AddQuery, pgx.NamedArgs{ - pInfoHash: []byte(ih), - pPeerID: peer.ID[:], + pInfoHash: ih, + pPeerID: peer.ID.Bytes(), pAddress: net.IP(peer.Addr().AsSlice()), pPort: peer.Port(), pSeeder: seeder, @@ -377,14 +377,14 @@ func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor return } -func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) { +func (s *store) delPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, seeder bool) (err error) { logger.Trace(). - Stringer("infoHash", ih). + Hex("infoHash", ih). Object("peer", peer). Msg("del peer") _, err = s.Exec(ctx, s.Peer.DelQuery, pgx.NamedArgs{ - pInfoHash: []byte(ih), - pPeerID: peer.ID[:], + pInfoHash: ih, + pPeerID: peer.ID.Bytes(), pAddress: net.IP(peer.Addr().AsSlice()), pPort: peer.Port(), pSeeder: seeder, @@ -393,19 +393,19 @@ func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor } func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.putPeer(ctx, ih, peer, true) + return s.putPeer(ctx, ih.Bytes(), peer, true) } func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(ctx, ih, peer, true) + return s.delPeer(ctx, ih.Bytes(), peer, true) } func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.putPeer(ctx, ih, peer, false) + return s.putPeer(ctx, ih.Bytes(), peer, false) } func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(ctx, ih, peer, false) + return s.delPeer(ctx, ih.Bytes(), peer, false) } func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { @@ -414,10 +414,10 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee Object("peer", peer). Msg("graduate leecher") var batch pgx.Batch - ihb := []byte(ih) + ihb := ih.Bytes() batch.Queue(s.Peer.GraduateQuery, pgx.NamedArgs{ pInfoHash: ihb, - pPeerID: peer.ID[:], + pPeerID: peer.ID.Bytes(), pAddress: net.IP(peer.Addr().AsSlice()), pPort: peer.Port(), }) @@ -425,10 +425,10 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee return s.txBatch(ctx, &batch) } -func (s *store) getPeers(ctx context.Context, ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) { +func (s *store) getPeers(ctx context.Context, ih []byte, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) { var rows pgx.Rows if rows, err = s.Query(ctx, s.Announce.Query, pgx.NamedArgs{ - pInfoHash: []byte(ih), + pInfoHash: ih, pSeeder: seeders, pV6: isV6, pCount: maxCount, @@ -503,12 +503,13 @@ func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSe Int("numWant", numWant). Bool("v6", v6). Msg("announce peers") + ihb := ih.Bytes() if forSeeder { - peers, err = s.getPeers(ctx, ih, false, numWant, v6) + peers, err = s.getPeers(ctx, ihb, false, numWant, v6) } else { - if peers, err = s.getPeers(ctx, ih, true, numWant, v6); err == nil { + if peers, err = s.getPeers(ctx, ihb, true, numWant, v6); err == nil { var addPeers []bittorrent.Peer - addPeers, err = s.getPeers(ctx, ih, false, numWant-len(peers), v6) + addPeers, err = s.getPeers(ctx, ihb, false, numWant-len(peers), v6) peers = append(peers, addPeers...) } } @@ -568,7 +569,7 @@ func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (leeche logger.Trace(). Stringer("infoHash", ih). Msg("scrape swarm") - ihb := []byte(ih) + ihb := ih.Bytes() if seeders, leechers, err = s.countPeers(ctx, ihb); err != nil { return } diff --git a/storage/redis/storage.go b/storage/redis/storage.go index f46ffbf..a60bcdf 100644 --- a/storage/redis/storage.go +++ b/storage/redis/storage.go @@ -24,13 +24,18 @@ package redis import ( "context" + "encoding/binary" "errors" + "fmt" + "net" + "net/netip" "strconv" "strings" "sync" "time" "github.com/redis/go-redis/v9" + "github.com/sot-tech/mochi/pkg/str2bytes" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/pkg/conf" @@ -389,20 +394,30 @@ func (ps *store) delPeer(ctx context.Context, infoHashKey, peerCountKey, peerID return err } +// PackPeer generates concatenation of PeerID, net port and IP-address +func PackPeer(p bittorrent.Peer) string { + ip := p.Addr() + b := make([]byte, bittorrent.PeerIDLen+2+(ip.BitLen()/8)) + copy(b[:bittorrent.PeerIDLen], p.ID.Bytes()) + binary.BigEndian.PutUint16(b[bittorrent.PeerIDLen:bittorrent.PeerIDLen+2], p.Port()) + copy(b[bittorrent.PeerIDLen+2:], ip.AsSlice()) + return str2bytes.BytesToString(b) +} + func (ps *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString()) + return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, PackPeer(peer)) } func (ps *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString()) + return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, PackPeer(peer)) } func (ps *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString()) + return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, PackPeer(peer)) } func (ps *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString()) + return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, PackPeer(peer)) } func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { @@ -411,7 +426,7 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe Object("peer", peer). Msg("graduate leecher") - infoHash, peerID, isV6 := ih.RawString(), peer.RawString(), peer.Addr().Is6() + infoHash, peerID, isV6 := ih.RawString(), PackPeer(peer), peer.Addr().Is6() ihSeederKey, ihLeecherKey := InfoHashKey(infoHash, true, isV6), InfoHashKey(infoHash, false, isV6) return ps.tx(ctx, func(tx redis.Pipeliner) error { @@ -438,12 +453,42 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe }) } +// peerMinimumLen is the least allowed length of string serialized Peer +const peerMinimumLen = bittorrent.PeerIDLen + 2 + net.IPv4len + +var errInvalidPeerDataSize = fmt.Errorf("invalid peer data (must be at least %d bytes (InfoHash + Port + IPv4))", peerMinimumLen) + +// UnpackPeer constructs Peer from serialized by Peer.PackPeer data: PeerID[20by]Port[2by]net.IP[4/16by] +func UnpackPeer(data string) (bittorrent.Peer, error) { + var peer bittorrent.Peer + if len(data) < peerMinimumLen { + return peer, errInvalidPeerDataSize + } + b := str2bytes.StringToBytes(data) + peerID, err := bittorrent.NewPeerID(b[:bittorrent.PeerIDLen]) + if err == nil { + if addr, isOk := netip.AddrFromSlice(b[bittorrent.PeerIDLen+2:]); isOk { + peer = bittorrent.Peer{ + ID: peerID, + AddrPort: netip.AddrPortFrom( + addr.Unmap(), + binary.BigEndian.Uint16(b[bittorrent.PeerIDLen:bittorrent.PeerIDLen+2]), + ), + } + } else { + err = bittorrent.ErrInvalidIP + } + } + + return peer, err +} + func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers []bittorrent.Peer, err error) { var peerIds []string peerIds, err = peersResult.Result() if err = NoResultErr(err); err == nil { for _, peerID := range peerIds { - if p, err := bittorrent.NewPeer(peerID); err == nil { + if p, err := UnpackPeer(peerID); err == nil { peers = append(peers, p) } else { logger.Error().Err(err).Str("peerID", peerID).Msg("unable to decode peer") @@ -453,13 +498,13 @@ func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers [ return } -type getPeersFn func(string, int) *redis.StringSliceCmd +type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd // GetPeers retrieves peers for provided info hash by calling membersFn and // converts result to bittorrent.Peer array. // If forSeeder set to true - returns only leechers, if false - // seeders and if maxCount not reached - leechers. -func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) { +func (ps *Connection) GetPeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) { infoHash := ih.RawString() infoHashKeys := make([]string, 1, 2) @@ -473,7 +518,7 @@ func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount for _, infoHashKey := range infoHashKeys { var peers []bittorrent.Peer - peers, err = ps.parsePeersList(membersFn(infoHashKey, maxCount)) + peers, err = ps.parsePeersList(membersFn(ctx, infoHashKey, maxCount)) maxCount -= len(peers) out = append(out, peers...) if err != nil || maxCount <= 0 { @@ -501,9 +546,7 @@ func (ps *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forS Bool("v6", v6). Msg("announce peers") - return ps.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd { - return ps.HRandField(ctx, infoHashKey, maxCount) - }) + return ps.GetPeers(ctx, ih, forSeeder, numWant, v6, ps.HRandField) } type getPeerCountFn func(context.Context, string) *redis.IntCmd diff --git a/storage/test/storage_bench.go b/storage/test/storage_bench.go index 20bdfe1..3d0496a 100644 --- a/storage/test/storage_bench.go +++ b/storage/test/storage_bench.go @@ -17,8 +17,8 @@ import ( ) const ( - ihCount = 1000 - peersCount = 1000 + ihCount = 100 + peersCount = 100 ) type benchData struct {