mirror of
https://github.com/sot-tech/mochi.git
synced 2026-06-11 23:43:29 -07:00
@@ -13,9 +13,9 @@ import (
|
||||
//
|
||||
// See ParseURLData for specifics on parsing and limitations.
|
||||
type Params interface {
|
||||
// String returns a string parsed from a query. Every key can be
|
||||
// GetString returns a string parsed from a query. Every key can be
|
||||
// returned as a string because they are encoded in the URL as strings.
|
||||
String(key string) (string, bool)
|
||||
GetString(key string) (string, bool)
|
||||
|
||||
zerolog.LogObjectMarshaler
|
||||
}
|
||||
|
||||
+41
-96
@@ -6,14 +6,12 @@ package bittorrent
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
)
|
||||
|
||||
// PeerIDLen is length of peer id field in bytes
|
||||
@@ -25,26 +23,31 @@ type PeerID [PeerIDLen]byte
|
||||
// ErrInvalidPeerIDSize holds error about invalid PeerID size
|
||||
var ErrInvalidPeerIDSize = fmt.Errorf("peer ID must be %d bytes", PeerIDLen)
|
||||
|
||||
var zeroPeerID PeerID
|
||||
|
||||
// NewPeerID creates a PeerID from a byte slice.
|
||||
//
|
||||
// It panics if b is not 20 bytes long.
|
||||
func NewPeerID(b []byte) (PeerID, error) {
|
||||
var p PeerID
|
||||
if len(b) != PeerIDLen {
|
||||
return p, ErrInvalidPeerIDSize
|
||||
return zeroPeerID, ErrInvalidPeerIDSize
|
||||
}
|
||||
copy(p[:], b)
|
||||
return p, nil
|
||||
return PeerID(b), nil
|
||||
}
|
||||
|
||||
// Bytes returns slice of bytes represents this PeerID
|
||||
func (p PeerID) Bytes() []byte {
|
||||
return p[:]
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer, returning the base16 encoded PeerID.
|
||||
func (p PeerID) String() string {
|
||||
return hex.EncodeToString(p[:])
|
||||
return hex.EncodeToString(p.Bytes())
|
||||
}
|
||||
|
||||
// RawString returns a 20-byte string of the raw bytes of the ID.
|
||||
func (p PeerID) RawString() string {
|
||||
return string(p[:])
|
||||
return str2bytes.BytesToString(p.Bytes())
|
||||
}
|
||||
|
||||
// InfoHash represents an infohash.
|
||||
@@ -55,16 +58,33 @@ const (
|
||||
InfoHashV1Len = sha1.Size
|
||||
// InfoHashV2Len ... sha256.Size
|
||||
InfoHashV2Len = sha256.Size
|
||||
// NoneInfoHash dummy invalid InfoHash
|
||||
NoneInfoHash InfoHash = ""
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidHashType holds error about invalid InfoHash input type
|
||||
ErrInvalidHashType = errors.New("info hash must be provided as byte slice or raw/hex string")
|
||||
// ErrInvalidHashSize holds error about invalid InfoHash size
|
||||
ErrInvalidHashSize = fmt.Errorf("info hash must be either %d (for torrent V1) or %d (V2) bytes", InfoHashV1Len, InfoHashV2Len)
|
||||
)
|
||||
// ErrInvalidHashSize holds error about invalid InfoHash size
|
||||
var ErrInvalidHashSize = fmt.Errorf("info hash must be either %d (for torrent V1) or %d (V2) bytes or same sizes x2 (if HEX encoded)", InfoHashV1Len, InfoHashV2Len)
|
||||
|
||||
// NewInfoHash creates an InfoHash from raw/hex byte slice.
|
||||
func NewInfoHash(data []byte) (InfoHash, error) {
|
||||
var ih InfoHash
|
||||
switch l := len(data); l {
|
||||
case InfoHashV1Len, InfoHashV2Len:
|
||||
ih = InfoHash(data)
|
||||
case InfoHashV1Len * 2, InfoHashV2Len * 2:
|
||||
bb := make([]byte, l/2)
|
||||
if _, err := hex.Decode(bb, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ih = InfoHash(str2bytes.BytesToString(bb))
|
||||
default:
|
||||
return "", ErrInvalidHashSize
|
||||
}
|
||||
return ih, nil
|
||||
}
|
||||
|
||||
// NewInfoHashString creates an InfoHash from raw/hex string.
|
||||
func NewInfoHashString(data string) (InfoHash, error) {
|
||||
return NewInfoHash(str2bytes.StringToBytes(data))
|
||||
}
|
||||
|
||||
// TruncateV1 returns truncated to 20-bytes length array of the corresponding InfoHash.
|
||||
// If InfoHash is V2 (32 bytes), it will be truncated to 20 bytes
|
||||
@@ -76,48 +96,14 @@ func (i InfoHash) TruncateV1() InfoHash {
|
||||
return i
|
||||
}
|
||||
|
||||
// NewInfoHash creates an InfoHash from a byte slice or raw/hex string.
|
||||
func NewInfoHash(data any) (InfoHash, error) {
|
||||
if data == nil {
|
||||
return NoneInfoHash, ErrInvalidHashType
|
||||
}
|
||||
var ba []byte
|
||||
switch t := data.(type) {
|
||||
case [InfoHashV1Len]byte:
|
||||
ba = t[:]
|
||||
case [InfoHashV2Len]byte:
|
||||
ba = t[:]
|
||||
case []byte:
|
||||
l := len(t)
|
||||
if l == InfoHashV1Len*2 || l == InfoHashV2Len*2 {
|
||||
ba = make([]byte, l/2)
|
||||
if _, err := hex.Decode(ba, t); err != nil {
|
||||
return NoneInfoHash, err
|
||||
}
|
||||
} else {
|
||||
ba = t
|
||||
}
|
||||
case string:
|
||||
l := len(t)
|
||||
if l == InfoHashV1Len*2 || l == InfoHashV2Len*2 {
|
||||
var err error
|
||||
if ba, err = hex.DecodeString(t); err != nil {
|
||||
return NoneInfoHash, err
|
||||
}
|
||||
} else {
|
||||
ba = []byte(t)
|
||||
}
|
||||
}
|
||||
l := len(ba)
|
||||
if l != InfoHashV1Len && l != InfoHashV2Len {
|
||||
return NoneInfoHash, ErrInvalidHashSize
|
||||
}
|
||||
return InfoHash(ba), nil
|
||||
// Bytes returns slice of bytes represents this InfoHash
|
||||
func (i InfoHash) Bytes() []byte {
|
||||
return str2bytes.StringToBytes(string(i))
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer, returning the base16 encoded InfoHash.
|
||||
func (i InfoHash) String() string {
|
||||
return hex.EncodeToString([]byte(i))
|
||||
return hex.EncodeToString(i.Bytes())
|
||||
}
|
||||
|
||||
// RawString returns a string of the raw bytes of the InfoHash.
|
||||
@@ -132,47 +118,6 @@ type Peer struct {
|
||||
netip.AddrPort
|
||||
}
|
||||
|
||||
// PeerMinimumLen is the least allowed length of string serialized Peer
|
||||
const PeerMinimumLen = PeerIDLen + 2 + net.IPv4len
|
||||
|
||||
// ErrInvalidPeerDataSize holds error about invalid Peer data size
|
||||
var ErrInvalidPeerDataSize = fmt.Errorf("invalid peer data it must be at least %d bytes (InfoHash + Port + IPv4)", PeerMinimumLen)
|
||||
|
||||
// NewPeer constructs Peer from serialized by Peer.RawString data: PeerID[20by]Port[2by]net.IP[4/16by]
|
||||
func NewPeer(data string) (Peer, error) {
|
||||
var peer Peer
|
||||
if len(data) < PeerMinimumLen {
|
||||
return peer, ErrInvalidPeerDataSize
|
||||
}
|
||||
b := []byte(data)
|
||||
peerID, err := NewPeerID(b[:PeerIDLen])
|
||||
if err == nil {
|
||||
if addr, isOk := netip.AddrFromSlice(b[PeerIDLen+2:]); isOk {
|
||||
peer = Peer{
|
||||
ID: peerID,
|
||||
AddrPort: netip.AddrPortFrom(
|
||||
addr.Unmap(),
|
||||
binary.BigEndian.Uint16(b[PeerIDLen:PeerIDLen+2]),
|
||||
),
|
||||
}
|
||||
} else {
|
||||
err = ErrInvalidIP
|
||||
}
|
||||
}
|
||||
|
||||
return peer, err
|
||||
}
|
||||
|
||||
// RawString generates concatenation of PeerID, net port and IP-address
|
||||
func (p Peer) RawString() string {
|
||||
ip := p.Addr()
|
||||
b := make([]byte, PeerIDLen+2+(ip.BitLen()/8))
|
||||
copy(b[:PeerIDLen], p.ID[:])
|
||||
binary.BigEndian.PutUint16(b[PeerIDLen:PeerIDLen+2], p.Port())
|
||||
copy(b[PeerIDLen+2:], ip.AsSlice())
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// Addr returns unmapped peer's IP address
|
||||
func (p Peer) Addr() netip.Addr {
|
||||
return p.AddrPort.Addr().Unmap()
|
||||
|
||||
+5
-19
@@ -238,24 +238,10 @@ func (s Scrape) MarshalZerologObject(e *zerolog.Event) {
|
||||
// Scrapes wrapper of array of Scrape-s
|
||||
type Scrapes []Scrape
|
||||
|
||||
func (s *Scrapes) Len() int {
|
||||
return len(*s)
|
||||
}
|
||||
|
||||
func (s *Scrapes) Less(i, j int) bool {
|
||||
return (*s)[i].InfoHash < (*s)[j].InfoHash
|
||||
}
|
||||
|
||||
func (s *Scrapes) Swap(i, j int) {
|
||||
(*s)[i], (*s)[j] = (*s)[j], (*s)[i]
|
||||
}
|
||||
|
||||
// MarshalZerologArray writes array elements to zerolog event
|
||||
func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
||||
if s != nil {
|
||||
for _, scrape := range *s {
|
||||
a.Object(scrape)
|
||||
}
|
||||
func (s Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
||||
for _, scrape := range s {
|
||||
a.Object(scrape)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,10 +250,10 @@ func (s *Scrapes) MarshalZerologArray(a *zerolog.Array) {
|
||||
// The Scrapes must be in the same order as the InfoHashes in the corresponding
|
||||
// ScrapeRequest.
|
||||
type ScrapeResponse struct {
|
||||
Files Scrapes
|
||||
Data Scrapes
|
||||
}
|
||||
|
||||
// MarshalZerologObject writes fields into zerolog event
|
||||
func (sr ScrapeResponse) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Array("scrapes", &sr.Files)
|
||||
e.Array("scrapes", sr.Data)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cr "crypto/rand"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/frontend"
|
||||
hf "github.com/sot-tech/mochi/frontend/http"
|
||||
l "github.com/sot-tech/mochi/pkg/log"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
)
|
||||
|
||||
const (
|
||||
connectUDPPacketSize = 16
|
||||
announceUDPPacketSize = 98
|
||||
announceNumWant = 10
|
||||
|
||||
timeout = time.Second * 2
|
||||
)
|
||||
|
||||
var (
|
||||
udpConnectHeader = []byte{0x0, 0x0, 0x4, 0x17, 0x27, 0x10, 0x19, 0x80}
|
||||
errUDPSendTruncated = errors.New("data not fully sent")
|
||||
errUDPRecvTruncated = errors.New("data not fully received")
|
||||
errUDPUnexpectedAction = errors.New("unexpected action")
|
||||
errTxIDMissmatch = errors.New("transaction ID missmatch")
|
||||
hashes = make([][]byte, 100)
|
||||
peers = make([][]byte, 100)
|
||||
)
|
||||
|
||||
func init() {
|
||||
_ = l.ConfigureLogger("", "error", false, false)
|
||||
for i := range hashes {
|
||||
bb := make([]byte, bittorrent.InfoHashV1Len)
|
||||
if _, err := cr.Read(bb); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
hashes[i] = bb
|
||||
}
|
||||
|
||||
for i := range peers {
|
||||
bb := make([]byte, bittorrent.PeerIDLen)
|
||||
if _, err := cr.Read(bb); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
peers[i] = bb
|
||||
}
|
||||
}
|
||||
|
||||
func buildUDPConnReq() []byte {
|
||||
req := make([]byte, connectUDPPacketSize)
|
||||
copy(req, udpConnectHeader)
|
||||
|
||||
// TxID
|
||||
binary.BigEndian.PutUint32(req[12:16], rand.Uint32())
|
||||
return req
|
||||
}
|
||||
|
||||
func sendUDPConnReq(addr string) ([]byte, []byte, error) {
|
||||
req := buildUDPConnReq()
|
||||
|
||||
conn, err := net.Dial("udp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
n, err := conn.Write(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if n != len(req) {
|
||||
return nil, nil, errUDPSendTruncated
|
||||
}
|
||||
|
||||
resp := make([]byte, 16)
|
||||
n, err = conn.Read(resp)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if n != len(resp) {
|
||||
return nil, nil, errUDPRecvTruncated
|
||||
}
|
||||
|
||||
action := binary.BigEndian.Uint32(resp[:4])
|
||||
if action != 0 {
|
||||
return nil, nil, errUDPUnexpectedAction
|
||||
}
|
||||
|
||||
if !bytes.Equal(resp[4:8], req[12:16]) {
|
||||
return nil, nil, errTxIDMissmatch
|
||||
}
|
||||
|
||||
// TxID, ConnectionID
|
||||
return resp[4:8], resp[8:16], nil
|
||||
}
|
||||
|
||||
func buildAnnounceUDPReq(txID, connID []byte) []byte {
|
||||
req := make([]byte, announceUDPPacketSize)
|
||||
|
||||
// Connection ID
|
||||
copy(req[:8], connID)
|
||||
|
||||
// Action
|
||||
req[11] = 1
|
||||
|
||||
// TxID
|
||||
copy(req[12:16], txID)
|
||||
|
||||
// InfoHash
|
||||
copy(req[16:36], hashes[rand.Intn(len(hashes))])
|
||||
|
||||
// PeerID
|
||||
copy(req[36:56], peers[rand.Intn(len(peers))])
|
||||
|
||||
var down, left uint64
|
||||
if rand.Intn(2) == 0 {
|
||||
down, left = 1, 0
|
||||
} else {
|
||||
down, left = 0, 1
|
||||
}
|
||||
// Downloaded
|
||||
binary.BigEndian.PutUint64(req[56:64], down)
|
||||
// Left
|
||||
binary.BigEndian.PutUint64(req[64:72], left)
|
||||
|
||||
// Event
|
||||
req[83] = 1
|
||||
|
||||
// Numwant
|
||||
req[92], req[95] = byte(announceNumWant>>24), byte(announceNumWant>>16)
|
||||
|
||||
// Port
|
||||
p := rand.Intn(math.MaxInt16) + 1
|
||||
req[96], req[97] = byte(p>>8), byte(p)
|
||||
return req
|
||||
}
|
||||
|
||||
func BenchmarkServerUDPAnnounce(b *testing.B) {
|
||||
var s Server
|
||||
if err := s.Run(QuickConfig); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
|
||||
addr := "127.0.0.1" + frontend.DefaultListenAddress
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
txID, connID, err := sendUDPConnReq(addr)
|
||||
if err != nil {
|
||||
b.Log(err)
|
||||
return
|
||||
}
|
||||
req := buildAnnounceUDPReq(txID, connID)
|
||||
|
||||
conn, err := net.Dial("udp", addr)
|
||||
if err != nil {
|
||||
b.Log(err)
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
n, err := conn.Write(req)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
b.Log(err)
|
||||
return
|
||||
}
|
||||
if n != announceUDPPacketSize {
|
||||
_ = conn.Close()
|
||||
b.Logf("packet not fully sent, %d bytes instead of %d", n, announceUDPPacketSize)
|
||||
return
|
||||
}
|
||||
resp := make([]byte, 1024)
|
||||
n, err = conn.Read(resp)
|
||||
_ = conn.Close()
|
||||
if err != nil {
|
||||
b.Log(err)
|
||||
return
|
||||
}
|
||||
if n < 20 {
|
||||
b.Logf("packet not fully received, only %d bytes", n)
|
||||
return
|
||||
}
|
||||
action := binary.BigEndian.Uint32(resp[:4])
|
||||
if action != 1 {
|
||||
if action == 3 {
|
||||
errVal := string(resp[8:n])
|
||||
b.Logf("tracker error: %s", errVal)
|
||||
} else {
|
||||
b.Logf("unexpected action: %d", action)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(resp[4:8], req[12:16]) {
|
||||
b.Log("transaction ID missmatch")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func sendHTTPReq(u string) (err error) {
|
||||
var r *http.Response
|
||||
// nolint:gosec
|
||||
if r, err = http.Get(u); err == nil {
|
||||
defer r.Body.Close()
|
||||
if r.StatusCode < 400 {
|
||||
var out []byte
|
||||
if out, err = io.ReadAll(r.Body); err == nil {
|
||||
sout := string(out)
|
||||
if strings.Contains(sout, "failure reason") {
|
||||
return errors.New(sout)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return errors.New(r.Status)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func BenchmarkServerHTTPAnnounce(b *testing.B) {
|
||||
var s Server
|
||||
if err := s.Run(QuickConfig); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
|
||||
reqs := make([]string, len(hashes)*len(peers))
|
||||
addr := "127.0.0.1" + frontend.DefaultListenAddress
|
||||
for i := range reqs {
|
||||
var down, left string
|
||||
if rand.Intn(2) == 0 {
|
||||
down, left = "1", "0"
|
||||
} else {
|
||||
down, left = "0", "1"
|
||||
}
|
||||
u := url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr,
|
||||
Path: hf.DefaultAnnounceRoute,
|
||||
RawQuery: url.Values{
|
||||
"event": []string{bittorrent.StartedStr},
|
||||
"compact": []string{"1"},
|
||||
"left": []string{left},
|
||||
"downloaded": []string{down},
|
||||
"uploaded": []string{"0"},
|
||||
"numwant": []string{"1"},
|
||||
"port": []string{strconv.FormatInt(int64(rand.Intn(math.MaxInt16)+1), 10)},
|
||||
"info_hash": []string{str2bytes.BytesToString(hashes[rand.Intn(len(hashes))])},
|
||||
"peer_id": []string{str2bytes.BytesToString(peers[rand.Intn(len(peers))])},
|
||||
}.Encode(),
|
||||
}
|
||||
reqs[i] = u.String()
|
||||
}
|
||||
|
||||
var cnt atomic.Uint32
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
if err := sendHTTPReq(reqs[int(cnt.Add(1))%len(reqs)]); err != nil {
|
||||
b.Log(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Vendored
+3
-3
@@ -190,10 +190,10 @@ prehooks:
|
||||
#
|
||||
# - name: client approval
|
||||
# config:
|
||||
# whitelist:
|
||||
# client_id_list:
|
||||
# - "OP1011"
|
||||
# blacklist:
|
||||
# - "OP1012"
|
||||
# true - whitelist mode, false - blacklist
|
||||
# invert: true
|
||||
#
|
||||
# - name: interval variation
|
||||
# config:
|
||||
|
||||
@@ -51,11 +51,15 @@ type Config struct {
|
||||
}
|
||||
|
||||
const (
|
||||
defaultReadTimeout = 2 * time.Second
|
||||
defaultWriteTimeout = 2 * time.Second
|
||||
defaultIdleTimeout = 30 * time.Second
|
||||
defaultAnnounceRoute = "/announce"
|
||||
defaultScrapeRoute = "/scrape"
|
||||
defaultReadTimeout = 2 * time.Second
|
||||
defaultWriteTimeout = 2 * time.Second
|
||||
defaultIdleTimeout = 30 * time.Second
|
||||
// DefaultAnnounceRoute is the default url path to listen announce
|
||||
// requests if nothing else provided
|
||||
DefaultAnnounceRoute = "/announce"
|
||||
// DefaultScrapeRoute is the default url path to listen scrape
|
||||
// requests if nothing else provided
|
||||
DefaultScrapeRoute = "/scrape"
|
||||
)
|
||||
|
||||
// Validate sanity checks values set in a config and returns a new config with
|
||||
@@ -98,7 +102,7 @@ func (cfg Config) Validate() (validCfg Config, err error) {
|
||||
}
|
||||
}
|
||||
if len(cfg.AnnounceRoutes) == 0 {
|
||||
validCfg.AnnounceRoutes = []string{defaultAnnounceRoute}
|
||||
validCfg.AnnounceRoutes = []string{DefaultAnnounceRoute}
|
||||
logger.Warn().
|
||||
Str("name", "AnnounceRoutes").
|
||||
Strs("provided", cfg.AnnounceRoutes).
|
||||
@@ -106,7 +110,7 @@ func (cfg Config) Validate() (validCfg Config, err error) {
|
||||
Msg("falling back to default configuration")
|
||||
}
|
||||
if len(cfg.ScrapeRoutes) == 0 {
|
||||
validCfg.ScrapeRoutes = []string{defaultScrapeRoute}
|
||||
validCfg.ScrapeRoutes = []string{DefaultScrapeRoute}
|
||||
logger.Warn().
|
||||
Str("name", "ScrapeRoutes").
|
||||
Strs("provided", cfg.ScrapeRoutes).
|
||||
@@ -211,7 +215,10 @@ func runServer(s *fasthttp.Server, cfg *Config) {
|
||||
err = s.ServeTLS(ln, "", "")
|
||||
}
|
||||
}
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
defer ln.Close()
|
||||
if err == nil {
|
||||
logger.Info().Msg("server stopped")
|
||||
} else if !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Fatal().Err(err).Msg("server failed")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,6 @@ import (
|
||||
"github.com/sot-tech/mochi/pkg/log"
|
||||
)
|
||||
|
||||
const iterations = 10000
|
||||
|
||||
var (
|
||||
addr = fmt.Sprintf("127.0.0.1:%d", rand.Int63n(10000)+16384)
|
||||
hashes = make([]string, 10)
|
||||
@@ -87,7 +85,7 @@ func BenchmarkPing(b *testing.B) {
|
||||
Path: "ping",
|
||||
}
|
||||
us := u.String()
|
||||
for i := 0; i < iterations; i++ {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := runGet(us, false); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
@@ -95,11 +93,11 @@ func BenchmarkPing(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkAnnounce(b *testing.B) {
|
||||
for i := 0; i < iterations; i++ {
|
||||
for i := 0; i < b.N; i++ {
|
||||
u := url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr,
|
||||
Path: defaultAnnounceRoute,
|
||||
Path: DefaultAnnounceRoute,
|
||||
RawQuery: url.Values{
|
||||
"event": []string{bittorrent.StartedStr},
|
||||
"compact": []string{"1"},
|
||||
@@ -119,11 +117,11 @@ func BenchmarkAnnounce(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkScrape(b *testing.B) {
|
||||
for i := 0; i < iterations; i++ {
|
||||
for i := 0; i < b.N; i++ {
|
||||
u := url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr,
|
||||
Path: defaultScrapeRoute,
|
||||
Path: DefaultScrapeRoute,
|
||||
RawQuery: url.Values{"info_hash": hashes[:len(hashes)/2]}.Encode(),
|
||||
}
|
||||
if err := runGet(u.String(), true); err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package http
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
@@ -12,11 +13,11 @@ type queryParams struct {
|
||||
*fasthttp.Args
|
||||
}
|
||||
|
||||
// String returns a string parsed from a query. Every key can be returned as a
|
||||
// GetString returns a string parsed from a query. Every key can be returned as a
|
||||
// string because they are encoded in the URL as strings.
|
||||
func (qp queryParams) String(key string) (string, bool) {
|
||||
func (qp queryParams) GetString(key string) (string, bool) {
|
||||
v := qp.Peek(key)
|
||||
return string(v), v != nil
|
||||
return str2bytes.BytesToString(v), v != nil
|
||||
}
|
||||
|
||||
// InfoHashes returns a list of requested infohashes.
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/frontend"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
)
|
||||
|
||||
// ParseOptions is the configuration used to parse an Announce Request.
|
||||
@@ -40,7 +41,7 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou
|
||||
// Attempt to parse the event from the request.
|
||||
var eventStr string
|
||||
var err error
|
||||
eventStr, request.EventProvided = qp.String("event")
|
||||
eventStr, request.EventProvided = qp.GetString("event")
|
||||
if request.EventProvided {
|
||||
if request.Event, err = bittorrent.NewEvent(eventStr); err != nil {
|
||||
return nil, err
|
||||
@@ -57,7 +58,6 @@ func parseAnnounce(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.Annou
|
||||
if len(infoHashes) > 1 {
|
||||
return nil, errMultipleInfoHashes
|
||||
}
|
||||
// FIXME: make sure that we have a copy of InfoHash
|
||||
request.InfoHash = infoHashes[0]
|
||||
|
||||
// Parse the PeerID from the request.
|
||||
@@ -137,7 +137,7 @@ func parseScrape(r *fasthttp.RequestCtx, opts ParseOptions) (*bittorrent.ScrapeR
|
||||
func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (addresses bittorrent.RequestAddresses) {
|
||||
if opts.AllowIPSpoofing {
|
||||
for _, f := range []string{"ip", "ipv4", "ipv6"} {
|
||||
if ipStr, ok := p.String(f); ok {
|
||||
if ipStr, ok := p.GetString(f); ok {
|
||||
addresses.Add(parseRequestAddress(ipStr, true))
|
||||
}
|
||||
}
|
||||
@@ -147,7 +147,7 @@ func requestedIPs(r *fasthttp.RequestCtx, p *queryParams, opts ParseOptions) (ad
|
||||
for _, ipStr := range ipValues {
|
||||
for _, ipStr := range bytes.Split(ipStr, []byte{','}) {
|
||||
if ipStr = bytes.TrimSpace(ipStr); len(ipStr) > 0 {
|
||||
addresses.Add(parseRequestAddress(string(ipStr), false))
|
||||
addresses.Add(parseRequestAddress(str2bytes.BytesToString(ipStr), false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+24
-15
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
var respBufferPool = bytepool.NewBufferPool()
|
||||
|
||||
func writeErrorResponse(w io.StringWriter, err error) {
|
||||
message := "internal server error"
|
||||
message := "mochi internal error"
|
||||
var clientErr bittorrent.ClientError
|
||||
if errors.As(err, &clientErr) {
|
||||
message = clientErr.Error()
|
||||
@@ -96,28 +97,36 @@ func dictAddress(bb *bytes.Buffer, peer bittorrent.Peer, includePeerID bool) {
|
||||
bb.WriteString(addr)
|
||||
if includePeerID {
|
||||
bb.WriteString("7:peer id20:")
|
||||
bb.Write(peer.ID[:])
|
||||
bb.Write(peer.ID.Bytes())
|
||||
}
|
||||
bb.WriteString("4:porti")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(peer.Port())))
|
||||
bb.Write([]byte{'e', 'e'})
|
||||
port := peer.Port()
|
||||
bb.Write([]byte{byte(port >> 8), byte(port), 'e', 'e'})
|
||||
}
|
||||
|
||||
func writeScrapeResponse(w io.Writer, resp *bittorrent.ScrapeResponse) {
|
||||
bb := respBufferPool.Get()
|
||||
defer respBufferPool.Put(bb)
|
||||
bb.WriteString("d5:filesd")
|
||||
for _, scrape := range resp.Files {
|
||||
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
|
||||
bb.WriteByte(':')
|
||||
bb.Write([]byte(scrape.InfoHash))
|
||||
bb.WriteString("d8:completei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
|
||||
bb.WriteString("e10:downloadedi")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
|
||||
bb.WriteString("e10:incompletei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
|
||||
bb.WriteString("ee")
|
||||
l := len(resp.Data)
|
||||
if l > 0 {
|
||||
if l > 1 {
|
||||
sort.Slice(resp.Data, func(i, j int) bool {
|
||||
return resp.Data[i].InfoHash < resp.Data[j].InfoHash
|
||||
})
|
||||
}
|
||||
for _, scrape := range resp.Data {
|
||||
bb.Write(fasthttp.AppendUint(nil, len(scrape.InfoHash)))
|
||||
bb.WriteByte(':')
|
||||
bb.Write([]byte(scrape.InfoHash))
|
||||
bb.WriteString("d8:completei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Complete)))
|
||||
bb.WriteString("e10:downloadedi")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Snatches)))
|
||||
bb.WriteString("e10:incompletei")
|
||||
bb.Write(fasthttp.AppendUint(nil, int(scrape.Incomplete)))
|
||||
bb.Write([]byte{'e', 'e'})
|
||||
}
|
||||
}
|
||||
bb.Write([]byte{'e', 'e'})
|
||||
_, _ = bb.WriteTo(w)
|
||||
|
||||
+4
-2
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/libp2p/go-reuseport"
|
||||
)
|
||||
|
||||
const defaultListenAddress = ":6969"
|
||||
// DefaultListenAddress is the default net listen address
|
||||
// if nothing else provided
|
||||
const DefaultListenAddress = ":6969"
|
||||
|
||||
var errUnexpectedListenerType = errors.New("unexpected listener type")
|
||||
|
||||
@@ -26,7 +28,7 @@ type ListenOptions struct {
|
||||
func (lo ListenOptions) Validate(logger *log.Logger) (validOptions ListenOptions) {
|
||||
validOptions = lo
|
||||
if len(lo.Addr) == 0 {
|
||||
validOptions.Addr = defaultListenAddress
|
||||
validOptions.Addr = DefaultListenAddress
|
||||
logger.Warn().
|
||||
Str("name", "Addr").
|
||||
Str("provided", lo.Addr).
|
||||
|
||||
@@ -34,7 +34,7 @@ const (
|
||||
|
||||
var (
|
||||
logger = log.NewLogger("frontend/udp")
|
||||
allowedGeneratedPrivateKeyRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
|
||||
allowedGeneratedPrivateKeyRunes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -57,16 +57,16 @@ func (cfg Config) Validate() (validCfg Config) {
|
||||
validCfg.ListenOptions = cfg.ListenOptions.Validate(logger)
|
||||
|
||||
if cfg.Workers == 0 {
|
||||
cfg.Workers = 1
|
||||
validCfg.Workers = 1
|
||||
}
|
||||
if cfg.Workers > 1 && !cfg.ReusePort {
|
||||
cfg.ReusePort = true
|
||||
if validCfg.Workers > 1 && !validCfg.ReusePort {
|
||||
validCfg.ReusePort = true
|
||||
logger.Warn().Msg("forcibly enabling ReusePort because Workers > 1")
|
||||
}
|
||||
|
||||
// Generate a private key if one isn't provided by the user.
|
||||
if cfg.PrivateKey == "" {
|
||||
pkeyRunes := make([]rune, defaultKeyLen)
|
||||
pkeyRunes := make([]byte, defaultKeyLen)
|
||||
for i := range pkeyRunes {
|
||||
pkeyRunes[i] = allowedGeneratedPrivateKeyRunes[rand.Intn(len(allowedGeneratedPrivateKeyRunes))]
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
)
|
||||
|
||||
// ErrInvalidQueryEscape is returned when a query string contains invalid
|
||||
@@ -60,7 +61,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
||||
key, value = key[:i], key[i+1:]
|
||||
}
|
||||
var k, v string
|
||||
k, err = url.QueryUnescape(string(key))
|
||||
k, err = url.QueryUnescape(str2bytes.BytesToString(key))
|
||||
if err != nil {
|
||||
// QueryUnescape returns an error like "invalid escape: '%x'".
|
||||
// But frontends record these errors to prometheus, which generates
|
||||
@@ -68,7 +69,7 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
||||
// We log it here for debugging instead.
|
||||
return nil, ErrInvalidQueryEscape
|
||||
}
|
||||
v, err = url.QueryUnescape(string(value))
|
||||
v, err = url.QueryUnescape(str2bytes.BytesToString(value))
|
||||
if err != nil {
|
||||
// QueryUnescape returns an error like "invalid escape: '%x'".
|
||||
// But frontends record these errors to prometheus, which generates
|
||||
@@ -83,9 +84,9 @@ func parseQuery(query []byte) (q *queryParams, err error) {
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// String returns a string parsed from a query. Every key can be returned as a
|
||||
// GetString returns a string parsed from a query. Every key can be returned as a
|
||||
// string because they are encoded in the URL as strings.
|
||||
func (qp queryParams) String(key string) (string, bool) {
|
||||
func (qp queryParams) GetString(key string) (string, bool) {
|
||||
value, ok := qp.params[strings.ToLower(key)]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
@@ -71,7 +71,6 @@ func parseAnnounce(r Request, v6Action bool, opts frontend.ParseOptions) (*bitto
|
||||
|
||||
// XXX: pure V2 hashes will cause invalid parsing,
|
||||
// but BEP-52 says, that V2 hashes SHOULD be truncated
|
||||
// FIXME: make sure that we have a copy of InfoHash
|
||||
request.InfoHash, err = bittorrent.NewInfoHash(r.Packet[16:36])
|
||||
if err != nil {
|
||||
return nil, errInvalidInfoHash
|
||||
@@ -178,7 +177,6 @@ func parseScrape(r Request, opts frontend.ParseOptions) (*bittorrent.ScrapeReque
|
||||
var request *bittorrent.ScrapeRequest
|
||||
for len(r.Packet) >= bittorrent.InfoHashV1Len {
|
||||
var ih bittorrent.InfoHash
|
||||
// FIXME: make sure that we have a copy of InfoHash
|
||||
if ih, err = bittorrent.NewInfoHash(r.Packet[:bittorrent.InfoHashV1Len]); err == nil {
|
||||
infoHashes = append(infoHashes, ih)
|
||||
r.Packet = r.Packet[bittorrent.InfoHashV1Len:]
|
||||
|
||||
@@ -64,7 +64,7 @@ func TestHandleOptionalParameters(t *testing.T) {
|
||||
t.Fatalf("expected values %v for %x", tt.values, tt.data)
|
||||
} else {
|
||||
for key, want := range tt.values {
|
||||
if got, ok := params.String(key); !ok {
|
||||
if got, ok := params.GetString(key); !ok {
|
||||
t.Fatalf("params missing entry %s for data %x", key, tt.data)
|
||||
} else if got != want {
|
||||
t.Fatalf("expected param %s=%s, but was %s for data %x", key, want, got, tt.data)
|
||||
|
||||
+12
-9
@@ -14,13 +14,17 @@ func writeErrorResponse(w io.Writer, txID []byte, err error) {
|
||||
buf := reqRespBufferPool.Get()
|
||||
defer reqRespBufferPool.Put(buf)
|
||||
writeHeader(buf, txID, errorActionID)
|
||||
message := "mochi internal error"
|
||||
var clientErr bittorrent.ClientError
|
||||
// If the client wasn't at fault, acknowledge it.
|
||||
if !errors.As(err, new(bittorrent.ClientError)) {
|
||||
buf.WriteString("internal error occurred: ")
|
||||
if errors.As(err, &clientErr) {
|
||||
message = clientErr.Error()
|
||||
} else {
|
||||
logger.Error().Err(err).Msg("internal error")
|
||||
}
|
||||
buf.WriteString(err.Error())
|
||||
buf.WriteString(message)
|
||||
buf.WriteByte('\000')
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeAnnounceResponse encodes an announce response according to BEP 15.
|
||||
@@ -51,7 +55,7 @@ func writeAnnounceResponse(w io.Writer, txID []byte, resp *bittorrent.AnnounceRe
|
||||
_ = binary.Write(buf, binary.BigEndian, peer.Port())
|
||||
}
|
||||
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeScrapeResponse encodes a scrape response according to BEP 15.
|
||||
@@ -61,13 +65,12 @@ func writeScrapeResponse(w io.Writer, txID []byte, resp *bittorrent.ScrapeRespon
|
||||
|
||||
writeHeader(buf, txID, scrapeActionID)
|
||||
|
||||
for _, scrape := range resp.Files {
|
||||
for _, scrape := range resp.Data {
|
||||
_ = binary.Write(buf, binary.BigEndian, scrape.Complete)
|
||||
_ = binary.Write(buf, binary.BigEndian, scrape.Snatches)
|
||||
_ = binary.Write(buf, binary.BigEndian, scrape.Incomplete)
|
||||
}
|
||||
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeConnectionID encodes a new connection response according to BEP 15.
|
||||
@@ -77,7 +80,7 @@ func writeConnectionID(w io.Writer, txID, connID []byte) {
|
||||
|
||||
writeHeader(buf, txID, connectActionID)
|
||||
buf.Write(connID)
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
_, _ = buf.WriteTo(w)
|
||||
}
|
||||
|
||||
// writeHeader writes the action and transaction ID to the provided response
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
module github.com/sot-tech/mochi
|
||||
|
||||
go 1.19
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
code.cloudfoundry.org/go-diodes v0.0.0-20230317203753-49f1af6d2f1a
|
||||
github.com/MicahParks/keyfunc v1.9.0
|
||||
github.com/anacrolix/torrent v1.48.0
|
||||
github.com/anacrolix/torrent v1.49.0
|
||||
github.com/cespare/xxhash/v2 v2.2.0
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
github.com/jackc/pgx/v5 v5.3.1
|
||||
@@ -41,7 +41,7 @@ require (
|
||||
github.com/klauspost/compress v1.16.3 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.17 // indirect
|
||||
github.com/mattn/go-isatty v0.0.18 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
|
||||
@@ -40,8 +40,8 @@ github.com/anacrolix/stm v0.2.0/go.mod h1:zoVQRvSiGjGoTmbM0vSLIiaKjWtNPeTvXUSdJQ
|
||||
github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw=
|
||||
github.com/anacrolix/tagflag v1.0.0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw=
|
||||
github.com/anacrolix/tagflag v1.1.0/go.mod h1:Scxs9CV10NQatSmbyjqmqmeQNwGzlNe0CMUMIxqHIG8=
|
||||
github.com/anacrolix/torrent v1.48.0 h1:OQe1aQb8WnhDzpcI7r3yWoHzHWKyPbfhXGfO9Q/pvbY=
|
||||
github.com/anacrolix/torrent v1.48.0/go.mod h1:3UtkJ8BnxXDRwvk+eT+uwiZalfFJ8YzAhvxe4QRPSJI=
|
||||
github.com/anacrolix/torrent v1.49.0 h1:v/TAd8BKsZarYEYv7VkPNv8tY5zZCwQyxMMlKKbAF4I=
|
||||
github.com/anacrolix/torrent v1.49.0/go.mod h1:qT3yS5oQwDUHnBXy+zf3nozLPudG7SFNDL3Jl/zQwFw=
|
||||
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
|
||||
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
|
||||
@@ -176,8 +176,8 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
|
||||
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
|
||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
// - https://github.com/webtorrent/bittorrent-peerid/blob/master/lib/utils.js
|
||||
package clientapproval
|
||||
|
||||
import "github.com/sot-tech/mochi/bittorrent"
|
||||
import (
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
)
|
||||
|
||||
// ClientID represents the part of a PeerID that identifies a Peer's client
|
||||
// software.
|
||||
@@ -14,14 +16,9 @@ type ClientID [6]byte
|
||||
func NewClientID(pid bittorrent.PeerID) ClientID {
|
||||
var cid ClientID
|
||||
if pid[0] == '-' {
|
||||
copy(cid[:], pid[1:7])
|
||||
cid = ClientID(pid[1:7])
|
||||
} else {
|
||||
copy(cid[:], pid[:6])
|
||||
cid = ClientID(pid[:6])
|
||||
}
|
||||
|
||||
return cid
|
||||
}
|
||||
|
||||
func (cid ClientID) String() string {
|
||||
return string(cid[:])
|
||||
}
|
||||
|
||||
@@ -20,23 +20,21 @@ func init() {
|
||||
middleware.RegisterBuilder(Name, build)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
|
||||
ErrClientUnapproved = bittorrent.ClientError("unapproved client")
|
||||
|
||||
errBothListsProvided = errors.New("using both whitelist and blacklist is invalid")
|
||||
)
|
||||
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
|
||||
var ErrClientUnapproved = bittorrent.ClientError("client not allowed by mochi")
|
||||
|
||||
// Config represents all the values required by this middleware to validate
|
||||
// peers based on their BitTorrent client ID.
|
||||
type Config struct {
|
||||
Whitelist []string
|
||||
Blacklist []string
|
||||
// Static list of client IDs.
|
||||
ClientIDList []string `cfg:"client_id_list"`
|
||||
// If Invert set to true, all client IDs stored in ClientIDList should be blacklisted.
|
||||
Invert bool
|
||||
}
|
||||
|
||||
type hook struct {
|
||||
approved map[ClientID]struct{}
|
||||
unapproved map[ClientID]struct{}
|
||||
clientIDs map[ClientID]any
|
||||
invert bool
|
||||
}
|
||||
|
||||
func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) {
|
||||
@@ -47,53 +45,31 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error
|
||||
}
|
||||
|
||||
h := &hook{
|
||||
approved: make(map[ClientID]struct{}),
|
||||
unapproved: make(map[ClientID]struct{}),
|
||||
clientIDs: make(map[ClientID]any, len(cfg.ClientIDList)),
|
||||
invert: cfg.Invert,
|
||||
}
|
||||
|
||||
if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 {
|
||||
return nil, errBothListsProvided
|
||||
}
|
||||
|
||||
for _, cidString := range cfg.Whitelist {
|
||||
for _, cidString := range cfg.ClientIDList {
|
||||
cidBytes := []byte(cidString)
|
||||
if len(cidBytes) != 6 {
|
||||
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
|
||||
}
|
||||
var cid ClientID
|
||||
copy(cid[:], cidBytes)
|
||||
h.approved[cid] = struct{}{}
|
||||
}
|
||||
|
||||
for _, cidString := range cfg.Blacklist {
|
||||
cidBytes := []byte(cidString)
|
||||
if len(cidBytes) != 6 {
|
||||
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
|
||||
}
|
||||
var cid ClientID
|
||||
copy(cid[:], cidBytes)
|
||||
h.unapproved[cid] = struct{}{}
|
||||
h.clientIDs[ClientID(cidBytes)] = true
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// HandleAnnounce checks if specified ClientID is approved or not.
|
||||
// If Config.Invert set to true and hash found in provided list, function will return ErrClientUnapproved,
|
||||
// that means that ClientID is blacklisted.
|
||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||
clientID := NewClientID(req.ID)
|
||||
|
||||
if len(h.approved) > 0 {
|
||||
if _, found := h.approved[clientID]; !found {
|
||||
return ctx, ErrClientUnapproved
|
||||
}
|
||||
var err error
|
||||
if _, contains := h.clientIDs[NewClientID(req.ID)]; contains == h.invert {
|
||||
err = ErrClientUnapproved
|
||||
}
|
||||
|
||||
if len(h.unapproved) > 0 {
|
||||
if _, found := h.unapproved[clientID]; found {
|
||||
return ctx, ErrClientUnapproved
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
|
||||
|
||||
@@ -19,7 +19,7 @@ var cases = []struct {
|
||||
// Client ID is whitelisted
|
||||
{
|
||||
Config{
|
||||
Whitelist: []string{"010203"},
|
||||
ClientIDList: []string{"010203"},
|
||||
},
|
||||
"01020304050607080900",
|
||||
true,
|
||||
@@ -27,7 +27,7 @@ var cases = []struct {
|
||||
// Client ID is not whitelisted
|
||||
{
|
||||
Config{
|
||||
Whitelist: []string{"010203"},
|
||||
ClientIDList: []string{"010203"},
|
||||
},
|
||||
"10203040506070809000",
|
||||
false,
|
||||
@@ -35,7 +35,8 @@ var cases = []struct {
|
||||
// Client ID is not blacklisted
|
||||
{
|
||||
Config{
|
||||
Blacklist: []string{"010203"},
|
||||
ClientIDList: []string{"010203"},
|
||||
Invert: true,
|
||||
},
|
||||
"00000000001234567890",
|
||||
true,
|
||||
@@ -43,7 +44,8 @@ var cases = []struct {
|
||||
// Client ID is blacklisted
|
||||
{
|
||||
Config{
|
||||
Blacklist: []string{"123456"},
|
||||
ClientIDList: []string{"123456"},
|
||||
Invert: true,
|
||||
},
|
||||
"12345678900000000000",
|
||||
false,
|
||||
@@ -53,7 +55,7 @@ var cases = []struct {
|
||||
func TestHandleAnnounce(t *testing.T) {
|
||||
for _, tt := range cases {
|
||||
t.Run(fmt.Sprintf("testing peerid %s", tt.peerID), func(t *testing.T) {
|
||||
c := conf.MapConfig{"whitelist": tt.cfg.Whitelist, "blacklist": tt.cfg.Blacklist}
|
||||
c := conf.MapConfig{"client_id_list": tt.cfg.ClientIDList, "invert": tt.cfg.Invert}
|
||||
h, err := build(c, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
@@ -67,7 +69,7 @@ func TestHandleAnnounce(t *testing.T) {
|
||||
|
||||
nctx, err := h.HandleAnnounce(ctx, req, resp)
|
||||
require.Equal(t, ctx, nctx)
|
||||
if tt.approved == true {
|
||||
if tt.approved {
|
||||
require.NotEqual(t, err, ErrClientUnapproved)
|
||||
} else {
|
||||
require.Equal(t, err, ErrClientUnapproved)
|
||||
|
||||
+1
-1
@@ -223,7 +223,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Files = append(resp.Files, scr)
|
||||
resp.Data = append(resp.Data, scr)
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
|
||||
@@ -34,10 +34,10 @@ func init() {
|
||||
var (
|
||||
logger = log.NewLogger("middleware/jwt")
|
||||
// ErrMissingJWT is returned when a JWT is missing from a request.
|
||||
ErrMissingJWT = bittorrent.ClientError("unapproved request: missing jwt")
|
||||
ErrMissingJWT = bittorrent.ClientError("request not allowed by mochi: missing jwt")
|
||||
|
||||
// ErrInvalidJWT is returned when a JWT fails to verify.
|
||||
ErrInvalidJWT = bittorrent.ClientError("unapproved request: invalid jwt")
|
||||
ErrInvalidJWT = bittorrent.ClientError("request not allowed by mochi: invalid jwt")
|
||||
|
||||
errJWKsNotSet = errors.New("required parameters not provided: Issuer/Audience/JWKSetURL")
|
||||
|
||||
@@ -160,7 +160,7 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
||||
err = ErrInvalidJWT
|
||||
} else {
|
||||
var claimIH bittorrent.InfoHash
|
||||
if claimIH, err = bittorrent.NewInfoHash(claims.InfoHash); err != nil {
|
||||
if claimIH, err = bittorrent.NewInfoHashString(claims.InfoHash); err != nil {
|
||||
logger.Info().
|
||||
Err(err).
|
||||
Object("source", req.RequestPeer).
|
||||
@@ -206,7 +206,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
||||
} else {
|
||||
var claimIHs bittorrent.InfoHashes
|
||||
for _, s := range claims.InfoHashes {
|
||||
if providedIh, err := bittorrent.NewInfoHash(s); err == nil {
|
||||
if providedIh, err := bittorrent.NewInfoHashString(s); err == nil {
|
||||
claimIHs = append(claimIHs, providedIh)
|
||||
} else {
|
||||
logger.Info().
|
||||
@@ -248,7 +248,7 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest,
|
||||
func (h *hook) getJWTString(params bittorrent.Params) (jwt string) {
|
||||
if params != nil {
|
||||
var found bool
|
||||
if jwt, found = params.String(h.cfg.Header); found {
|
||||
if jwt, found = params.GetString(h.cfg.Header); found {
|
||||
if strings.HasPrefix(strings.ToLower(jwt), bearerAuthPrefix) {
|
||||
jwt = jwt[len(bearerAuthPrefix):]
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ type JWKSKeys struct {
|
||||
|
||||
type params map[string]string
|
||||
|
||||
func (p params) String(key string) (out string, found bool) {
|
||||
func (p params) GetString(key string) (out string, found bool) {
|
||||
out, found = p[key]
|
||||
return
|
||||
}
|
||||
|
||||
+1
-3
@@ -2,7 +2,6 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
@@ -80,14 +79,13 @@ func (l *Logic) AfterAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
||||
func (l *Logic) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest) (_ context.Context, resp *bittorrent.ScrapeResponse, err error) {
|
||||
logger.Debug().Object("request", req).Msg("new scrape request")
|
||||
resp = &bittorrent.ScrapeResponse{
|
||||
Files: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
|
||||
Data: make([]bittorrent.Scrape, 0, len(req.InfoHashes)),
|
||||
}
|
||||
for _, h := range l.preHooks {
|
||||
if ctx, err = h.HandleScrape(ctx, req, resp); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
sort.Sort(&resp.Files)
|
||||
|
||||
logger.Debug().Object("response", resp).Msg("generated scrape response")
|
||||
return ctx, resp, nil
|
||||
|
||||
@@ -56,7 +56,7 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er
|
||||
if len(c.HashList) > 0 {
|
||||
init := make([]storage.Entry, 0, len(c.HashList))
|
||||
for _, hashString := range c.HashList {
|
||||
ih, err := bittorrent.NewInfoHash(hashString)
|
||||
ih, err := bittorrent.NewInfoHashString(hashString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("whitelist : %s : %w", hashString, err)
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func build(config conf.MapConfig, st storage.PeerStorage) (h middleware.Hook, er
|
||||
}
|
||||
|
||||
// ErrTorrentUnapproved is the error returned when a torrent hash is invalid.
|
||||
var ErrTorrentUnapproved = bittorrent.ClientError("unapproved torrent")
|
||||
var ErrTorrentUnapproved = bittorrent.ClientError("torrent not allowed by mochi")
|
||||
|
||||
type hook struct {
|
||||
hashContainer container.Container
|
||||
|
||||
@@ -84,7 +84,7 @@ func TestHandleAnnounce(t *testing.T) {
|
||||
req := &bittorrent.AnnounceRequest{}
|
||||
resp := &bittorrent.AnnounceResponse{}
|
||||
|
||||
hashinfo, err := bittorrent.NewInfoHash(tt.ih)
|
||||
hashinfo, err := bittorrent.NewInfoHashString(tt.ih)
|
||||
require.Nil(t, err)
|
||||
|
||||
req.InfoHash = hashinfo
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
// Package str2bytes provides fast, but unsafe functions to convert string to []byte
|
||||
// or vice versa.
|
||||
package str2bytes
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// StringToBytes converts string to slice of bytes
|
||||
// without allocation. Note, that returned slice
|
||||
// must NOT be modified, since strings in Go are
|
||||
// immutable.
|
||||
// See unsafe.Slice.
|
||||
func StringToBytes(s string) []byte {
|
||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
||||
}
|
||||
|
||||
// BytesToString converts slice of bytes to string
|
||||
// without allocation.
|
||||
// See unsafe.String
|
||||
func BytesToString(b []byte) string {
|
||||
return unsafe.String(unsafe.SliceData(b), len(b))
|
||||
}
|
||||
+12
-12
@@ -114,19 +114,19 @@ func (s *store) delPeer(ctx context.Context, infoHashKey, peerID string) error {
|
||||
}
|
||||
|
||||
func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
|
||||
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), r.PackPeer(peer))
|
||||
}
|
||||
|
||||
func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
|
||||
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), r.PackPeer(peer))
|
||||
}
|
||||
|
||||
func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
|
||||
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), r.PackPeer(peer))
|
||||
}
|
||||
|
||||
func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
|
||||
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), r.PackPeer(peer))
|
||||
}
|
||||
|
||||
func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) (err error) {
|
||||
@@ -134,18 +134,18 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee
|
||||
Stringer("infoHash", ih).
|
||||
Object("peer", peer).
|
||||
Msg("graduate leecher")
|
||||
infoHash, peerID := ih.RawString(), peer.RawString()
|
||||
infoHash, peerID := ih.RawString(), r.PackPeer(peer)
|
||||
ihSeederKey := r.InfoHashKey(infoHash, true, peer.Addr().Is6())
|
||||
ihLeecherKey := r.InfoHashKey(infoHash, false, peer.Addr().Is6())
|
||||
var moved bool
|
||||
if moved, err = s.SMove(ctx, ihLeecherKey, ihSeederKey, peerID).Result(); err == nil {
|
||||
if moved {
|
||||
err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL))
|
||||
} else {
|
||||
err = s.addPeer(ctx, ihSeederKey, peerID)
|
||||
if !moved {
|
||||
err = s.SAdd(ctx, ihSeederKey, peerID).Err()
|
||||
}
|
||||
if err == nil {
|
||||
err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err()
|
||||
if err != nil {
|
||||
if err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL)); err == nil {
|
||||
err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
@@ -160,7 +160,7 @@ func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSe
|
||||
Bool("v6", v6).
|
||||
Msg("announce peers")
|
||||
|
||||
return s.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd {
|
||||
return s.GetPeers(ctx, ih, forSeeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd {
|
||||
return s.SRandMemberN(ctx, infoHashKey, int64(maxCount))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -177,14 +177,15 @@ func (p *peers) len() int {
|
||||
return len(p.m)
|
||||
}
|
||||
|
||||
func (p *peers) keys(fn func(k bittorrent.Peer) bool) {
|
||||
func (p *peers) keys(fn func(k bittorrent.Peer) bool) bool {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
for k := range p.m {
|
||||
if !fn(k) {
|
||||
break
|
||||
return false
|
||||
}
|
||||
}
|
||||
p.RUnlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *peers) forEach(fn func(k bittorrent.Peer, v int64) bool) {
|
||||
@@ -268,7 +269,7 @@ func (ps *peerStore) shardIndex(infoHash bittorrent.InfoHash, v6 bool) uint32 {
|
||||
// There are twice the amount of shards specified by the user, the first
|
||||
// half is dedicated to IPv4 swarms and the second half is dedicated to
|
||||
// IPv6 swarms.
|
||||
idx := binary.BigEndian.Uint32([]byte(infoHash[:4])) % (uint32(len(ps.shards)) / 2)
|
||||
idx := binary.BigEndian.Uint32(infoHash.Bytes()[:4]) % (uint32(len(ps.shards)) / 2)
|
||||
if v6 {
|
||||
idx += uint32(len(ps.shards) / 2)
|
||||
}
|
||||
@@ -417,8 +418,7 @@ func (ps *peerStore) AnnouncePeers(_ context.Context, ih bittorrent.InfoHash, fo
|
||||
if forSeeder {
|
||||
sw.leechers.keys(rangeFn)
|
||||
} else {
|
||||
sw.seeders.keys(rangeFn)
|
||||
if numWant > 0 {
|
||||
if sw.seeders.keys(rangeFn) {
|
||||
sw.leechers.keys(rangeFn)
|
||||
}
|
||||
}
|
||||
@@ -446,8 +446,8 @@ func (ps *peerStore) ScrapeSwarm(_ context.Context, ih bittorrent.InfoHash) (lee
|
||||
Stringer("infoHash", ih).
|
||||
Msg("scrape swarm")
|
||||
|
||||
leechers, seeders = ps.countPeers(ih, true)
|
||||
l, s := ps.countPeers(ih, false)
|
||||
leechers, seeders = ps.countPeers(ih, false)
|
||||
l, s := ps.countPeers(ih, true)
|
||||
leechers, seeders = leechers+l, seeders+s
|
||||
|
||||
return
|
||||
|
||||
+21
-20
@@ -359,15 +359,15 @@ func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
|
||||
func (s *store) putPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, seeder bool) (err error) {
|
||||
logger.Trace().
|
||||
Stringer("infoHash", ih).
|
||||
Hex("infoHash", ih).
|
||||
Object("peer", peer).
|
||||
Bool("seeder", seeder).
|
||||
Msg("put peer")
|
||||
_, err = s.Exec(ctx, s.Peer.AddQuery, pgx.NamedArgs{
|
||||
pInfoHash: []byte(ih),
|
||||
pPeerID: peer.ID[:],
|
||||
pInfoHash: ih,
|
||||
pPeerID: peer.ID.Bytes(),
|
||||
pAddress: net.IP(peer.Addr().AsSlice()),
|
||||
pPort: peer.Port(),
|
||||
pSeeder: seeder,
|
||||
@@ -377,14 +377,14 @@ func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor
|
||||
return
|
||||
}
|
||||
|
||||
func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
|
||||
func (s *store) delPeer(ctx context.Context, ih []byte, peer bittorrent.Peer, seeder bool) (err error) {
|
||||
logger.Trace().
|
||||
Stringer("infoHash", ih).
|
||||
Hex("infoHash", ih).
|
||||
Object("peer", peer).
|
||||
Msg("del peer")
|
||||
_, err = s.Exec(ctx, s.Peer.DelQuery, pgx.NamedArgs{
|
||||
pInfoHash: []byte(ih),
|
||||
pPeerID: peer.ID[:],
|
||||
pInfoHash: ih,
|
||||
pPeerID: peer.ID.Bytes(),
|
||||
pAddress: net.IP(peer.Addr().AsSlice()),
|
||||
pPort: peer.Port(),
|
||||
pSeeder: seeder,
|
||||
@@ -393,19 +393,19 @@ func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor
|
||||
}
|
||||
|
||||
func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.putPeer(ctx, ih, peer, true)
|
||||
return s.putPeer(ctx, ih.Bytes(), peer, true)
|
||||
}
|
||||
|
||||
func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.delPeer(ctx, ih, peer, true)
|
||||
return s.delPeer(ctx, ih.Bytes(), peer, true)
|
||||
}
|
||||
|
||||
func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.putPeer(ctx, ih, peer, false)
|
||||
return s.putPeer(ctx, ih.Bytes(), peer, false)
|
||||
}
|
||||
|
||||
func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return s.delPeer(ctx, ih, peer, false)
|
||||
return s.delPeer(ctx, ih.Bytes(), peer, false)
|
||||
}
|
||||
|
||||
func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
@@ -414,10 +414,10 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee
|
||||
Object("peer", peer).
|
||||
Msg("graduate leecher")
|
||||
var batch pgx.Batch
|
||||
ihb := []byte(ih)
|
||||
ihb := ih.Bytes()
|
||||
batch.Queue(s.Peer.GraduateQuery, pgx.NamedArgs{
|
||||
pInfoHash: ihb,
|
||||
pPeerID: peer.ID[:],
|
||||
pPeerID: peer.ID.Bytes(),
|
||||
pAddress: net.IP(peer.Addr().AsSlice()),
|
||||
pPort: peer.Port(),
|
||||
})
|
||||
@@ -425,10 +425,10 @@ func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pee
|
||||
return s.txBatch(ctx, &batch)
|
||||
}
|
||||
|
||||
func (s *store) getPeers(ctx context.Context, ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) {
|
||||
func (s *store) getPeers(ctx context.Context, ih []byte, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) {
|
||||
var rows pgx.Rows
|
||||
if rows, err = s.Query(ctx, s.Announce.Query, pgx.NamedArgs{
|
||||
pInfoHash: []byte(ih),
|
||||
pInfoHash: ih,
|
||||
pSeeder: seeders,
|
||||
pV6: isV6,
|
||||
pCount: maxCount,
|
||||
@@ -503,12 +503,13 @@ func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSe
|
||||
Int("numWant", numWant).
|
||||
Bool("v6", v6).
|
||||
Msg("announce peers")
|
||||
ihb := ih.Bytes()
|
||||
if forSeeder {
|
||||
peers, err = s.getPeers(ctx, ih, false, numWant, v6)
|
||||
peers, err = s.getPeers(ctx, ihb, false, numWant, v6)
|
||||
} else {
|
||||
if peers, err = s.getPeers(ctx, ih, true, numWant, v6); err == nil {
|
||||
if peers, err = s.getPeers(ctx, ihb, true, numWant, v6); err == nil {
|
||||
var addPeers []bittorrent.Peer
|
||||
addPeers, err = s.getPeers(ctx, ih, false, numWant-len(peers), v6)
|
||||
addPeers, err = s.getPeers(ctx, ihb, false, numWant-len(peers), v6)
|
||||
peers = append(peers, addPeers...)
|
||||
}
|
||||
}
|
||||
@@ -568,7 +569,7 @@ func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (leeche
|
||||
logger.Trace().
|
||||
Stringer("infoHash", ih).
|
||||
Msg("scrape swarm")
|
||||
ihb := []byte(ih)
|
||||
ihb := ih.Bytes()
|
||||
if seeders, leechers, err = s.countPeers(ctx, ihb); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
+55
-12
@@ -24,13 +24,18 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sot-tech/mochi/pkg/str2bytes"
|
||||
|
||||
"github.com/sot-tech/mochi/bittorrent"
|
||||
"github.com/sot-tech/mochi/pkg/conf"
|
||||
@@ -389,20 +394,30 @@ func (ps *store) delPeer(ctx context.Context, infoHashKey, peerCountKey, peerID
|
||||
return err
|
||||
}
|
||||
|
||||
// PackPeer generates concatenation of PeerID, net port and IP-address
|
||||
func PackPeer(p bittorrent.Peer) string {
|
||||
ip := p.Addr()
|
||||
b := make([]byte, bittorrent.PeerIDLen+2+(ip.BitLen()/8))
|
||||
copy(b[:bittorrent.PeerIDLen], p.ID.Bytes())
|
||||
binary.BigEndian.PutUint16(b[bittorrent.PeerIDLen:bittorrent.PeerIDLen+2], p.Port())
|
||||
copy(b[bittorrent.PeerIDLen+2:], ip.AsSlice())
|
||||
return str2bytes.BytesToString(b)
|
||||
}
|
||||
|
||||
func (ps *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
|
||||
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, PackPeer(peer))
|
||||
}
|
||||
|
||||
func (ps *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
|
||||
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, PackPeer(peer))
|
||||
}
|
||||
|
||||
func (ps *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
|
||||
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, PackPeer(peer))
|
||||
}
|
||||
|
||||
func (ps *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
|
||||
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, PackPeer(peer))
|
||||
}
|
||||
|
||||
func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
|
||||
@@ -411,7 +426,7 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe
|
||||
Object("peer", peer).
|
||||
Msg("graduate leecher")
|
||||
|
||||
infoHash, peerID, isV6 := ih.RawString(), peer.RawString(), peer.Addr().Is6()
|
||||
infoHash, peerID, isV6 := ih.RawString(), PackPeer(peer), peer.Addr().Is6()
|
||||
ihSeederKey, ihLeecherKey := InfoHashKey(infoHash, true, isV6), InfoHashKey(infoHash, false, isV6)
|
||||
|
||||
return ps.tx(ctx, func(tx redis.Pipeliner) error {
|
||||
@@ -438,12 +453,42 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe
|
||||
})
|
||||
}
|
||||
|
||||
// peerMinimumLen is the least allowed length of string serialized Peer
|
||||
const peerMinimumLen = bittorrent.PeerIDLen + 2 + net.IPv4len
|
||||
|
||||
var errInvalidPeerDataSize = fmt.Errorf("invalid peer data (must be at least %d bytes (InfoHash + Port + IPv4))", peerMinimumLen)
|
||||
|
||||
// UnpackPeer constructs Peer from serialized by Peer.PackPeer data: PeerID[20by]Port[2by]net.IP[4/16by]
|
||||
func UnpackPeer(data string) (bittorrent.Peer, error) {
|
||||
var peer bittorrent.Peer
|
||||
if len(data) < peerMinimumLen {
|
||||
return peer, errInvalidPeerDataSize
|
||||
}
|
||||
b := str2bytes.StringToBytes(data)
|
||||
peerID, err := bittorrent.NewPeerID(b[:bittorrent.PeerIDLen])
|
||||
if err == nil {
|
||||
if addr, isOk := netip.AddrFromSlice(b[bittorrent.PeerIDLen+2:]); isOk {
|
||||
peer = bittorrent.Peer{
|
||||
ID: peerID,
|
||||
AddrPort: netip.AddrPortFrom(
|
||||
addr.Unmap(),
|
||||
binary.BigEndian.Uint16(b[bittorrent.PeerIDLen:bittorrent.PeerIDLen+2]),
|
||||
),
|
||||
}
|
||||
} else {
|
||||
err = bittorrent.ErrInvalidIP
|
||||
}
|
||||
}
|
||||
|
||||
return peer, err
|
||||
}
|
||||
|
||||
func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers []bittorrent.Peer, err error) {
|
||||
var peerIds []string
|
||||
peerIds, err = peersResult.Result()
|
||||
if err = NoResultErr(err); err == nil {
|
||||
for _, peerID := range peerIds {
|
||||
if p, err := bittorrent.NewPeer(peerID); err == nil {
|
||||
if p, err := UnpackPeer(peerID); err == nil {
|
||||
peers = append(peers, p)
|
||||
} else {
|
||||
logger.Error().Err(err).Str("peerID", peerID).Msg("unable to decode peer")
|
||||
@@ -453,13 +498,13 @@ func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers [
|
||||
return
|
||||
}
|
||||
|
||||
type getPeersFn func(string, int) *redis.StringSliceCmd
|
||||
type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd
|
||||
|
||||
// GetPeers retrieves peers for provided info hash by calling membersFn and
|
||||
// converts result to bittorrent.Peer array.
|
||||
// If forSeeder set to true - returns only leechers, if false -
|
||||
// seeders and if maxCount not reached - leechers.
|
||||
func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) {
|
||||
func (ps *Connection) GetPeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) {
|
||||
infoHash := ih.RawString()
|
||||
|
||||
infoHashKeys := make([]string, 1, 2)
|
||||
@@ -473,7 +518,7 @@ func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount
|
||||
|
||||
for _, infoHashKey := range infoHashKeys {
|
||||
var peers []bittorrent.Peer
|
||||
peers, err = ps.parsePeersList(membersFn(infoHashKey, maxCount))
|
||||
peers, err = ps.parsePeersList(membersFn(ctx, infoHashKey, maxCount))
|
||||
maxCount -= len(peers)
|
||||
out = append(out, peers...)
|
||||
if err != nil || maxCount <= 0 {
|
||||
@@ -501,9 +546,7 @@ func (ps *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forS
|
||||
Bool("v6", v6).
|
||||
Msg("announce peers")
|
||||
|
||||
return ps.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd {
|
||||
return ps.HRandField(ctx, infoHashKey, maxCount)
|
||||
})
|
||||
return ps.GetPeers(ctx, ih, forSeeder, numWant, v6, ps.HRandField)
|
||||
}
|
||||
|
||||
type getPeerCountFn func(context.Context, string) *redis.IntCmd
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ihCount = 1000
|
||||
peersCount = 1000
|
||||
ihCount = 100
|
||||
peersCount = 100
|
||||
)
|
||||
|
||||
type benchData struct {
|
||||
|
||||
Reference in New Issue
Block a user