diff --git a/bittorrent/peer.go b/bittorrent/peer.go index 751fb21..8d0b323 100644 --- a/bittorrent/peer.go +++ b/bittorrent/peer.go @@ -9,9 +9,9 @@ import ( "encoding/hex" "fmt" "net/netip" - "unsafe" "github.com/rs/zerolog" + "github.com/sot-tech/mochi/pkg/str2bytes" ) // PeerIDLen is length of peer id field in bytes @@ -23,17 +23,23 @@ type PeerID [PeerIDLen]byte // ErrInvalidPeerIDSize holds error about invalid PeerID size var ErrInvalidPeerIDSize = fmt.Errorf("peer ID must be %d bytes", PeerIDLen) +var zeroPeerID PeerID + // NewPeerID creates a PeerID from a byte slice. // // It panics if b is not 20 bytes long. func NewPeerID(b []byte) (PeerID, error) { - var p PeerID if len(b) != PeerIDLen { - return p, ErrInvalidPeerIDSize + return zeroPeerID, ErrInvalidPeerIDSize } return PeerID(b), nil } +// Bytes returns slice of bytes represents this PeerID +func (p PeerID) Bytes() []byte { + return p[:] +} + // String implements fmt.Stringer, returning the base16 encoded PeerID. func (p PeerID) String() string { return hex.EncodeToString(p.Bytes()) @@ -41,12 +47,7 @@ func (p PeerID) String() string { // RawString returns a 20-byte string of the raw bytes of the ID. func (p PeerID) RawString() string { - return unsafe.String(&p[0], PeerIDLen) -} - -// Bytes returns slice of bytes represents this PeerID -func (p PeerID) Bytes() []byte { - return p[:] + return str2bytes.BytesToString(p.Bytes()) } // InfoHash represents an infohash. @@ -62,16 +63,6 @@ const ( // ErrInvalidHashSize holds error about invalid InfoHash size var ErrInvalidHashSize = fmt.Errorf("info hash must be either %d (for torrent V1) or %d (V2) bytes or same sizes x2 (if HEX encoded)", InfoHashV1Len, InfoHashV2Len) -// TruncateV1 returns truncated to 20-bytes length array of the corresponding InfoHash. -// If InfoHash is V2 (32 bytes), it will be truncated to 20 bytes -// according to BEP52. -func (i InfoHash) TruncateV1() InfoHash { - if len(i) == InfoHashV2Len { - return i[:InfoHashV1Len] - } - return i -} - // NewInfoHash creates an InfoHash from raw/hex byte slice. func NewInfoHash(data []byte) (InfoHash, error) { var ih InfoHash @@ -83,7 +74,7 @@ func NewInfoHash(data []byte) (InfoHash, error) { if _, err := hex.Decode(bb, data); err != nil { return "", err } - ih = InfoHash(unsafe.String(&bb[0], len(bb))) + ih = InfoHash(str2bytes.BytesToString(bb)) default: return "", ErrInvalidHashSize } @@ -92,7 +83,22 @@ func NewInfoHash(data []byte) (InfoHash, error) { // NewInfoHashString creates an InfoHash from raw/hex string. func NewInfoHashString(data string) (InfoHash, error) { - return NewInfoHash(unsafe.Slice(unsafe.StringData(data), len(data))) + return NewInfoHash(str2bytes.StringToBytes(data)) +} + +// TruncateV1 returns truncated to 20-bytes length array of the corresponding InfoHash. +// If InfoHash is V2 (32 bytes), it will be truncated to 20 bytes +// according to BEP52. +func (i InfoHash) TruncateV1() InfoHash { + if len(i) == InfoHashV2Len { + return i[:InfoHashV1Len] + } + return i +} + +// Bytes returns slice of bytes represents this InfoHash +func (i InfoHash) Bytes() []byte { + return str2bytes.StringToBytes(string(i)) } // String implements fmt.Stringer, returning the base16 encoded InfoHash. @@ -105,11 +111,6 @@ func (i InfoHash) RawString() string { return string(i) } -// Bytes returns slice of bytes represents this InfoHash -func (i InfoHash) Bytes() []byte { - return unsafe.Slice(unsafe.StringData(string(i)), len(i)) -} - // Peer represents the connection details of a peer that is returned in an // announce response. type Peer struct { diff --git a/cmd/mochi/server_test.go b/cmd/mochi/server_test.go new file mode 100644 index 0000000..5af270b --- /dev/null +++ b/cmd/mochi/server_test.go @@ -0,0 +1,288 @@ +package main + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "math" + "math/rand" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + cr "crypto/rand" + + "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/frontend" + hf "github.com/sot-tech/mochi/frontend/http" + l "github.com/sot-tech/mochi/pkg/log" + "github.com/sot-tech/mochi/pkg/str2bytes" +) + +const ( + connectUDPPacketSize = 16 + announceUDPPacketSize = 98 + announceNumWant = 10 + + timeout = time.Second * 2 +) + +var ( + udpConnectHeader = []byte{0x0, 0x0, 0x4, 0x17, 0x27, 0x10, 0x19, 0x80} + errUDPSendTruncated = errors.New("data not fully sent") + errUDPRecvTruncated = errors.New("data not fully received") + errUDPUnexpectedAction = errors.New("unexpected action") + errTxIDMissmatch = errors.New("transaction ID missmatch") + hashes = make([][]byte, 100) + peers = make([][]byte, 100) +) + +func init() { + _ = l.ConfigureLogger("", "error", false, false) + for i := range hashes { + bb := make([]byte, bittorrent.InfoHashV1Len) + if _, err := cr.Read(bb); err != nil { + panic(err) + } + hashes[i] = bb + } + + for i := range peers { + bb := make([]byte, bittorrent.PeerIDLen) + if _, err := cr.Read(bb); err != nil { + panic(err) + } + peers[i] = bb + } +} + +func buildUDPConnReq() []byte { + req := make([]byte, connectUDPPacketSize) + copy(req, udpConnectHeader) + + // TxID + binary.BigEndian.PutUint32(req[12:16], rand.Uint32()) + return req +} + +func sendUDPConnReq(addr string) ([]byte, []byte, error) { + req := buildUDPConnReq() + + conn, err := net.Dial("udp", addr) + if err != nil { + return nil, nil, err + } + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + + defer conn.Close() + + n, err := conn.Write(req) + if err != nil { + return nil, nil, err + } + + if n != len(req) { + return nil, nil, errUDPSendTruncated + } + + resp := make([]byte, 16) + n, err = conn.Read(resp) + if err != nil { + return nil, nil, err + } + + if n != len(resp) { + return nil, nil, errUDPRecvTruncated + } + + action := binary.BigEndian.Uint32(resp[:4]) + if action != 0 { + return nil, nil, errUDPUnexpectedAction + } + + if !bytes.Equal(resp[4:8], req[12:16]) { + return nil, nil, errTxIDMissmatch + } + + // TxID, ConnectionID + return resp[4:8], resp[8:16], nil +} + +func buildAnnounceUDPReq(txID, connID []byte) []byte { + req := make([]byte, announceUDPPacketSize) + + // Connection ID + copy(req[:8], connID) + + // Action + req[11] = 1 + + // TxID + copy(req[12:16], txID) + + // InfoHash + copy(req[16:36], hashes[rand.Intn(len(hashes))]) + + // PeerID + copy(req[36:56], peers[rand.Intn(len(peers))]) + + var down, left uint64 + if rand.Intn(2) == 0 { + down, left = 1, 0 + } else { + down, left = 0, 1 + } + // Downloaded + binary.BigEndian.PutUint64(req[56:64], down) + // Left + binary.BigEndian.PutUint64(req[64:72], left) + + // Event + req[83] = 1 + + // Numwant + req[92], req[95] = byte(announceNumWant>>24), byte(announceNumWant>>16) + + // Port + p := rand.Intn(math.MaxInt16) + 1 + req[96], req[97] = byte(p>>8), byte(p) + return req +} + +func BenchmarkServerUDPAnnounce(b *testing.B) { + var s Server + if err := s.Run(QuickConfig); err != nil { + b.Fatal(err) + } + defer s.Shutdown() + + addr := "127.0.0.1" + frontend.DefaultListenAddress + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + txID, connID, err := sendUDPConnReq(addr) + if err != nil { + b.Log(err) + return + } + req := buildAnnounceUDPReq(txID, connID) + + conn, err := net.Dial("udp", addr) + if err != nil { + b.Log(err) + return + } + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + n, err := conn.Write(req) + if err != nil { + _ = conn.Close() + b.Log(err) + return + } + if n != announceUDPPacketSize { + _ = conn.Close() + b.Logf("packet not fully sent, %d bytes instead of %d", n, announceUDPPacketSize) + return + } + resp := make([]byte, 1024) + n, err = conn.Read(resp) + _ = conn.Close() + if err != nil { + b.Log(err) + return + } + if n < 20 { + b.Logf("packet not fully received, only %d bytes", n) + return + } + action := binary.BigEndian.Uint32(resp[:4]) + if action != 1 { + if action == 3 { + errVal := string(resp[8:n]) + b.Logf("tracker error: %s", errVal) + } else { + b.Logf("unexpected action: %d", action) + } + return + } + + if !bytes.Equal(resp[4:8], req[12:16]) { + b.Log("transaction ID missmatch") + } + } + }) +} + +func sendHTTPReq(u string) (err error) { + var r *http.Response + // nolint:gosec + if r, err = http.Get(u); err == nil { + defer r.Body.Close() + if r.StatusCode < 400 { + var out []byte + if out, err = io.ReadAll(r.Body); err == nil { + sout := string(out) + if strings.Contains(sout, "failure reason") { + return errors.New(sout) + } + } + } else { + return errors.New(r.Status) + } + } + return +} + +func BenchmarkServerHTTPAnnounce(b *testing.B) { + var s Server + if err := s.Run(QuickConfig); err != nil { + b.Fatal(err) + } + defer s.Shutdown() + + reqs := make([]string, len(hashes)*len(peers)) + addr := "127.0.0.1" + frontend.DefaultListenAddress + for i := range reqs { + var down, left string + if rand.Intn(2) == 0 { + down, left = "1", "0" + } else { + down, left = "0", "1" + } + u := url.URL{ + Scheme: "http", + Host: addr, + Path: hf.DefaultAnnounceRoute, + RawQuery: url.Values{ + "event": []string{bittorrent.StartedStr}, + "compact": []string{"1"}, + "left": []string{left}, + "downloaded": []string{down}, + "uploaded": []string{"0"}, + "numwant": []string{"1"}, + "port": []string{strconv.FormatInt(int64(rand.Intn(math.MaxInt16)+1), 10)}, + "info_hash": []string{str2bytes.BytesToString(hashes[rand.Intn(len(hashes))])}, + "peer_id": []string{str2bytes.BytesToString(peers[rand.Intn(len(peers))])}, + }.Encode(), + } + reqs[i] = u.String() + } + + var cnt atomic.Uint32 + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := sendHTTPReq(reqs[int(cnt.Add(1))%len(reqs)]); err != nil { + b.Log(err) + } + } + }) +} diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index 13df09d..58a113c 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -51,11 +51,15 @@ type Config struct { } const ( - defaultReadTimeout = 2 * time.Second - defaultWriteTimeout = 2 * time.Second - defaultIdleTimeout = 30 * time.Second - defaultAnnounceRoute = "/announce" - defaultScrapeRoute = "/scrape" + defaultReadTimeout = 2 * time.Second + defaultWriteTimeout = 2 * time.Second + defaultIdleTimeout = 30 * time.Second + // DefaultAnnounceRoute is the default url path to listen announce + // requests if nothing else provided + DefaultAnnounceRoute = "/announce" + // DefaultScrapeRoute is the default url path to listen scrape + // requests if nothing else provided + DefaultScrapeRoute = "/scrape" ) // Validate sanity checks values set in a config and returns a new config with @@ -98,7 +102,7 @@ func (cfg Config) Validate() (validCfg Config, err error) { } } if len(cfg.AnnounceRoutes) == 0 { - validCfg.AnnounceRoutes = []string{defaultAnnounceRoute} + validCfg.AnnounceRoutes = []string{DefaultAnnounceRoute} logger.Warn(). Str("name", "AnnounceRoutes"). Strs("provided", cfg.AnnounceRoutes). @@ -106,7 +110,7 @@ func (cfg Config) Validate() (validCfg Config, err error) { Msg("falling back to default configuration") } if len(cfg.ScrapeRoutes) == 0 { - validCfg.ScrapeRoutes = []string{defaultScrapeRoute} + validCfg.ScrapeRoutes = []string{DefaultScrapeRoute} logger.Warn(). Str("name", "ScrapeRoutes"). Strs("provided", cfg.ScrapeRoutes). @@ -211,7 +215,10 @@ func runServer(s *fasthttp.Server, cfg *Config) { err = s.ServeTLS(ln, "", "") } } - if !errors.Is(err, http.ErrServerClosed) { + defer ln.Close() + if err == nil { + logger.Info().Msg("server stopped") + } else if !errors.Is(err, http.ErrServerClosed) { logger.Fatal().Err(err).Msg("server failed") } } diff --git a/frontend/http/frontend_test.go b/frontend/http/frontend_test.go index 1df4904..f47271b 100644 --- a/frontend/http/frontend_test.go +++ b/frontend/http/frontend_test.go @@ -97,7 +97,7 @@ func BenchmarkAnnounce(b *testing.B) { u := url.URL{ Scheme: "http", Host: addr, - Path: defaultAnnounceRoute, + Path: DefaultAnnounceRoute, RawQuery: url.Values{ "event": []string{bittorrent.StartedStr}, "compact": []string{"1"}, @@ -121,7 +121,7 @@ func BenchmarkScrape(b *testing.B) { u := url.URL{ Scheme: "http", Host: addr, - Path: defaultScrapeRoute, + Path: DefaultScrapeRoute, RawQuery: url.Values{"info_hash": hashes[:len(hashes)/2]}.Encode(), } if err := runGet(u.String(), true); err != nil { diff --git a/frontend/http/writer.go b/frontend/http/writer.go index b384703..28c5610 100644 --- a/frontend/http/writer.go +++ b/frontend/http/writer.go @@ -96,7 +96,7 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) { bb.WriteString(addr) if includePeerID { bb.WriteString("7:peer id20:") - bb.Write(peer.ID[:]) + bb.Write(peer.ID.Bytes()) } bb.WriteString("4:porti") bb.Write(fasthttp.AppendUint(nil, int(peer.Port()))) diff --git a/frontend/options.go b/frontend/options.go index 61fc631..ba315c7 100644 --- a/frontend/options.go +++ b/frontend/options.go @@ -9,7 +9,9 @@ import ( "github.com/libp2p/go-reuseport" ) -const defaultListenAddress = ":6969" +// DefaultListenAddress is the default net listen address +// if nothing else provided +const DefaultListenAddress = ":6969" var errUnexpectedListenerType = errors.New("unexpected listener type") @@ -26,7 +28,7 @@ type ListenOptions struct { func (lo ListenOptions) Validate(logger *log.Logger) (validOptions ListenOptions) { validOptions = lo if len(lo.Addr) == 0 { - validOptions.Addr = defaultListenAddress + validOptions.Addr = DefaultListenAddress logger.Warn(). Str("name", "Addr"). Str("provided", lo.Addr). diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index de6f4fa..cb229b8 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -57,10 +57,10 @@ func (cfg Config) Validate() (validCfg Config) { validCfg.ListenOptions = cfg.ListenOptions.Validate(logger) if cfg.Workers == 0 { - cfg.Workers = 1 + validCfg.Workers = 1 } - if cfg.Workers > 1 && !cfg.ReusePort { - cfg.ReusePort = true + if validCfg.Workers > 1 && !validCfg.ReusePort { + validCfg.ReusePort = true logger.Warn().Msg("forcibly enabling ReusePort because Workers > 1") } diff --git a/middleware/clientapproval/client_id.go b/middleware/clientapproval/client_id.go index 601bc6a..685bf67 100644 --- a/middleware/clientapproval/client_id.go +++ b/middleware/clientapproval/client_id.go @@ -4,7 +4,10 @@ // - https://github.com/webtorrent/bittorrent-peerid/blob/master/lib/utils.js package clientapproval -import "github.com/sot-tech/mochi/bittorrent" +import ( + "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/pkg/str2bytes" +) // ClientID represents the part of a PeerID that identifies a Peer's client // software. @@ -23,5 +26,5 @@ func NewClientID(pid bittorrent.PeerID) ClientID { } func (cid ClientID) String() string { - return string(cid[:]) + return str2bytes.BytesToString(cid[:]) } diff --git a/pkg/str2bytes/str2bytes.go b/pkg/str2bytes/str2bytes.go new file mode 100644 index 0000000..1309c81 --- /dev/null +++ b/pkg/str2bytes/str2bytes.go @@ -0,0 +1,21 @@ +// Package str2bytes provides fast, but unsafe functions to convert string to []byte +// or vice versa. +package str2bytes + +import "unsafe" + +// StringToBytes converts string to slice of bytes +// without allocation. Note, that returned slice +// must NOT be modified, since strings in Go are +// immutable. +// See unsafe.Slice. +func StringToBytes(s string) []byte { + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// BytesToString converts slice of bytes to string +// without allocation. +// See unsafe.String +func BytesToString(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) +} diff --git a/storage/keydb/storage.go b/storage/keydb/storage.go index ba5c1ea..4ab2fac 100644 --- a/storage/keydb/storage.go +++ b/storage/keydb/storage.go @@ -139,13 +139,13 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee ihLeecherKey := r.InfoHashKey(infoHash, false, peer.Addr().Is6()) var moved bool if moved, err = s.SMove(ctx, ihLeecherKey, ihSeederKey, peerID).Result(); err == nil { - if moved { - err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL)) - } else { - err = s.addPeer(ctx, ihSeederKey, peerID) + if !moved { + err = s.SAdd(ctx, ihSeederKey, peerID).Err() } - if err == nil { - err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err() + if err != nil { + if err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL)); err == nil { + err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err() + } } } return err diff --git a/storage/pg/storage.go b/storage/pg/storage.go index 554c2af..b7873b0 100644 --- a/storage/pg/storage.go +++ b/storage/pg/storage.go @@ -367,7 +367,7 @@ func (s *store) putPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, se Msg("put peer") _, err = s.Exec(ctx, s.Peer.AddQuery, pgx.NamedArgs{ pInfoHash: ih, - pPeerID: peer.ID[:], + pPeerID: peer.ID.Bytes(), pAddress: net.IP(peer.Addr().AsSlice()), pPort: peer.Port(), pSeeder: seeder, @@ -384,7 +384,7 @@ func (s *store) delPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, se Msg("del peer") _, err = s.Exec(ctx, s.Peer.DelQuery, pgx.NamedArgs{ pInfoHash: ih, - pPeerID: peer.ID[:], + pPeerID: peer.ID.Bytes(), pAddress: net.IP(peer.Addr().AsSlice()), pPort: peer.Port(), pSeeder: seeder, @@ -417,7 +417,7 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee ihb := ih.Bytes() batch.Queue(s.Peer.GraduateQuery, pgx.NamedArgs{ pInfoHash: ihb, - pPeerID: peer.ID[:], + pPeerID: peer.ID.Bytes(), pAddress: net.IP(peer.Addr().AsSlice()), pPort: peer.Port(), }) diff --git a/storage/redis/storage.go b/storage/redis/storage.go index 64f12fc..1c94ed9 100644 --- a/storage/redis/storage.go +++ b/storage/redis/storage.go @@ -33,9 +33,9 @@ import ( "strings" "sync" "time" - "unsafe" "github.com/redis/go-redis/v9" + "github.com/sot-tech/mochi/pkg/str2bytes" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/pkg/conf" @@ -401,7 +401,7 @@ func PackPeer(p bittorrent.Peer) string { copy(b[:bittorrent.PeerIDLen], p.ID.Bytes()) binary.BigEndian.PutUint16(b[bittorrent.PeerIDLen:bittorrent.PeerIDLen+2], p.Port()) copy(b[bittorrent.PeerIDLen+2:], ip.AsSlice()) - return unsafe.String(&b[0], len(b)) + return str2bytes.BytesToString(b) } func (ps *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { @@ -464,7 +464,7 @@ func UnpackPeer(data string) (bittorrent.Peer, error) { if len(data) < peerMinimumLen { return peer, errInvalidPeerDataSize } - b := unsafe.Slice(unsafe.StringData(data), len(data)) + b := str2bytes.StringToBytes(data) peerID, err := bittorrent.NewPeerID(b[:bittorrent.PeerIDLen]) if err == nil { if addr, isOk := netip.AddrFromSlice(b[bittorrent.PeerIDLen+2:]); isOk { diff --git a/storage/test/storage_bench.go b/storage/test/storage_bench.go index 20bdfe1..3d0496a 100644 --- a/storage/test/storage_bench.go +++ b/storage/test/storage_bench.go @@ -17,8 +17,8 @@ import ( ) const ( - ihCount = 1000 - peersCount = 1000 + ihCount = 100 + peersCount = 100 ) type benchData struct {