Merge pull request #35 from sot-tech/test

Refactor code
This commit is contained in:
SOT-TECH
2023-03-23 00:41:53 +03:00
committed by GitHub
34 changed files with 589 additions and 313 deletions
+2 -2
View File
@@ -13,9 +13,9 @@ import (
//
// See ParseURLData for specifics on parsing and limitations.
type Params interface {
// String returns a string parsed from a query. Every key can be
// GetString returns a string parsed from a query. Every key can be
// returned as a string because they are encoded in the URL as strings.
String(key string) (string, bool)
GetString(key string) (string, bool)
zerolog.LogObjectMarshaler
}
+41 -96
View File
@@ -6,14 +6,12 @@ package bittorrent
import (
"crypto/sha1"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"github.com/rs/zerolog"
"github.com/sot-tech/mochi/pkg/str2bytes"
)
// PeerIDLen is length of peer id field in bytes
@@ -25,26 +23,31 @@ type PeerID [PeerIDLen]byte
// ErrInvalidPeerIDSize holds error about invalid PeerID size
var ErrInvalidPeerIDSize = fmt.Errorf("peer ID must be %d bytes", PeerIDLen)
var zeroPeerID PeerID
// NewPeerID creates a PeerID from a byte slice.
//
// It panics if b is not 20 bytes long.
func NewPeerID(b []byte) (PeerID, error) {
var p PeerID
if len(b) != PeerIDLen {
return p, ErrInvalidPeerIDSize
return zeroPeerID, ErrInvalidPeerIDSize
}
copy(p[:], b)
return p, nil
return PeerID(b), nil
}
// Bytes returns slice of bytes represents this PeerID
func (p PeerID) Bytes() []byte {
return p[:]
}
// String implements fmt.Stringer, returning the base16 encoded PeerID.
func (p PeerID) String() string {
return hex.EncodeToString(p[:])
return hex.EncodeToString(p.Bytes())
}
// RawString returns a 20-byte string of the raw bytes of the ID.
func (p PeerID) RawString() string {
return string(p[:])
return str2bytes.BytesToString(p.Bytes())
}
// InfoHash represents an infohash.
@@ -55,16 +58,33 @@ const (
InfoHashV1Len = sha1.Size
// InfoHashV2Len ... sha256.Size
InfoHashV2Len = sha256.Size
// NoneInfoHash dummy invalid InfoHash
NoneInfoHash InfoHash = ""
)
var (
// ErrInvalidHashType holds error about invalid InfoHash input type
ErrInvalidHashType = errors.New("info hash must be provided as byte slice or raw/hex string")
// ErrInvalidHashSize holds error about invalid InfoHash size
ErrInvalidHashSize = fmt.Errorf("info hash must be either %d (for torrent V1) or %d (V2) bytes", InfoHashV1Len, InfoHashV2Len)
)
// ErrInvalidHashSize holds error about invalid InfoHash size
var ErrInvalidHashSize = fmt.Errorf("info hash must be either %d (for torrent V1) or %d (V2) bytes or same sizes x2 (if HEX encoded)", InfoHashV1Len, InfoHashV2Len)
// NewInfoHash creates an InfoHash from raw/hex byte slice.
func NewInfoHash(data []byte) (InfoHash, error) {
var ih InfoHash
switch l := len(data); l {
case InfoHashV1Len, InfoHashV2Len:
ih = InfoHash(data)
case InfoHashV1Len * 2, InfoHashV2Len * 2:
bb := make([]byte, l/2)
if _, err := hex.Decode(bb, data); err != nil {
return "", err
}
ih = InfoHash(str2bytes.BytesToString(bb))
default:
return "", ErrInvalidHashSize
}
return ih, nil
}
// NewInfoHashString creates an InfoHash from raw/hex string.
func NewInfoHashString(data string) (InfoHash, error) {
return NewInfoHash(str2bytes.StringToBytes(data))
}
// TruncateV1 returns truncated to 20-bytes length array of the corresponding InfoHash.
// If InfoHash is V2 (32 bytes), it will be truncated to 20 bytes
@@ -76,48 +96,14 @@ func (i InfoHash) TruncateV1() InfoHash {
return i
}
// NewInfoHash creates an InfoHash from a byte slice or raw/hex string.
func NewInfoHash(data any) (InfoHash, error) {
if data == nil {
return NoneInfoHash, ErrInvalidHashType
}
var ba []byte
switch t := data.(type) {
case [InfoHashV1Len]byte:
ba = t[:]
case [InfoHashV2Len]byte:
ba = t[:]
case []byte:
l := len(t)
if l == InfoHashV1Len*2 || l == InfoHashV2Len*2 {
ba = make([]byte, l/2)
if _, err := hex.Decode(ba, t); err != nil {
return NoneInfoHash, err
}
} else {
ba = t
}
case string:
l := len(t)
if l == InfoHashV1Len*2 || l == InfoHashV2Len*2 {
var err error
if ba, err = hex.DecodeString(t); err != nil {
return NoneInfoHash, err
}
} else {
ba = []byte(t)
}
}
l := len(ba)
if l != InfoHashV1Len && l != InfoHashV2Len {
return NoneInfoHash, ErrInvalidHashSize
}
return InfoHash(ba), nil
// Bytes returns slice of bytes represents this InfoHash
func (i InfoHash) Bytes() []byte {
return str2bytes.StringToBytes(string(i))
}
// String implements fmt.Stringer, returning the base16 encoded InfoHash.
func (i InfoHash) String() string {
return hex.EncodeToString([]byte(i))
return hex.EncodeToString(i.Bytes())
}
// RawString returns a string of the raw bytes of the InfoHash.
@@ -132,47 +118,6 @@ type Peer struct {
netip.AddrPort
}
// PeerMinimumLen is the least allowed length of string serialized Peer
const PeerMinimumLen = PeerIDLen + 2 + net.IPv4len
// ErrInvalidPeerDataSize holds error about invalid Peer data size
var ErrInvalidPeerDataSize = fmt.Errorf("invalid peer data it must be at least %d bytes (InfoHash + Port + IPv4)", PeerMinimumLen)
// NewPeer constructs Peer from serialized by Peer.RawString data: PeerID[20by]Port[2by]net.IP[4/16by]
func NewPeer(data string) (Peer, error) {
var peer Peer
if len(data) < PeerMinimumLen {
return peer, ErrInvalidPeerDataSize
}
b := []byte(data)
peerID, err := NewPeerID(b[:PeerIDLen])
if err == nil {
if addr, isOk := netip.AddrFromSlice(b[PeerIDLen+2:]); isOk {
peer = Peer{
ID: peerID,
AddrPort: netip.AddrPortFrom(
addr.Unmap(),
binary.BigEndian.Uint16(b[PeerIDLen:PeerIDLen+2]),
),
}
} else {
err = ErrInvalidIP
}
}
return peer, err
}
// RawString generates concatenation of PeerID, net port and IP-address
func (p Peer) RawString() string {
ip := p.Addr()
b := make([]byte, PeerIDLen+2+(ip.BitLen()/8))
copy(b[:PeerIDLen], p.ID[:])
binary.BigEndian.PutUint16(b[PeerIDLen:PeerIDLen+2], p.Port())
copy(b[PeerIDLen+2:], ip.AsSlice())
return string(b)
}
// Addr returns unmapped peer's IP address
func (p Peer) Addr() netip.Addr {
return p.AddrPort.Addr().Unmap()
+5 -19
View File
@@ -238,24 +238,10 @@ func (s Scrape) MarshalZerologObject(e *zerolog.Event) {
// Scrapes wrapper of array of Scrape-s
type Scrapes []Scrape
func (s *Scrapes) Len() int {
return len(*s)
}
func (s *Scrapes) Less(i, j int) bool {
return (*s)[i].InfoHash < (*s)[j].InfoHash
}
func (s *Scrapes) Swap(i, j int) {
(*s)[i], (*s)[j] = (*s)[j], (*s)[i]
}
// MarshalZerologArray writes array elements to zerolog event
func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
if s != nil {
for _, scrape := range *s {
a.Object(scrape)
}
func (s Scrapes) MarshalZerologArray(a *zerolog.Array) {
for _, scrape := range s {
a.Object(scrape)
}
}
@@ -264,10 +250,10 @@ func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
// The Scrapes must be in the same order as the InfoHashes in the corresponding
// ScrapeRequest.
type ScrapeResponse struct {
Files Scrapes
Data Scrapes
}
// MarshalZerologObject writes fields into zerolog event
func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) {
e.Array("scrapes", &sr.Files)
e.Array("scrapes", sr.Data)
}
+288
View File
@@ -0,0 +1,288 @@
package main
import (
"bytes"
"encoding/binary"
"errors"
"io"
"math"
"math/rand"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
cr "crypto/rand"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/frontend"
hf "github.com/sot-tech/mochi/frontend/http"
l "github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/str2bytes"
)
const (
connectUDPPacketSize = 16
announceUDPPacketSize = 98
announceNumWant = 10
timeout = time.Second * 2
)
var (
udpConnectHeader = []byte{0x0, 0x0, 0x4, 0x17, 0x27, 0x10, 0x19, 0x80}
errUDPSendTruncated = errors.New("data not fully sent")
errUDPRecvTruncated = errors.New("data not fully received")
errUDPUnexpectedAction = errors.New("unexpected action")
errTxIDMissmatch = errors.New("transaction ID missmatch")
hashes = make([][]byte, 100)
peers = make([][]byte, 100)
)
func init() {
_ = l.ConfigureLogger("", "error", false, false)
for i := range hashes {
bb := make([]byte, bittorrent.InfoHashV1Len)
if _, err := cr.Read(bb); err != nil {
panic(err)
}
hashes[i] = bb
}
for i := range peers {
bb := make([]byte, bittorrent.PeerIDLen)
if _, err := cr.Read(bb); err != nil {
panic(err)
}
peers[i] = bb
}
}
func buildUDPConnReq() []byte {
req := make([]byte, connectUDPPacketSize)
copy(req, udpConnectHeader)
// TxID
binary.BigEndian.PutUint32(req[12:16], rand.Uint32())
return req
}
func sendUDPConnReq(addr string) ([]byte, []byte, error) {
req := buildUDPConnReq()
conn, err := net.Dial("udp", addr)
if err != nil {
return nil, nil, err
}
_ = conn.SetReadDeadline(time.Now().Add(timeout))
defer conn.Close()
n, err := conn.Write(req)
if err != nil {
return nil, nil, err
}
if n != len(req) {
return nil, nil, errUDPSendTruncated
}
resp := make([]byte, 16)
n, err = conn.Read(resp)
if err != nil {
return nil, nil, err
}
if n != len(resp) {
return nil, nil, errUDPRecvTruncated
}
action := binary.BigEndian.Uint32(resp[:4])
if action != 0 {
return nil, nil, errUDPUnexpectedAction
}
if !bytes.Equal(resp[4:8], req[12:16]) {
return nil, nil, errTxIDMissmatch
}
// TxID, ConnectionID
return resp[4:8], resp[8:16], nil
}
func buildAnnounceUDPReq(txID, connID []byte) []byte {
req := make([]byte, announceUDPPacketSize)
// Connection ID
copy(req[:8], connID)
// Action
req[11] = 1
// TxID
copy(req[12:16], txID)
// InfoHash
copy(req[16:36], hashes[rand.Intn(len(hashes))])
// PeerID
copy(req[36:56], peers[rand.Intn(len(peers))])
var down, left uint64
if rand.Intn(2) == 0 {
down, left = 1, 0
} else {
down, left = 0, 1
}
// Downloaded
binary.BigEndian.PutUint64(req[56:64], down)
// Left
binary.BigEndian.PutUint64(req[64:72], left)
// Event
req[83] = 1
// Numwant
req[92], req[95] = byte(announceNumWant>>24), byte(announceNumWant>>16)
// Port
p := rand.Intn(math.MaxInt16) + 1
req[96], req[97] = byte(p>>8), byte(p)
return req
}
func BenchmarkServerUDPAnnounce(b *testing.B) {
var s Server
if err := s.Run(QuickConfig); err != nil {
b.Fatal(err)
}
defer s.Shutdown()
addr := "127.0.0.1" + frontend.DefaultListenAddress
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
txID, connID, err := sendUDPConnReq(addr)
if err != nil {
b.Log(err)
return
}
req := buildAnnounceUDPReq(txID, connID)
conn, err := net.Dial("udp", addr)
if err != nil {
b.Log(err)
return
}
_ = conn.SetReadDeadline(time.Now().Add(timeout))
n, err := conn.Write(req)
if err != nil {
_ = conn.Close()
b.Log(err)
return
}
if n != announceUDPPacketSize {
_ = conn.Close()
b.Logf("packet not fully sent, %d bytes instead of %d", n, announceUDPPacketSize)
return
}
resp := make([]byte, 1024)
n, err = conn.Read(resp)
_ = conn.Close()
if err != nil {
b.Log(err)
return
}
if n < 20 {
b.Logf("packet not fully received, only %d bytes", n)
return
}
action := binary.BigEndian.Uint32(resp[:4])
if action != 1 {
if action == 3 {
errVal := string(resp[8:n])
b.Logf("tracker error: %s", errVal)
} else {
b.Logf("unexpected action: %d", action)
}
return
}
if !bytes.Equal(resp[4:8], req[12:16]) {
b.Log("transaction ID missmatch")
}
}
})
}
func sendHTTPReq(u string) (err error) {
var r *http.Response
// nolint:gosec
if r, err = http.Get(u); err == nil {
defer r.Body.Close()
if r.StatusCode < 400 {
var out []byte
if out, err = io.ReadAll(r.Body); err == nil {
sout := string(out)
if strings.Contains(sout, "failure reason") {
return errors.New(sout)
}
}
} else {
return errors.New(r.Status)
}
}
return
}
func BenchmarkServerHTTPAnnounce(b *testing.B) {
var s Server
if err := s.Run(QuickConfig); err != nil {
b.Fatal(err)
}
defer s.Shutdown()
reqs := make([]string, len(hashes)*len(peers))
addr := "127.0.0.1" + frontend.DefaultListenAddress
for i := range reqs {
var down, left string
if rand.Intn(2) == 0 {
down, left = "1", "0"
} else {
down, left = "0", "1"
}
u := url.URL{
Scheme: "http",
Host: addr,
Path: hf.DefaultAnnounceRoute,
RawQuery: url.Values{
"event": []string{bittorrent.StartedStr},
"compact": []string{"1"},
"left": []string{left},
"downloaded": []string{down},
"uploaded": []string{"0"},
"numwant": []string{"1"},
"port": []string{strconv.FormatInt(int64(rand.Intn(math.MaxInt16)+1), 10)},
"info_hash": []string{str2bytes.BytesToString(hashes[rand.Intn(len(hashes))])},
"peer_id": []string{str2bytes.BytesToString(peers[rand.Intn(len(peers))])},
}.Encode(),
}
reqs[i] = u.String()
}
var cnt atomic.Uint32
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := sendHTTPReq(reqs[int(cnt.Add(1))%len(reqs)]); err != nil {
b.Log(err)
}
}
})
}
+3 -3
View File
@@ -190,10 +190,10 @@ prehooks:
#
# - name: client approval
# config:
# whitelist:
# client_id_list:
# - "OP1011"
# blacklist:
# - "OP1012"
# true - whitelist mode, false - blacklist
# invert: true
#
# - name: interval variation
# config:
+15 -8
View File
@@ -51,11 +51,15 @@ type Config struct {
}
const (
defaultReadTimeout = 2 * time.Second
defaultWriteTimeout = 2 * time.Second
defaultIdleTimeout = 30 * time.Second
defaultAnnounceRoute = "/announce"
defaultScrapeRoute = "/scrape"
defaultReadTimeout = 2 * time.Second
defaultWriteTimeout = 2 * time.Second
defaultIdleTimeout = 30 * time.Second
// DefaultAnnounceRoute is the default url path to listen announce
// requests if nothing else provided
DefaultAnnounceRoute = "/announce"
// DefaultScrapeRoute is the default url path to listen scrape
// requests if nothing else provided
DefaultScrapeRoute = "/scrape"
)
// Validate sanity checks values set in a config and returns a new config with
@@ -98,7 +102,7 @@ func (cfg Config) Validate() (validCfg Config, err error) {
}
}
if len(cfg.AnnounceRoutes) == 0 {
validCfg.AnnounceRoutes = []string{defaultAnnounceRoute}
validCfg.AnnounceRoutes = []string{DefaultAnnounceRoute}
logger.Warn().
Str("name", "AnnounceRoutes").
Strs("provided", cfg.AnnounceRoutes).
@@ -106,7 +110,7 @@ func (cfg Config) Validate() (validCfg Config, err error) {
Msg("falling back to default configuration")
}
if len(cfg.ScrapeRoutes) == 0 {
validCfg.ScrapeRoutes = []string{defaultScrapeRoute}
validCfg.ScrapeRoutes = []string{DefaultScrapeRoute}
logger.Warn().
Str("name", "ScrapeRoutes").
Strs("provided", cfg.ScrapeRoutes).
@@ -211,7 +215,10 @@ func runServer(s *fasthttp.Server, cfg *Config) {
err = s.ServeTLS(ln, "", "")
}
}
if !errors.Is(err, http.ErrServerClosed) {
defer ln.Close()
if err == nil {
logger.Info().Msg("server stopped")
} else if !errors.Is(err, http.ErrServerClosed) {
logger.Fatal().Err(err).Msg("server failed")
}
}
+5 -7
View File
@@ -17,8 +17,6 @@ import (
"github.com/sot-tech/mochi/pkg/log"
)
const iterations = 10000
var (
addr = fmt.Sprintf("127.0.0.1:%d", rand.Int63n(10000)+16384)
hashes = make([]string, 10)
@@ -87,7 +85,7 @@ func BenchmarkPing(b *testing.B) {
Path: "ping",
}
us := u.String()
for i := 0; i < iterations; i++ {
for i := 0; i < b.N; i++ {
if err := runGet(us, false); err != nil {
b.Error(err)
}
@@ -95,11 +93,11 @@ func BenchmarkPing(b *testing.B) {
}
func BenchmarkAnnounce(b *testing.B) {
for i := 0; i < iterations; i++ {
for i := 0; i < b.N; i++ {
u := url.URL{
Scheme: "http",
Host: addr,
Path: defaultAnnounceRoute,
Path: DefaultAnnounceRoute,
RawQuery: url.Values{
"event": []string{bittorrent.StartedStr},
"compact": []string{"1"},
@@ -119,11 +117,11 @@ func BenchmarkAnnounce(b *testing.B) {
}
func BenchmarkScrape(b *testing.B) {
for i := 0; i < iterations; i++ {
for i := 0; i < b.N; i++ {
u := url.URL{
Scheme: "http",
Host: addr,
Path: defaultScrapeRoute,
Path: DefaultScrapeRoute,
RawQuery: url.Values{"info_hash": hashes[:len(hashes)/2]}.Encode(),
}
if err := runGet(u.String(), true); err != nil {
+4 -3
View File
@@ -3,6 +3,7 @@ package http
import (
"github.com/rs/zerolog"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/str2bytes"
"github.com/valyala/fasthttp"
)
@@ -12,11 +13,11 @@ type queryParams struct {
*fasthttp.Args
}
// String returns a string parsed from a query. Every key can be returned as a
// GetString returns a string parsed from a query. Every key can be returned as a
// string because they are encoded in the URL as strings.
func (qp queryParams) String(key string) (string, bool) {
func (qp queryParams) GetString(key string) (string, bool) {
v := qp.Peek(key)
return string(v), v != nil
return str2bytes.BytesToString(v), v != nil
}
// InfoHashes returns a list of requested infohashes.
+6 -6
View File
@@ -5,10 +5,11 @@ import (
"errors"
"net/netip"
"github.com/valyala/fasthttp"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/frontend"
"github.com/valyala/fasthttp"
"github.com/sot-tech/mochi/pkg/str2bytes"
)
// ParseOptions is the configuration used to parse an Announce Request.
@@ -40,7 +41,7 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou
// Attempt to parse the event from the request.
var eventStr string
var err error
eventStr, request.EventProvided = qp.String("event")
eventStr, request.EventProvided = qp.GetString("event")
if request.EventProvided {
if request.Event, err = bittorrent.NewEvent(eventStr); err != nil {
return nil, err
@@ -57,7 +58,6 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou
if len(infoHashes) > 1 {
return nil, errMultipleInfoHashes
}
// FIXME: make sure that we have a copy of InfoHash
request.InfoHash = infoHashes[0]
// Parse the PeerID from the request.
@@ -137,7 +137,7 @@ func parseScrape(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.ScrapeR
func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) {
if opts.AllowIPSpoofing {
for _, f := range []string{"ip", "ipv4", "ipv6"} {
if ipStr, ok := p.String(f); ok {
if ipStr, ok := p.GetString(f); ok {
addresses.Add(parseRequestAddress(ipStr, true))
}
}
@@ -147,7 +147,7 @@ func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (ad
for _, ipStr := range ipValues {
for _, ipStr := range bytes.Split(ipStr, []byte{','}) {
if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 {
addresses.Add(parseRequestAddress(string(ipStr), false))
addresses.Add(parseRequestAddress(str2bytes.BytesToString(ipStr), false))
}
}
}
+24 -15
View File
@@ -5,6 +5,7 @@ import (
"errors"
"io"
"net"
"sort"
"strconv"
"time"
@@ -17,7 +18,7 @@ import (
var respBufferPool = bytepool.NewBufferPool()
func writeErrorResponse(w io.StringWriter, err error) {
message := "internal server error"
message := "mochi internal error"
var clientErr bittorrent.ClientError
if errors.As(err, &clientErr) {
message = clientErr.Error()
@@ -96,28 +97,36 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) {
bb.WriteString(addr)
if includePeerID {
bb.WriteString("7:peer id20:")
bb.Write(peer.ID[:])
bb.Write(peer.ID.Bytes())
}
bb.WriteString("4:porti")
bb.Write(fasthttp.AppendUint(nil, int(peer.Port())))
bb.Write([]byte{'e', 'e'})
port := peer.Port()
bb.Write([]byte{byte(port >> 8), byte(port), 'e', 'e'})
}
func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) {
bb := respBufferPool.Get()
defer respBufferPool.Put(bb)
bb.WriteString("d5:filesd")
for _, scrape := range resp.Files {
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
bb.WriteByte(':')
bb.Write([]byte(scrape.InfoHash))
bb.WriteString("d8:completei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
bb.WriteString("e10:downloadedi")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
bb.WriteString("e10:incompletei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
bb.WriteString("ee")
l := len(resp.Data)
if l > 0 {
if l > 1 {
sort.Slice(resp.Data, func(i, j int) bool {
return resp.Data[i].InfoHash < resp.Data[j].InfoHash
})
}
for _, scrape := range resp.Data {
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
bb.WriteByte(':')
bb.Write([]byte(scrape.InfoHash))
bb.WriteString("d8:completei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
bb.WriteString("e10:downloadedi")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
bb.WriteString("e10:incompletei")
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
bb.Write([]byte{'e', 'e'})
}
}
bb.Write([]byte{'e', 'e'})
_, _ = bb.WriteTo(w)
+4 -2
View File
@@ -9,7 +9,9 @@ import (
"github.com/libp2p/go-reuseport"
)
const defaultListenAddress = ":6969"
// DefaultListenAddress is the default net listen address
// if nothing else provided
const DefaultListenAddress = ":6969"
var errUnexpectedListenerType = errors.New("unexpected listener type")
@@ -26,7 +28,7 @@ type ListenOptions struct {
func (lo ListenOptions) Validate(logger *log.Logger) (validOptions ListenOptions) {
validOptions = lo
if len(lo.Addr) == 0 {
validOptions.Addr = defaultListenAddress
validOptions.Addr = DefaultListenAddress
logger.Warn().
Str("name", "Addr").
Str("provided", lo.Addr).
+5 -5
View File
@@ -34,7 +34,7 @@ const (
var (
logger = log.NewLogger("frontend/udp")
allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
allowedGeneratedPrivateKeyRunes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
)
func init() {
@@ -57,16 +57,16 @@ func (cfg Config) Validate() (validCfg Config) {
validCfg.ListenOptions = cfg.ListenOptions.Validate(logger)
if cfg.Workers == 0 {
cfg.Workers = 1
validCfg.Workers = 1
}
if cfg.Workers > 1 && !cfg.ReusePort {
cfg.ReusePort = true
if validCfg.Workers > 1 && !validCfg.ReusePort {
validCfg.ReusePort = true
logger.Warn().Msg("forcibly enabling ReusePort because Workers > 1")
}
// Generate a private key if one isn't provided by the user.
if cfg.PrivateKey == "" {
pkeyRunes := make([]rune, defaultKeyLen)
pkeyRunes := make([]byte, defaultKeyLen)
for i := range pkeyRunes {
pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))]
}
+5 -4
View File
@@ -8,6 +8,7 @@ import (
"github.com/rs/zerolog"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/str2bytes"
)
// ErrInvalidQueryEscape is returned when a query string contains invalid
@@ -60,7 +61,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
key, value = key[:i], key[i+1:]
}
var k, v string
k, err = url.QueryUnescape(string(key))
k, err = url.QueryUnescape(str2bytes.BytesToString(key))
if err != nil {
// QueryUnescape returns an error like "invalid escape: '%x'".
// But frontends record these errors to prometheus, which generates
@@ -68,7 +69,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
// We log it here for debugging instead.
return nil, ErrInvalidQueryEscape
}
v, err = url.QueryUnescape(string(value))
v, err = url.QueryUnescape(str2bytes.BytesToString(value))
if err != nil {
// QueryUnescape returns an error like "invalid escape: '%x'".
// But frontends record these errors to prometheus, which generates
@@ -83,9 +84,9 @@ func parseQuery(query []byte) (q *queryParams, err error) {
return q, nil
}
// String returns a string parsed from a query. Every key can be returned as a
// GetString returns a string parsed from a query. Every key can be returned as a
// string because they are encoded in the URL as strings.
func (qp queryParams) String(key string) (string, bool) {
func (qp queryParams) GetString(key string) (string, bool) {
value, ok := qp.params[strings.ToLower(key)]
return value, ok
}
-2
View File
@@ -71,7 +71,6 @@ func parseAnnounce(r Request, v6Action bool, opts frontend.ParseOptions) (*bitto
// XXX: pure V2 hashes will cause invalid parsing,
// but BEP-52 says, that V2 hashes SHOULD be truncated
// FIXME: make sure that we have a copy of InfoHash
request.InfoHash, err = bittorrent.NewInfoHash(r.Packet[16:36])
if err != nil {
return nil, errInvalidInfoHash
@@ -178,7 +177,6 @@ func parseScrape(r Request, opts frontend.ParseOptions) (*bittorrent.ScrapeReque
var request *bittorrent.ScrapeRequest
for len(r.Packet) >= bittorrent.InfoHashV1Len {
var ih bittorrent.InfoHash
// FIXME: make sure that we have a copy of InfoHash
if ih, err = bittorrent.NewInfoHash(r.Packet[:bittorrent.InfoHashV1Len]); err == nil {
infoHashes = append(infoHashes, ih)
r.Packet = r.Packet[bittorrent.InfoHashV1Len:]
+1 -1
View File
@@ -64,7 +64,7 @@ func TestHandleOptionalParameters(t *testing.T) {
t.Fatalf("expected values %v for %x", tt.values, tt.data)
} else {
for key, want := range tt.values {
if got, ok := params.String(key); !ok {
if got, ok := params.GetString(key); !ok {
t.Fatalf("params missing entry %s for data %x", key, tt.data)
} else if got != want {
t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data)
+12 -9
View File
@@ -14,13 +14,17 @@ func writeErrorResponse(w io.Writer, txID []byte, err error) {
buf := reqRespBufferPool.Get()
defer reqRespBufferPool.Put(buf)
writeHeader(buf, txID, errorActionID)
message := "mochi internal error"
var clientErr bittorrent.ClientError
// If the client wasn't at fault, acknowledge it.
if !errors.As(err, new(bittorrent.ClientError)) {
buf.WriteString("internal error occurred: ")
if errors.As(err, &clientErr) {
message = clientErr.Error()
} else {
logger.Error().Err(err).Msg("internal error")
}
buf.WriteString(err.Error())
buf.WriteString(message)
buf.WriteByte('\000')
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeAnnounceResponse encodes an announce response according to BEP 15.
@@ -51,7 +55,7 @@ func writeAnnounceResponse(w io.Writer, txID []byte, resp *bittorrent.AnnounceRe
_ = binary.Write(buf, binary.BigEndian, peer.Port())
}
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeScrapeResponse encodes a scrape response according to BEP 15.
@@ -61,13 +65,12 @@ func writeScrapeResponse(w io.Writer, txID []byte, resp *bittorrent.ScrapeRespon
writeHeader(buf, txID, scrapeActionID)
for _, scrape := range resp.Files {
for _, scrape := range resp.Data {
_ = binary.Write(buf, binary.BigEndian, scrape.Complete)
_ = binary.Write(buf, binary.BigEndian, scrape.Snatches)
_ = binary.Write(buf, binary.BigEndian, scrape.Incomplete)
}
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeConnectionID encodes a new connection response according to BEP 15.
@@ -77,7 +80,7 @@ func writeConnectionID(w io.Writer, txID, connID []byte) {
writeHeader(buf, txID, connectActionID)
buf.Write(connID)
_, _ = w.Write(buf.Bytes())
_, _ = buf.WriteTo(w)
}
// writeHeader writes the action and transaction ID to the provided response
+3 -3
View File
@@ -1,11 +1,11 @@
module github.com/sot-tech/mochi
go 1.19
go 1.20
require (
code.cloudfoundry.org/go-diodes v0.0.0-20230317203753-49f1af6d2f1a
github.com/MicahParks/keyfunc v1.9.0
github.com/anacrolix/torrent v1.48.0
github.com/anacrolix/torrent v1.49.0
github.com/cespare/xxhash/v2 v2.2.0
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/jackc/pgx/v5 v5.3.1
@@ -41,7 +41,7 @@ require (
github.com/klauspost/compress v1.16.3 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/onsi/ginkgo/v2 v2.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
+4 -4
View File
@@ -40,8 +40,8 @@ github.com/anacrolix/stm v0.2.0/go.mod h1:zoVQRvSiGjGoTmbM0vSLIiaKjWtNPeTvXUSdJQ
github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw=
github.com/anacrolix/tagflag v1.0.0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw=
github.com/anacrolix/tagflag v1.1.0/go.mod h1:Scxs9CV10NQatSmbyjqmqmeQNwGzlNe0CMUMIxqHIG8=
github.com/anacrolix/torrent v1.48.0 h1:OQe1aQb8WnhDzpcI7r3yWoHzHWKyPbfhXGfO9Q/pvbY=
github.com/anacrolix/torrent v1.48.0/go.mod h1:3UtkJ8BnxXDRwvk+eT+uwiZalfFJ8YzAhvxe4QRPSJI=
github.com/anacrolix/torrent v1.49.0 h1:v/TAd8BKsZarYEYv7VkPNv8tY5zZCwQyxMMlKKbAF4I=
github.com/anacrolix/torrent v1.49.0/go.mod h1:qT3yS5oQwDUHnBXy+zf3nozLPudG7SFNDL3Jl/zQwFw=
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
@@ -176,8 +176,8 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
+5 -8
View File
@@ -4,7 +4,9 @@
// - https://github.com/webtorrent/bittorrent-peerid/blob/master/lib/utils.js
package clientapproval
import "github.com/sot-tech/mochi/bittorrent"
import (
"github.com/sot-tech/mochi/bittorrent"
)
// ClientID represents the part of a PeerID that identifies a Peer's client
// software.
@@ -14,14 +16,9 @@ type ClientID [6]byte
func NewClientID(pid bittorrent.PeerID) ClientID {
var cid ClientID
if pid[0] == '-' {
copy(cid[:], pid[1:7])
cid = ClientID(pid[1:7])
} else {
copy(cid[:], pid[:6])
cid = ClientID(pid[:6])
}
return cid
}
func (cid ClientID) String() string {
return string(cid[:])
}
+19 -43
View File
@@ -20,23 +20,21 @@ func init() {
middleware.RegisterBuilder(Name, build)
}
var (
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
ErrClientUnapproved = bittorrent.ClientError("unapproved client")
errBothListsProvided = errors.New("using both whitelist and blacklist is invalid")
)
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
var ErrClientUnapproved = bittorrent.ClientError("client not allowed by mochi")
// Config represents all the values required by this middleware to validate
// peers based on their BitTorrent client ID.
type Config struct {
Whitelist []string
Blacklist []string
// Static list of client IDs.
ClientIDList []string `cfg:"client_id_list"`
// If Invert set to true, all client IDs stored in ClientIDList should be blacklisted.
Invert bool
}
type hook struct {
approved map[ClientID]struct{}
unapproved map[ClientID]struct{}
clientIDs map[ClientID]any
invert bool
}
func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) {
@@ -47,53 +45,31 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error
}
h := &hook{
approved: make(map[ClientID]struct{}),
unapproved: make(map[ClientID]struct{}),
clientIDs: make(map[ClientID]any, len(cfg.ClientIDList)),
invert: cfg.Invert,
}
if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 {
return nil, errBothListsProvided
}
for _, cidString := range cfg.Whitelist {
for _, cidString := range cfg.ClientIDList {
cidBytes := []byte(cidString)
if len(cidBytes) != 6 {
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
}
var cid ClientID
copy(cid[:], cidBytes)
h.approved[cid] = struct{}{}
}
for _, cidString := range cfg.Blacklist {
cidBytes := []byte(cidString)
if len(cidBytes) != 6 {
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
}
var cid ClientID
copy(cid[:], cidBytes)
h.unapproved[cid] = struct{}{}
h.clientIDs[ClientID(cidBytes)] = true
}
return h, nil
}
// HandleAnnounce checks if specified ClientID is approved or not.
// If Config.Invert set to true and hash found in provided list, function will return ErrClientUnapproved,
// that means that ClientID is blacklisted.
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
clientID := NewClientID(req.ID)
if len(h.approved) > 0 {
if _, found := h.approved[clientID]; !found {
return ctx, ErrClientUnapproved
}
var err error
if _, contains := h.clientIDs[NewClientID(req.ID)]; contains == h.invert {
err = ErrClientUnapproved
}
if len(h.unapproved) > 0 {
if _, found := h.unapproved[clientID]; found {
return ctx, ErrClientUnapproved
}
}
return ctx, nil
return ctx, err
}
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
@@ -19,7 +19,7 @@ var cases = []struct {
// Client ID is whitelisted
{
Config{
Whitelist: []string{"010203"},
ClientIDList: []string{"010203"},
},
"01020304050607080900",
true,
@@ -27,7 +27,7 @@ var cases = []struct {
// Client ID is not whitelisted
{
Config{
Whitelist: []string{"010203"},
ClientIDList: []string{"010203"},
},
"10203040506070809000",
false,
@@ -35,7 +35,8 @@ var cases = []struct {
// Client ID is not blacklisted
{
Config{
Blacklist: []string{"010203"},
ClientIDList: []string{"010203"},
Invert: true,
},
"00000000001234567890",
true,
@@ -43,7 +44,8 @@ var cases = []struct {
// Client ID is blacklisted
{
Config{
Blacklist: []string{"123456"},
ClientIDList: []string{"123456"},
Invert: true,
},
"12345678900000000000",
false,
@@ -53,7 +55,7 @@ var cases = []struct {
func TestHandleAnnounce(t *testing.T) {
for _, tt := range cases {
t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) {
c := conf.MapConfig{"whitelist": tt.cfg.Whitelist, "blacklist": tt.cfg.Blacklist}
c := conf.MapConfig{"client_id_list": tt.cfg.ClientIDList, "invert": tt.cfg.Invert}
h, err := build(c, nil)
require.Nil(t, err)
@@ -67,7 +69,7 @@ func TestHandleAnnounce(t *testing.T) {
nctx, err := h.HandleAnnounce(ctx, req, resp)
require.Equal(t, ctx, nctx)
if tt.approved == true {
if tt.approved {
require.NotEqual(t, err, ErrClientUnapproved)
} else {
require.Equal(t, err, ErrClientUnapproved)
+1 -1
View File
@@ -223,7 +223,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
if err != nil {
return
}
resp.Files = append(resp.Files, scr)
resp.Data = append(resp.Data, scr)
}
return ctx, nil
+5 -5
View File
@@ -34,10 +34,10 @@ func init() {
var (
logger = log.NewLogger("middleware/jwt")
// ErrMissingJWT is returned when a JWT is missing from a request.
ErrMissingJWT = bittorrent.ClientError("unapproved request: missing jwt")
ErrMissingJWT = bittorrent.ClientError("request not allowed by mochi: missing jwt")
// ErrInvalidJWT is returned when a JWT fails to verify.
ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt")
ErrInvalidJWT = bittorrent.ClientError("request not allowed by mochi: invalid jwt")
errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL")
@@ -160,7 +160,7 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
err = ErrInvalidJWT
} else {
var claimIH bittorrent.InfoHash
if claimIH, err = bittorrent.NewInfoHash(claims.InfoHash); err != nil {
if claimIH, err = bittorrent.NewInfoHashString(claims.InfoHash); err != nil {
logger.Info().
Err(err).
Object("source", req.RequestPeer).
@@ -206,7 +206,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
} else {
var claimIHs bittorrent.InfoHashes
for _, s := range claims.InfoHashes {
if providedIh, err := bittorrent.NewInfoHash(s); err == nil {
if providedIh, err := bittorrent.NewInfoHashString(s); err == nil {
claimIHs = append(claimIHs, providedIh)
} else {
logger.Info().
@@ -248,7 +248,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
if params != nil {
var found bool
if jwt, found = params.String(h.cfg.Header); found {
if jwt, found = params.GetString(h.cfg.Header); found {
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
jwt = jwt[len(bearerAuthPrefix):]
}
+1 -1
View File
@@ -55,7 +55,7 @@ type JWKSKeys struct {
type params map[string]string
func (p params) String(key string) (out string, found bool) {
func (p params) GetString(key string) (out string, found bool) {
out, found = p[key]
return
}
+1 -3
View File
@@ -2,7 +2,6 @@ package middleware
import (
"context"
"sort"
"time"
"github.com/sot-tech/mochi/bittorrent"
@@ -80,14 +79,13 @@ func (l *Logic) AfterAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) {
logger.Debug().Object("request", req).Msg("new scrape request")
resp = &bittorrent.ScrapeResponse{
Files: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
Data: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
}
for _, h := range l.preHooks {
if ctx, err = h.HandleScrape(ctx, req, resp); err != nil {
return nil, nil, err
}
}
sort.Sort(&resp.Files)
logger.Debug().Object("response", resp).Msg("generated scrape response")
return ctx, resp, nil
@@ -56,7 +56,7 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er
if len(c.HashList) > 0 {
init := make([]storage.Entry, 0, len(c.HashList))
for _, hashString := range c.HashList {
ih, err := bittorrent.NewInfoHash(hashString)
ih, err := bittorrent.NewInfoHashString(hashString)
if err != nil {
return nil, fmt.Errorf("whitelist : %s : %w", hashString, err)
}
@@ -65,7 +65,7 @@ func build(config conf.MapConfig, st storage.PeerStorage) (h middleware.Hook, er
}
// ErrTorrentUnapproved is the error returned when a torrent hash is invalid.
var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent")
var ErrTorrentUnapproved = bittorrent.ClientError("torrent not allowed by mochi")
type hook struct {
hashContainer container.Container
@@ -84,7 +84,7 @@ func TestHandleAnnounce(t *testing.T) {
req := &bittorrent.AnnounceRequest{}
resp := &bittorrent.AnnounceResponse{}
hashinfo, err := bittorrent.NewInfoHash(tt.ih)
hashinfo, err := bittorrent.NewInfoHashString(tt.ih)
require.Nil(t, err)
req.InfoHash = hashinfo
+21
View File
@@ -0,0 +1,21 @@
// Package str2bytes provides fast, but unsafe functions to convert string to []byte
// or vice versa.
package str2bytes
import "unsafe"
// StringToBytes converts string to slice of bytes
// without allocation. Note, that returned slice
// must NOT be modified, since strings in Go are
// immutable.
// See unsafe.Slice.
func StringToBytes(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}
// BytesToString converts slice of bytes to string
// without allocation.
// See unsafe.String
func BytesToString(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b))
}
+12 -12
View File
@@ -114,19 +114,19 @@ func (s *store) delPeer(ctx context.Context, infoHashKey, peerID string) error {
}
func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), r.PackPeer(peer))
}
func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), r.PackPeer(peer))
}
func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), r.PackPeer(peer))
}
func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), r.PackPeer(peer))
}
func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) (err error) {
@@ -134,18 +134,18 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee
Stringer("infoHash", ih).
Object("peer", peer).
Msg("graduate leecher")
infoHash, peerID := ih.RawString(), peer.RawString()
infoHash, peerID := ih.RawString(), r.PackPeer(peer)
ihSeederKey := r.InfoHashKey(infoHash, true, peer.Addr().Is6())
ihLeecherKey := r.InfoHashKey(infoHash, false, peer.Addr().Is6())
var moved bool
if moved, err = s.SMove(ctx, ihLeecherKey, ihSeederKey, peerID).Result(); err == nil {
if moved {
err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL))
} else {
err = s.addPeer(ctx, ihSeederKey, peerID)
if !moved {
err = s.SAdd(ctx, ihSeederKey, peerID).Err()
}
if err == nil {
err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err()
if err != nil {
if err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL)); err == nil {
err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err()
}
}
}
return err
@@ -160,7 +160,7 @@ func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSe
Bool("v6", v6).
Msg("announce peers")
return s.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd {
return s.GetPeers(ctx, ih, forSeeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd {
return s.SRandMemberN(ctx, infoHashKey, int64(maxCount))
})
}
+8 -8
View File
@@ -177,14 +177,15 @@ func (p *peers) len() int {
return len(p.m)
}
func (p *peers) keys(fn func(k bittorrent.Peer) bool) {
func (p *peers) keys(fn func(k bittorrent.Peer) bool) bool {
p.RLock()
defer p.RUnlock()
for k := range p.m {
if !fn(k) {
break
return false
}
}
p.RUnlock()
return true
}
func (p *peers) forEach(fn func(k bittorrent.Peer, v int64) bool) {
@@ -268,7 +269,7 @@ func (ps *peerStore) shardIndex(infoHash bittorrent.InfoHash, v6 bool) uint32 {
// There are twice the amount of shards specified by the user, the first
// half is dedicated to IPv4 swarms and the second half is dedicated to
// IPv6 swarms.
idx := binary.BigEndian.Uint32([]byte(infoHash[:4])) % (uint32(len(ps.shards)) / 2)
idx := binary.BigEndian.Uint32(infoHash.Bytes()[:4]) % (uint32(len(ps.shards)) / 2)
if v6 {
idx += uint32(len(ps.shards) / 2)
}
@@ -417,8 +418,7 @@ func (ps *peerStore) AnnouncePeers(_ context.Context, ih bittorrent.InfoHash, fo
if forSeeder {
sw.leechers.keys(rangeFn)
} else {
sw.seeders.keys(rangeFn)
if numWant > 0 {
if sw.seeders.keys(rangeFn) {
sw.leechers.keys(rangeFn)
}
}
@@ -446,8 +446,8 @@ func (ps *peerStore) ScrapeSwarm(_ context.Context, ih bittorrent.InfoHash) (lee
Stringer("infoHash", ih).
Msg("scrape swarm")
leechers, seeders = ps.countPeers(ih, true)
l, s := ps.countPeers(ih, false)
leechers, seeders = ps.countPeers(ih, false)
l, s := ps.countPeers(ih, true)
leechers, seeders = leechers+l, seeders+s
return
+21 -20
View File
@@ -359,15 +359,15 @@ func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) {
}()
}
func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
func (s *store) putPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, seeder bool) (err error) {
logger.Trace().
Stringer("infoHash", ih).
Hex("infoHash", ih).
Object("peer", peer).
Bool("seeder", seeder).
Msg("put peer")
_, err = s.Exec(ctx, s.Peer.AddQuery, pgx.NamedArgs{
pInfoHash: []byte(ih),
pPeerID: peer.ID[:],
pInfoHash: ih,
pPeerID: peer.ID.Bytes(),
pAddress: net.IP(peer.Addr().AsSlice()),
pPort: peer.Port(),
pSeeder: seeder,
@@ -377,14 +377,14 @@ func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor
return
}
func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
func (s *store) delPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, seeder bool) (err error) {
logger.Trace().
Stringer("infoHash", ih).
Hex("infoHash", ih).
Object("peer", peer).
Msg("del peer")
_, err = s.Exec(ctx, s.Peer.DelQuery, pgx.NamedArgs{
pInfoHash: []byte(ih),
pPeerID: peer.ID[:],
pInfoHash: ih,
pPeerID: peer.ID.Bytes(),
pAddress: net.IP(peer.Addr().AsSlice()),
pPort: peer.Port(),
pSeeder: seeder,
@@ -393,19 +393,19 @@ func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor
}
func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.putPeer(ctx, ih, peer, true)
return s.putPeer(ctx, ih.Bytes(), peer, true)
}
func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, ih, peer, true)
return s.delPeer(ctx, ih.Bytes(), peer, true)
}
func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.putPeer(ctx, ih, peer, false)
return s.putPeer(ctx, ih.Bytes(), peer, false)
}
func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, ih, peer, false)
return s.delPeer(ctx, ih.Bytes(), peer, false)
}
func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
@@ -414,10 +414,10 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee
Object("peer", peer).
Msg("graduate leecher")
var batch pgx.Batch
ihb := []byte(ih)
ihb := ih.Bytes()
batch.Queue(s.Peer.GraduateQuery, pgx.NamedArgs{
pInfoHash: ihb,
pPeerID: peer.ID[:],
pPeerID: peer.ID.Bytes(),
pAddress: net.IP(peer.Addr().AsSlice()),
pPort: peer.Port(),
})
@@ -425,10 +425,10 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee
return s.txBatch(ctx, &batch)
}
func (s *store) getPeers(ctx context.Context, ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) {
func (s *store) getPeers(ctx context.Context, ih []byte, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) {
var rows pgx.Rows
if rows, err = s.Query(ctx, s.Announce.Query, pgx.NamedArgs{
pInfoHash: []byte(ih),
pInfoHash: ih,
pSeeder: seeders,
pV6: isV6,
pCount: maxCount,
@@ -503,12 +503,13 @@ func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSe
Int("numWant", numWant).
Bool("v6", v6).
Msg("announce peers")
ihb := ih.Bytes()
if forSeeder {
peers, err = s.getPeers(ctx, ih, false, numWant, v6)
peers, err = s.getPeers(ctx, ihb, false, numWant, v6)
} else {
if peers, err = s.getPeers(ctx, ih, true, numWant, v6); err == nil {
if peers, err = s.getPeers(ctx, ihb, true, numWant, v6); err == nil {
var addPeers []bittorrent.Peer
addPeers, err = s.getPeers(ctx, ih, false, numWant-len(peers), v6)
addPeers, err = s.getPeers(ctx, ihb, false, numWant-len(peers), v6)
peers = append(peers, addPeers...)
}
}
@@ -568,7 +569,7 @@ func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (leeche
logger.Trace().
Stringer("infoHash", ih).
Msg("scrape swarm")
ihb := []byte(ih)
ihb := ih.Bytes()
if seeders, leechers, err = s.countPeers(ctx, ihb); err != nil {
return
}
+55 -12
View File
@@ -24,13 +24,18 @@ package redis
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
"github.com/sot-tech/mochi/pkg/str2bytes"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/conf"
@@ -389,20 +394,30 @@ func (ps *store) delPeer(ctx context.Context, infoHashKey, peerCountKey, peerID
return err
}
// PackPeer generates concatenation of PeerID, net port and IP-address
func PackPeer(p bittorrent.Peer) string {
ip := p.Addr()
b := make([]byte, bittorrent.PeerIDLen+2+(ip.BitLen()/8))
copy(b[:bittorrent.PeerIDLen], p.ID.Bytes())
binary.BigEndian.PutUint16(b[bittorrent.PeerIDLen:bittorrent.PeerIDLen+2], p.Port())
copy(b[bittorrent.PeerIDLen+2:], ip.AsSlice())
return str2bytes.BytesToString(b)
}
func (ps *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, PackPeer(peer))
}
func (ps *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, PackPeer(peer))
}
func (ps *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, PackPeer(peer))
}
func (ps *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, PackPeer(peer))
}
func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
@@ -411,7 +426,7 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe
Object("peer", peer).
Msg("graduate leecher")
infoHash, peerID, isV6 := ih.RawString(), peer.RawString(), peer.Addr().Is6()
infoHash, peerID, isV6 := ih.RawString(), PackPeer(peer), peer.Addr().Is6()
ihSeederKey, ihLeecherKey := InfoHashKey(infoHash, true, isV6), InfoHashKey(infoHash, false, isV6)
return ps.tx(ctx, func(tx redis.Pipeliner) error {
@@ -438,12 +453,42 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe
})
}
// peerMinimumLen is the least allowed length of string serialized Peer
const peerMinimumLen = bittorrent.PeerIDLen + 2 + net.IPv4len
var errInvalidPeerDataSize = fmt.Errorf("invalid peer data (must be at least %d bytes (InfoHash + Port + IPv4))", peerMinimumLen)
// UnpackPeer constructs Peer from serialized by Peer.PackPeer data: PeerID[20by]Port[2by]net.IP[4/16by]
func UnpackPeer(data string) (bittorrent.Peer, error) {
var peer bittorrent.Peer
if len(data) < peerMinimumLen {
return peer, errInvalidPeerDataSize
}
b := str2bytes.StringToBytes(data)
peerID, err := bittorrent.NewPeerID(b[:bittorrent.PeerIDLen])
if err == nil {
if addr, isOk := netip.AddrFromSlice(b[bittorrent.PeerIDLen+2:]); isOk {
peer = bittorrent.Peer{
ID: peerID,
AddrPort: netip.AddrPortFrom(
addr.Unmap(),
binary.BigEndian.Uint16(b[bittorrent.PeerIDLen:bittorrent.PeerIDLen+2]),
),
}
} else {
err = bittorrent.ErrInvalidIP
}
}
return peer, err
}
func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers []bittorrent.Peer, err error) {
var peerIds []string
peerIds, err = peersResult.Result()
if err = NoResultErr(err); err == nil {
for _, peerID := range peerIds {
if p, err := bittorrent.NewPeer(peerID); err == nil {
if p, err := UnpackPeer(peerID); err == nil {
peers = append(peers, p)
} else {
logger.Error().Err(err).Str("peerID", peerID).Msg("unable to decode peer")
@@ -453,13 +498,13 @@ func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers [
return
}
type getPeersFn func(string, int) *redis.StringSliceCmd
type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd
// GetPeers retrieves peers for provided info hash by calling membersFn and
// converts result to bittorrent.Peer array.
// If forSeeder set to true - returns only leechers, if false -
// seeders and if maxCount not reached - leechers.
func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) {
func (ps *Connection) GetPeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) {
infoHash := ih.RawString()
infoHashKeys := make([]string, 1, 2)
@@ -473,7 +518,7 @@ func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount
for _, infoHashKey := range infoHashKeys {
var peers []bittorrent.Peer
peers, err = ps.parsePeersList(membersFn(infoHashKey, maxCount))
peers, err = ps.parsePeersList(membersFn(ctx, infoHashKey, maxCount))
maxCount -= len(peers)
out = append(out, peers...)
if err != nil || maxCount <= 0 {
@@ -501,9 +546,7 @@ func (ps *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forS
Bool("v6", v6).
Msg("announce peers")
return ps.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd {
return ps.HRandField(ctx, infoHashKey, maxCount)
})
return ps.GetPeers(ctx, ih, forSeeder, numWant, v6, ps.HRandField)
}
type getPeerCountFn func(context.Context, string) *redis.IntCmd
+2 -2
View File
@@ -17,8 +17,8 @@ import (
)
const (
ihCount = 1000
peersCount = 1000
ihCount = 100
peersCount = 100
)
type benchData struct {