(untested) Merge commit e56ad81 from https://github.com/jzelinskie/chihaya

* rename/replace redis keys
This commit is contained in:
Lawrence, Rendall
2022-04-14 19:12:33 +03:00
parent 05fe8e113a
commit 5c2471ca9b
23 changed files with 560 additions and 564 deletions

View File

@@ -10,6 +10,7 @@ import (
"encoding/hex"
"fmt"
"net"
"net/netip"
"time"
"github.com/pkg/errors"
@@ -179,17 +180,17 @@ func (r AnnounceResponse) LogFields() log.Fields {
// ScrapeRequest represents the parsed parameters from a scrape request.
type ScrapeRequest struct {
AddressFamily AddressFamily
InfoHashes []InfoHash
Params Params
Peer
InfoHashes []InfoHash
Params Params
}
// LogFields renders the current response as a set of log fields.
func (r ScrapeRequest) LogFields() log.Fields {
return log.Fields{
"addressFamily": r.AddressFamily,
"infoHashes": r.InfoHashes,
"params": r.Params,
"peer": r.Peer,
"infoHashes": r.InfoHashes,
"params": r.Params,
}
}
@@ -216,54 +217,18 @@ type Scrape struct {
Incomplete uint32
}
// AddressFamily is the address family of an IP address.
type AddressFamily uint8
func (af AddressFamily) String() string {
switch af {
case IPv4:
return "IPv4"
case IPv6:
return "IPv6"
default:
return "<unknown>"
}
}
// AddressFamily constants.
const (
IPv4 AddressFamily = iota
IPv6
)
// IP is a net.IP with an AddressFamily.
type IP struct {
net.IP
AddressFamily
}
func (ip IP) String() string {
return ip.IP.String()
}
// Peer represents the connection details of a peer that is returned in an
// announce response.
type Peer struct {
ID PeerID
IP IP
Port uint16
ID PeerID
netip.AddrPort
}
// PeerMinimumLen is the least allowed length of string serialized Peer
const PeerMinimumLen = PeerIDLen + 2 + net.IPv4len
var (
// ErrInvalidAddressFamily holds error about invalid address family
ErrInvalidAddressFamily = fmt.Errorf("address family must be %d(IPv4) or %d(IPv6)", IPv4, IPv6)
// ErrInvalidPeerDataSize holds error about invalid Peer data size
ErrInvalidPeerDataSize = fmt.Errorf("invalid peer data it must be at least %d bytes (InfoHash + Port + IPv4)", PeerMinimumLen)
)
// ErrInvalidPeerDataSize holds error about invalid Peer data size
var ErrInvalidPeerDataSize = fmt.Errorf("invalid peer data it must be at least %d bytes (InfoHash + Port + IPv4)", PeerMinimumLen)
// NewPeer constructs Peer from serialized by Peer.RawString data: PeerID[20by]Port[2by]net.IP[4/16by]
func NewPeer(data string) (Peer, error) {
@@ -271,21 +236,19 @@ func NewPeer(data string) (Peer, error) {
if len(data) < PeerMinimumLen {
return peer, ErrInvalidPeerDataSize
}
peerID, err := NewPeerID([]byte(data[:PeerIDLen]))
b := []byte(data)
peerID, err := NewPeerID(b[:PeerIDLen])
if err == nil {
peer = Peer{
ID: peerID,
Port: binary.BigEndian.Uint16([]byte(data[PeerIDLen : PeerIDLen+2])),
IP: IP{IP: net.IP(data[PeerIDLen+2:])},
}
if ip := peer.IP.To4(); ip != nil {
peer.IP.IP = ip
peer.IP.AddressFamily = IPv4
} else if len(peer.IP.IP) == net.IPv6len { // implies toReturn.IP.To4() == nil
peer.IP.AddressFamily = IPv6
if addr, isOk := netip.AddrFromSlice(b[PeerIDLen+2:]); isOk {
peer = Peer{
ID: peerID,
AddrPort: netip.AddrPortFrom(
addr,
binary.BigEndian.Uint16(b[PeerIDLen:PeerIDLen+2]),
),
}
} else {
err = ErrInvalidAddressFamily
err = ErrInvalidIP
}
}
@@ -296,15 +259,16 @@ func NewPeer(data string) (Peer, error) {
// The string will have the format <PeerID>@[<IP>]:<port>, for example
// "0102030405060708090a0b0c0d0e0f1011121314@[10.11.12.13]:1234"
func (p Peer) String() string {
return fmt.Sprintf("%s@[%s]:%d", p.ID.String(), p.IP.String(), p.Port)
return fmt.Sprintf("%s@[%s]:%d", p.ID, p.Addr(), p.Port())
}
// RawString generates concatenation of PeerID, net port and IP-address
func (p Peer) RawString() string {
b := make([]byte, PeerIDLen+2+len(p.IP.IP))
ip := p.Addr().Unmap()
b := make([]byte, PeerIDLen+2+(ip.BitLen()/8))
copy(b[:PeerIDLen], p.ID[:])
binary.BigEndian.PutUint16(b[PeerIDLen:PeerIDLen+2], p.Port)
copy(b[PeerIDLen+2:], p.IP.IP)
binary.BigEndian.PutUint16(b[PeerIDLen:PeerIDLen+2], p.Port())
copy(b[PeerIDLen+2:], ip.AsSlice())
return string(b)
}
@@ -312,8 +276,8 @@ func (p Peer) RawString() string {
func (p Peer) LogFields() log.Fields {
return log.Fields{
"ID": p.ID,
"IP": p.IP,
"port": p.Port,
"IP": p.Addr().String(),
"port": p.Port(),
}
}
@@ -321,7 +285,10 @@ func (p Peer) LogFields() log.Fields {
func (p Peer) Equal(x Peer) bool { return p.EqualEndpoint(x) && p.ID == x.ID }
// EqualEndpoint reports whether p and x have the same endpoint.
func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP.IP) }
func (p Peer) EqualEndpoint(x Peer) bool {
return p.Port() == x.Port() &&
p.Addr().Compare(x.Addr()) == 0
}
// ClientError represents an error that should be exposed to the client over
// the BitTorrent protocol implementation.

View File

@@ -2,7 +2,7 @@ package bittorrent
import (
"fmt"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
@@ -13,26 +13,6 @@ var (
expected = "0102030405060708090a0b0c0d0e0f1011121314"
)
var peerStringTestCases = []struct {
input Peer
expected string
}{
{
input: Peer{
IP: IP{net.IPv4(10, 11, 12, 1), IPv4},
Port: 1234,
},
expected: fmt.Sprintf("%s@[10.11.12.1]:1234", expected),
},
{
input: Peer{
IP: IP{net.ParseIP("2001:db8::ff00:42:8329"), IPv6},
Port: 1234,
},
expected: fmt.Sprintf("%s@[2001:db8::ff00:42:8329]:1234", expected),
},
}
func TestPeerID_String(t *testing.T) {
pid, err := NewPeerID(b)
require.Nil(t, err)
@@ -49,6 +29,26 @@ func TestInfoHash_String(t *testing.T) {
func TestPeer_String(t *testing.T) {
pid, err := NewPeerID(b)
require.Nil(t, err)
id, _ := NewPeerID(b)
peerStringTestCases := []struct {
input Peer
expected string
}{
{
input: Peer{
ID: id,
AddrPort: netip.MustParseAddrPort("10.11.12.1:1234"),
},
expected: fmt.Sprintf("%s@[10.11.12.1]:1234", expected),
},
{
input: Peer{
ID: id,
AddrPort: netip.MustParseAddrPort("[2001:db8::ff00:42:8329]:1234"),
},
expected: fmt.Sprintf("%s@[2001:db8::ff00:42:8329]:1234", expected),
},
}
for _, c := range peerStringTestCases {
c.input.ID = pid
got := c.input.String()

View File

@@ -1,12 +1,11 @@
package bittorrent
import (
"errors"
"strings"
)
// ErrUnknownEvent is returned when New fails to return an event.
var ErrUnknownEvent = errors.New("unknown event")
var ErrUnknownEvent = ClientError("unknown event")
// Event represents an event done by a BitTorrent client.
type Event uint8
@@ -46,12 +45,14 @@ func init() {
}
// NewEvent returns the proper Event given a string.
func NewEvent(eventStr string) (Event, error) {
func NewEvent(eventStr string) (evt Event, err error) {
if e, ok := stringToEvent[strings.ToLower(eventStr)]; ok {
return e, nil
evt = e
} else {
evt, err = None, ErrUnknownEvent
}
return None, ErrUnknownEvent
return
}
// String implements Stringer for an event.

View File

@@ -1,21 +1,23 @@
package bittorrent
import (
"net"
"net/netip"
"github.com/sot-tech/mochi/pkg/log"
)
// ErrInvalidIP indicates an invalid IP for an Announce.
var ErrInvalidIP = ClientError("invalid IP")
var (
// ErrInvalidIP indicates an invalid IP for an Announce.
ErrInvalidIP = ClientError("invalid IP")
// ErrInvalidPort indicates an invalid Port for an Announce.
var ErrInvalidPort = ClientError("invalid port")
// ErrInvalidPort indicates an invalid Port for an Announce.
ErrInvalidPort = ClientError("invalid port")
)
// SanitizeAnnounce enforces a max and default NumWant and coerces the peer's
// IP address into the proper format.
func SanitizeAnnounce(r *AnnounceRequest, maxNumWant, defaultNumWant uint32) error {
if r.Port == 0 {
if r.Port() == 0 {
return ErrInvalidPort
}
@@ -25,12 +27,8 @@ func SanitizeAnnounce(r *AnnounceRequest, maxNumWant, defaultNumWant uint32) err
r.NumWant = maxNumWant
}
if ip := r.Peer.IP.To4(); ip != nil {
r.Peer.IP.IP = ip
r.Peer.IP.AddressFamily = IPv4
} else if len(r.Peer.IP.IP) == net.IPv6len { // implies r.Peer.IP.To4() == nil
r.Peer.IP.AddressFamily = IPv6
} else {
r.AddrPort = netip.AddrPortFrom(r.Addr().Unmap(), r.Port())
if !r.Addr().IsValid() || r.Addr().IsUnspecified() {
return ErrInvalidIP
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"net"
"net/http"
"net/netip"
"time"
"github.com/julienschmidt/httprouter"
@@ -282,12 +283,12 @@ func (f *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, ps http
if f.EnableRequestTiming {
start = time.Now()
}
var af *bittorrent.AddressFamily
var addr netip.Addr
defer func() {
if f.EnableRequestTiming {
recordResponseDuration("announce", af, err, time.Since(start))
recordResponseDuration("announce", addr, err, time.Since(start))
} else {
recordResponseDuration("announce", af, err, time.Duration(0))
recordResponseDuration("announce", addr, err, time.Duration(0))
}
}()
@@ -296,8 +297,7 @@ func (f *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, ps http
WriteError(w, err)
return
}
af = new(bittorrent.AddressFamily)
*af = req.IP.AddressFamily
addr = req.AddrPort.Addr()
ctx := injectRouteParamsToContext(context.Background(), ps)
ctx, resp, err := f.logic.HandleAnnounce(ctx, req)
@@ -323,12 +323,12 @@ func (f *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httpro
if f.EnableRequestTiming {
start = time.Now()
}
var af *bittorrent.AddressFamily
var addr netip.Addr
defer func() {
if f.EnableRequestTiming {
recordResponseDuration("scrape", af, err, time.Since(start))
recordResponseDuration("scrape", addr, err, time.Since(start))
} else {
recordResponseDuration("scrape", af, err, time.Duration(0))
recordResponseDuration("scrape", addr, err, time.Duration(0))
}
}()
@@ -345,18 +345,12 @@ func (f *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httpro
return
}
reqIP := net.ParseIP(host)
if reqIP.To4() != nil {
req.AddressFamily = bittorrent.IPv4
} else if len(reqIP) == net.IPv6len { // implies reqIP.To4() == nil
req.AddressFamily = bittorrent.IPv6
} else {
addr, err = netip.ParseAddr(host)
if err != nil || addr.IsUnspecified() {
log.Error("http: invalid IP: neither v4 nor v6", log.Fields{"RemoteAddr": r.RemoteAddr})
WriteError(w, bittorrent.ErrInvalidIP)
return
}
af = new(bittorrent.AddressFamily)
*af = req.AddressFamily
ctx := injectRouteParamsToContext(context.Background(), ps)
ctx, resp, err := f.logic.HandleScrape(ctx, req)

View File

@@ -2,8 +2,8 @@ package http
import (
"errors"
"net"
"net/http"
"net/netip"
"github.com/sot-tech/mochi/bittorrent"
)
@@ -28,6 +28,16 @@ const (
defaultMaxScrapeInfoHashes = 50
)
var (
errNoInfoHash = bittorrent.ClientError("no info hash supplied")
errMultipleInfoHashes = bittorrent.ClientError("multiple info hashes supplied")
errInvalidPeerID = bittorrent.ClientError("peer ID invalid or not provided")
errInvalidParameterLeft = bittorrent.ClientError("parameter 'left' invalid or not provided")
errInvalidParameterDownloaded = bittorrent.ClientError("parameter 'downloaded' invalid or not provided")
errInvalidParameterUploaded = bittorrent.ClientError("parameter 'uploaded' invalid or not provided")
errInvalidParameterNumWant = bittorrent.ClientError("parameter 'num want' invalid or not provided")
)
// ParseAnnounce parses an bittorrent.AnnounceRequest from an http.Request.
func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequest, error) {
qp, err := bittorrent.ParseURLData(r.RequestURI)
@@ -41,9 +51,8 @@ func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequ
var eventStr string
eventStr, request.EventProvided = qp.String("event")
if request.EventProvided {
request.Event, err = bittorrent.NewEvent(eventStr)
if err != nil {
return nil, bittorrent.ClientError("failed to provide valid client event")
if request.Event, err = bittorrent.NewEvent(eventStr); err != nil {
return nil, err
}
} else {
request.Event = bittorrent.None
@@ -56,17 +65,17 @@ func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequ
// Parse the infohash from the request.
infoHashes := qp.InfoHashes()
if len(infoHashes) < 1 {
return nil, bittorrent.ClientError("no info_hash parameter supplied")
return nil, errNoInfoHash
}
if len(infoHashes) > 1 {
return nil, bittorrent.ClientError("multiple info_hash parameters supplied")
return nil, errMultipleInfoHashes
}
request.InfoHash = infoHashes[0]
// Parse the PeerID from the request.
peerID, ok := qp.String("peer_id")
if !ok {
return nil, bittorrent.ClientError("failed to parse parameter: peer_id")
return nil, errInvalidPeerID
}
request.Peer.ID, err = bittorrent.NewPeerID([]byte(peerID))
if err != nil {
@@ -75,25 +84,25 @@ func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequ
// Determine the number of remaining bytes for the client.
request.Left, err = qp.Uint("left", 64)
if err != nil {
return nil, bittorrent.ClientError("failed to parse parameter: left")
return nil, errInvalidParameterLeft
}
// Determine the number of bytes downloaded by the client.
request.Downloaded, err = qp.Uint("downloaded", 64)
if err != nil {
return nil, bittorrent.ClientError("failed to parse parameter: downloaded")
return nil, errInvalidParameterDownloaded
}
// Determine the number of bytes shared by the client.
request.Uploaded, err = qp.Uint("uploaded", 64)
if err != nil {
return nil, bittorrent.ClientError("failed to parse parameter: uploaded")
return nil, errInvalidParameterUploaded
}
// Determine the number of peers the client wants in the response.
numwant, err := qp.Uint("numwant", 32)
if err != nil && !errors.Is(err, bittorrent.ErrKeyNotFound) {
return nil, bittorrent.ClientError("failed to parse parameter: numwant")
return nil, errInvalidParameterNumWant
}
// If there were no errors, the user actually provided the numwant.
request.NumWantProvided = err == nil
@@ -102,21 +111,22 @@ func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequ
// Parse the port where the client is listening.
port, err := qp.Uint("port", 16)
if err != nil {
return nil, bittorrent.ClientError("failed to parse parameter: port")
return nil, bittorrent.ErrInvalidPort
}
request.Peer.Port = uint16(port)
// Parse the IP address where the client is listening.
request.Peer.IP.IP, request.IPProvided = requestedIP(r, qp, opts)
if request.Peer.IP.IP == nil {
return nil, bittorrent.ClientError("failed to parse peer IP address")
ip, spoofed, err := requestedIP(r, qp, opts)
if err != nil {
return nil, bittorrent.ErrInvalidIP
}
request.Peer.AddrPort = netip.AddrPortFrom(ip, uint16(port))
request.IPProvided = spoofed
if err = bittorrent.SanitizeAnnounce(request, opts.MaxNumWant, opts.DefaultNumWant); err != nil {
request = nil
}
if err := bittorrent.SanitizeAnnounce(request, opts.MaxNumWant, opts.DefaultNumWant); err != nil {
return nil, err
}
return request, nil
return request, err
}
// ParseScrape parses an bittorrent.ScrapeRequest from an http.Request.
@@ -128,7 +138,7 @@ func ParseScrape(r *http.Request, opts ParseOptions) (*bittorrent.ScrapeRequest,
infoHashes := qp.InfoHashes()
if len(infoHashes) < 1 {
return nil, bittorrent.ClientError("no info_hash parameter supplied")
return nil, errNoInfoHash
}
request := &bittorrent.ScrapeRequest{
@@ -144,27 +154,29 @@ func ParseScrape(r *http.Request, opts ParseOptions) (*bittorrent.ScrapeRequest,
}
// requestedIP determines the IP address for a BitTorrent client request.
func requestedIP(r *http.Request, p bittorrent.Params, opts ParseOptions) (ip net.IP, provided bool) {
func requestedIP(r *http.Request, p bittorrent.Params, opts ParseOptions) (netip.Addr, bool, error) {
if opts.AllowIPSpoofing {
if ipstr, ok := p.String("ip"); ok {
return net.ParseIP(ipstr), true
addr, err := netip.ParseAddr(ipstr)
return addr, true, err
}
if ipstr, ok := p.String("ipv4"); ok {
return net.ParseIP(ipstr), true
addr, err := netip.ParseAddr(ipstr)
return addr, true, err
}
if ipstr, ok := p.String("ipv6"); ok {
return net.ParseIP(ipstr), true
addr, err := netip.ParseAddr(ipstr)
return addr, true, err
}
}
if opts.RealIPHeader != "" {
if ip := r.Header.Get(opts.RealIPHeader); ip != "" {
return net.ParseIP(ip), false
}
if ipstr := r.Header.Get(opts.RealIPHeader); ipstr != "" && opts.RealIPHeader != "" {
addr, err := netip.ParseAddr(ipstr)
return addr, false, err
}
host, _, _ := net.SplitHostPort(r.RemoteAddr)
return net.ParseIP(host), false
addrPort, err := netip.ParseAddrPort(r.RemoteAddr)
return addrPort.Addr(), false, err
}

View File

@@ -2,11 +2,13 @@ package http
import (
"errors"
"net/netip"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/metrics"
)
func init() {
@@ -24,7 +26,7 @@ var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
// recordResponseDuration records the duration of time to respond to a Request
// in milliseconds.
func recordResponseDuration(action string, af *bittorrent.AddressFamily, err error, duration time.Duration) {
func recordResponseDuration(action string, addr netip.Addr, err error, duration time.Duration) {
var errString string
if err != nil {
var clientErr bittorrent.ClientError
@@ -35,16 +37,7 @@ func recordResponseDuration(action string, af *bittorrent.AddressFamily, err err
}
}
var afString string
if af == nil {
afString = "Unknown"
} else if *af == bittorrent.IPv4 {
afString = "IPv4"
} else if *af == bittorrent.IPv6 {
afString = "IPv6"
}
promResponseDurationMilliseconds.
WithLabelValues(action, afString, errString).
WithLabelValues(action, metrics.AddressFamily(addr), errString).
Observe(float64(duration.Nanoseconds()) / float64(time.Millisecond))
}

View File

@@ -98,29 +98,25 @@ func WriteScrapeResponse(w http.ResponseWriter, resp *bittorrent.ScrapeResponse)
}
func compact4(peer bittorrent.Peer) (buf []byte) {
if ip := peer.IP.To4(); ip == nil {
panic("non-IPv4 IP for Peer in IPv4Peers")
} else {
buf = ip
}
buf = append(buf, byte(peer.Port>>8), byte(peer.Port))
ip := peer.AddrPort.Addr().As4()
buf = append(buf, ip[:]...)
port := peer.AddrPort.Port()
buf = append(buf, byte(port>>8), byte(port&0xff))
return
}
func compact6(peer bittorrent.Peer) (buf []byte) {
if ip := peer.IP.To16(); ip == nil {
panic("non-IPv6 IP for Peer in IPv6Peers")
} else {
buf = ip
}
buf = append(buf, byte(peer.Port>>8), byte(peer.Port))
ip := peer.AddrPort.Addr().As16()
buf = append(buf, ip[:]...)
port := peer.AddrPort.Port()
buf = append(buf, byte(port>>8), byte(port&0xff))
return
}
func dict(peer bittorrent.Peer) map[string]any {
return map[string]any{
"peer id": string(peer.ID[:]),
"ip": peer.IP.String(),
"port": peer.Port,
"ip": peer.Addr(),
"port": peer.Port(),
}
}

View File

@@ -4,7 +4,7 @@ import (
"crypto/hmac"
"encoding/binary"
"hash"
"net"
"net/netip"
"time"
"github.com/minio/sha256-simd"
@@ -19,14 +19,14 @@ const ttl = 2 * time.Minute
// described by BEP 15.
// This is a wrapper around creating a new ConnectionIDGenerator and generating
// an ID. It is recommended to use the generator for performance.
func NewConnectionID(ip net.IP, now time.Time, key string) []byte {
func NewConnectionID(ip netip.Addr, now time.Time, key string) []byte {
return NewConnectionIDGenerator(key).Generate(ip, now)
}
// ValidConnectionID determines whether a connection identifier is legitimate.
// This is a wrapper around creating a new ConnectionIDGenerator and validating
// the ID. It is recommended to use the generator for performance.
func ValidConnectionID(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration, key string) bool {
func ValidConnectionID(connectionID []byte, ip netip.Addr, now time.Time, maxClockSkew time.Duration, key string) bool {
return NewConnectionIDGenerator(key).Validate(connectionID, ip, now, maxClockSkew)
}
@@ -85,13 +85,14 @@ func (g *ConnectionIDGenerator) reset() {
// The generated ID is written to g.connID, which is also returned. g.connID
// will be reused, so it must not be referenced after returning the generator
// to a pool and will be overwritten be subsequent calls to Generate!
func (g *ConnectionIDGenerator) Generate(ip net.IP, now time.Time) []byte {
func (g *ConnectionIDGenerator) Generate(ip netip.Addr, now time.Time) []byte {
g.reset()
binary.BigEndian.PutUint32(g.connID, uint32(now.Unix()))
g.mac.Write(g.connID[:4])
g.mac.Write(ip)
ipBytes, _ := ip.MarshalBinary()
g.mac.Write(ipBytes)
g.scratch = g.mac.Sum(g.scratch)
copy(g.connID[4:8], g.scratch[:4])
@@ -100,7 +101,7 @@ func (g *ConnectionIDGenerator) Generate(ip net.IP, now time.Time) []byte {
}
// Validate validates the given connection ID for an IP and the current time.
func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration) bool {
func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip netip.Addr, now time.Time, maxClockSkew time.Duration) bool {
ts := time.Unix(int64(binary.BigEndian.Uint32(connectionID[:4])), 0)
log.Debug("validating connection ID", log.Fields{"connID": connectionID, "ip": ip, "ts": ts, "now": now})
if now.After(ts.Add(ttl)) || ts.After(now.Add(maxClockSkew)) {
@@ -110,7 +111,8 @@ func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip net.IP, now tim
g.reset()
g.mac.Write(connectionID[:4])
g.mac.Write(ip)
ipBytes, _ := ip.MarshalBinary()
g.mac.Write(ipBytes)
g.scratch = g.mac.Sum(g.scratch)
return hmac.Equal(g.scratch[:4], connectionID[4:])
}

View File

@@ -4,7 +4,7 @@ import (
"crypto/hmac"
"encoding/binary"
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
@@ -24,18 +24,19 @@ var golden = []struct {
}{
{0, 1, "127.0.0.1", "", true},
{0, 420420, "127.0.0.1", "", false},
{0, 0, "[::]", "", true},
{0, 0, "::1", "", true},
}
// simpleNewConnectionID generates a new connection ID the explicit way.
// This is used to verify correct behaviour of the generator.
func simpleNewConnectionID(ip net.IP, now time.Time, key string) []byte {
func simpleNewConnectionID(ip netip.Addr, now time.Time, key string) []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint32(buf, uint32(now.Unix()))
mac := hmac.New(sha256.New, []byte(key))
mac.Write(buf[:4])
mac.Write(ip)
ipBytes, _ := ip.MarshalBinary()
mac.Write(ipBytes)
macBytes := mac.Sum(nil)[:4]
copy(buf[4:], macBytes)
@@ -48,8 +49,8 @@ func simpleNewConnectionID(ip net.IP, now time.Time, key string) []byte {
func TestVerification(t *testing.T) {
for _, tt := range golden {
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
got := ValidConnectionID(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute, tt.key)
cid := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
got := ValidConnectionID(cid, netip.MustParseAddr(tt.ip), time.Unix(tt.now, 0), time.Minute, tt.key)
if got != tt.valid {
t.Errorf("expected validity: %t got validity: %t", tt.valid, got)
}
@@ -60,8 +61,8 @@ func TestVerification(t *testing.T) {
func TestGeneration(t *testing.T) {
for _, tt := range golden {
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
want := simpleNewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
got := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
want := simpleNewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
got := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
require.Equal(t, want, got)
})
}
@@ -70,13 +71,13 @@ func TestGeneration(t *testing.T) {
func TestReuseGeneratorGenerate(t *testing.T) {
for _, tt := range golden {
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
cid := NewConnectionID(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
require.Len(t, cid, 8)
gen := NewConnectionIDGenerator(tt.key)
for i := 0; i < 3; i++ {
connID := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0))
connID := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0))
require.Equal(t, cid, connID)
}
})
@@ -87,9 +88,9 @@ func TestReuseGeneratorValidate(t *testing.T) {
for _, tt := range golden {
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
gen := NewConnectionIDGenerator(tt.key)
cid := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0))
cid := gen.Generate(netip.MustParseAddr(tt.ip), time.Unix(tt.createdAt, 0))
for i := 0; i < 3; i++ {
got := gen.Validate(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute)
got := gen.Validate(cid, netip.MustParseAddr(tt.ip), time.Unix(tt.now, 0), time.Minute)
if got != tt.valid {
t.Errorf("expected validity: %t got validity: %t", tt.valid, got)
}
@@ -99,7 +100,7 @@ func TestReuseGeneratorValidate(t *testing.T) {
}
func BenchmarkSimpleNewConnectionID(b *testing.B) {
ip := net.ParseIP("127.0.0.1")
ip := netip.MustParseAddr("127.0.0.1")
key := "some random string that is hopefully at least this long"
createdAt := time.Now()
@@ -116,7 +117,7 @@ func BenchmarkSimpleNewConnectionID(b *testing.B) {
}
func BenchmarkNewConnectionID(b *testing.B) {
ip := net.ParseIP("127.0.0.1")
ip := netip.MustParseAddr("127.0.0.1")
key := "some random string that is hopefully at least this long"
createdAt := time.Now()
@@ -133,7 +134,7 @@ func BenchmarkNewConnectionID(b *testing.B) {
}
func BenchmarkConnectionIDGenerator_Generate(b *testing.B) {
ip := net.ParseIP("127.0.0.1")
ip := netip.MustParseAddr("127.0.0.1")
key := "some random string that is hopefully at least this long"
createdAt := time.Now()
@@ -155,7 +156,7 @@ func BenchmarkConnectionIDGenerator_Generate(b *testing.B) {
}
func BenchmarkValidConnectionID(b *testing.B) {
ip := net.ParseIP("127.0.0.1")
ip := netip.MustParseAddr("127.0.0.1")
key := "some random string that is hopefully at least this long"
createdAt := time.Now()
cid := NewConnectionID(ip, createdAt, key)
@@ -170,7 +171,7 @@ func BenchmarkValidConnectionID(b *testing.B) {
}
func BenchmarkConnectionIDGenerator_Validate(b *testing.B) {
ip := net.ParseIP("127.0.0.1")
ip := netip.MustParseAddr("127.0.0.1")
key := "some random string that is hopefully at least this long"
createdAt := time.Now()
cid := NewConnectionID(ip, createdAt, key)

View File

@@ -7,9 +7,9 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"math/rand"
"net"
"net/netip"
"sync"
"time"
@@ -184,7 +184,7 @@ func (t *Frontend) serve() error {
// Read a UDP packet into a reusable buffer.
buffer := pool.Get()
n, addr, err := t.socket.ReadFromUDP(*buffer)
n, addrPort, err := t.socket.ReadFromUDPAddrPort(*buffer)
if err != nil {
pool.Put(buffer)
var netErr net.Error
@@ -206,24 +206,20 @@ func (t *Frontend) serve() error {
defer t.wg.Done()
defer pool.Put(buffer)
if ip := addr.IP.To4(); ip != nil {
addr.IP = ip
}
// Handle the request.
addr := addrPort.Addr()
var start time.Time
if t.EnableRequestTiming {
start = time.Now()
}
action, af, err := t.handleRequest(
// Make sure the IP is copied, not referenced.
Request{(*buffer)[:n], append([]byte{}, addr.IP...)},
ResponseWriter{t.socket, addr},
action, err := t.handleRequest(
Request{(*buffer)[:n], addr},
ResponseWriter{t.socket, addrPort},
)
if t.EnableRequestTiming {
recordResponseDuration(action, af, err, time.Since(start))
recordResponseDuration(action, addr, err, time.Since(start))
} else {
recordResponseDuration(action, af, err, time.Duration(0))
recordResponseDuration(action, addr, err, time.Duration(0))
}
}()
}
@@ -232,23 +228,23 @@ func (t *Frontend) serve() error {
// Request represents a UDP payload received by a Tracker.
type Request struct {
Packet []byte
IP net.IP
IP netip.Addr
}
// ResponseWriter implements the ability to respond to a Request via the
// io.Writer interface.
type ResponseWriter struct {
socket *net.UDPConn
addr *net.UDPAddr
socket *net.UDPConn
addrPort netip.AddrPort
}
// Write implements the io.Writer interface for a ResponseWriter.
func (w ResponseWriter) Write(b []byte) (int, error) {
return w.socket.WriteToUDP(b, w.addr)
return w.socket.WriteToUDPAddrPort(b, w.addrPort)
}
// handleRequest parses and responds to a UDP Request.
func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string, af *bittorrent.AddressFamily, err error) {
func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string, err error) {
if len(r.Packet) < 16 {
// Malformed, no client packets are less than 16 bytes.
// We explicitly return nothing in case this is a DoS attempt.
@@ -283,16 +279,6 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
return
}
af = new(bittorrent.AddressFamily)
if r.IP.To4() != nil {
*af = bittorrent.IPv4
} else if len(r.IP) == net.IPv6len { // implies r.IP.To4() == nil
*af = bittorrent.IPv6
} else {
// Should never happen - we got the IP straight from the UDP packet.
panic(fmt.Sprintf("udp: invalid IP: neither v4 nor v6, IP: %#v", r.IP))
}
WriteConnectionID(w, txID, gen.Generate(r.IP, timecache.Now()))
case announceActionID, announceV6ActionID:
@@ -304,8 +290,6 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
WriteError(w, txID, err)
return
}
af = new(bittorrent.AddressFamily)
*af = req.IP.AddressFamily
var ctx context.Context
var resp *bittorrent.AnnounceResponse
@@ -315,7 +299,7 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
return
}
WriteAnnounce(w, txID, resp, actionID == announceV6ActionID, req.IP.AddressFamily == bittorrent.IPv6)
WriteAnnounce(w, txID, resp, actionID == announceV6ActionID, r.IP.Is6())
go t.logic.AfterAnnounce(ctx, req, resp)
@@ -329,17 +313,6 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
return
}
if r.IP.To4() != nil {
req.AddressFamily = bittorrent.IPv4
} else if len(r.IP) == net.IPv6len { // implies r.IP.To4() == nil
req.AddressFamily = bittorrent.IPv6
} else {
// Should never happen - we got the IP straight from the UDP packet.
panic(fmt.Sprintf("udp: invalid IP: neither v4 nor v6, IP: %#v", r.IP))
}
af = new(bittorrent.AddressFamily)
*af = req.AddressFamily
var ctx context.Context
var resp *bittorrent.ScrapeResponse
ctx, resp, err = t.logic.HandleScrape(context.Background(), req)

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"sync"
"github.com/sot-tech/mochi/bittorrent"
@@ -41,8 +42,6 @@ var (
}
errMalformedPacket = bittorrent.ClientError("malformed packet")
errMalformedIP = bittorrent.ClientError("malformed IP address")
errMalformedEvent = bittorrent.ClientError("malformed event ID")
errUnknownAction = bittorrent.ClientError("unknown action ID")
errBadConnectionID = bittorrent.ClientError("bad connection ID")
errUnknownOptionType = bittorrent.ClientError("unknown option type")
@@ -89,19 +88,23 @@ func ParseAnnounce(r Request, v6Action bool, opts ParseOptions) (*bittorrent.Ann
eventID := int(r.Packet[83])
if eventID >= len(eventIDs) {
return nil, errMalformedEvent
return nil, bittorrent.ErrUnknownEvent
}
ip := r.IP
ipProvided := false
if ipBytes := r.Packet[84:ipEnd]; opts.AllowIPSpoofing {
// Make sure the bytes are copied to a new slice.
copy(ip, ipBytes)
if opts.AllowIPSpoofing {
ipBytes := r.Packet[84:ipEnd]
spoofed, ok := netip.AddrFromSlice(ipBytes)
if !ok {
return nil, bittorrent.ErrInvalidIP
}
ipProvided = true
ip = spoofed
}
if !opts.AllowIPSpoofing && r.IP == nil {
if !opts.AllowIPSpoofing && r.IP.IsUnspecified() {
// We have no IP address to fallback on.
return nil, errMalformedIP
return nil, bittorrent.ErrInvalidIP
}
numWant := binary.BigEndian.Uint32(r.Packet[ipEnd+4 : ipEnd+8])
@@ -133,18 +136,17 @@ func ParseAnnounce(r Request, v6Action bool, opts ParseOptions) (*bittorrent.Ann
NumWantProvided: true,
EventProvided: true,
Peer: bittorrent.Peer{
ID: peerID,
IP: bittorrent.IP{IP: ip},
Port: port,
ID: peerID,
AddrPort: netip.AddrPortFrom(ip, port),
},
Params: params,
}
if err := bittorrent.SanitizeAnnounce(request, opts.MaxNumWant, opts.DefaultNumWant); err != nil {
return nil, err
if err = bittorrent.SanitizeAnnounce(request, opts.MaxNumWant, opts.DefaultNumWant); err != nil {
request = nil
}
return request, nil
return request, err
}
type buffer struct {

View File

@@ -2,11 +2,13 @@ package udp
import (
"errors"
"net/netip"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/metrics"
)
func init() {
@@ -24,7 +26,7 @@ var promResponseDurationMilliseconds = prometheus.NewHistogramVec(
// recordResponseDuration records the duration of time to respond to a UDP
// Request in milliseconds.
func recordResponseDuration(action string, af *bittorrent.AddressFamily, err error, duration time.Duration) {
func recordResponseDuration(action string, addr netip.Addr, err error, duration time.Duration) {
var errString string
if err != nil {
var clientErr bittorrent.ClientError
@@ -35,16 +37,7 @@ func recordResponseDuration(action string, af *bittorrent.AddressFamily, err err
}
}
var afString string
if af == nil {
afString = "Unknown"
} else if *af == bittorrent.IPv4 {
afString = "IPv4"
} else if *af == bittorrent.IPv6 {
afString = "IPv6"
}
promResponseDurationMilliseconds.
WithLabelValues(action, afString, errString).
WithLabelValues(action, metrics.AddressFamily(addr), errString).
Observe(float64(duration.Nanoseconds()) / float64(time.Millisecond))
}

View File

@@ -49,8 +49,8 @@ func WriteAnnounce(w io.Writer, txID []byte, resp *bittorrent.AnnounceResponse,
}
for _, peer := range peers {
_, _ = buf.Write(peer.IP.IP)
_ = binary.Write(buf, binary.BigEndian, peer.Port)
_, _ = buf.Write(peer.Addr().AsSlice())
_ = binary.Write(buf, binary.BigEndian, peer.Port())
}
_, _ = w.Write(buf.Bytes())

View File

@@ -96,7 +96,7 @@ func (h *responseHook) HandleAnnounce(ctx context.Context, req *bittorrent.Annou
}
// Add the Scrape data to the response.
s := h.store.ScrapeSwarm(req.InfoHash, req.IP.AddressFamily)
s := h.store.ScrapeSwarm(req.InfoHash, req.Peer)
resp.Incomplete = s.Incomplete
resp.Complete = s.Complete
@@ -106,7 +106,8 @@ func (h *responseHook) HandleAnnounce(ctx context.Context, req *bittorrent.Annou
func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) error {
seeding := req.Left == 0
peers, err := h.store.AnnouncePeers(req.InfoHash, seeding, int(req.NumWant), req.Peer)
max := int(req.NumWant)
peers, err := h.store.AnnouncePeers(req.InfoHash, seeding, max, req.Peer)
if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) {
return err
}
@@ -123,19 +124,19 @@ func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittor
peers = append(peers, req.Peer)
}
switch req.IP.AddressFamily {
case bittorrent.IPv4:
resp.IPv4Peers = mergePeers(resp.IPv4Peers, peers)
case bittorrent.IPv6:
resp.IPv6Peers = mergePeers(resp.IPv6Peers, peers)
switch addr := req.Peer.Addr(); {
case addr.Is4(), addr.Is4In6():
resp.IPv4Peers = mergePeers(resp.IPv4Peers, peers, max)
case addr.Is6():
resp.IPv6Peers = mergePeers(resp.IPv6Peers, peers, max)
default:
err = bittorrent.ErrInvalidAddressFamily
err = bittorrent.ErrInvalidIP
}
return err
}
func mergePeers(p0, p1 []bittorrent.Peer) (result []bittorrent.Peer) {
func mergePeers(p0, p1 []bittorrent.Peer, max int) (result []bittorrent.Peer) {
peers := make(map[string]bittorrent.Peer, len(p0)+len(p1))
for _, p := range p0 {
peers[p.RawString()] = p
@@ -145,7 +146,11 @@ func mergePeers(p0, p1 []bittorrent.Peer) (result []bittorrent.Peer) {
}
result = make([]bittorrent.Peer, 0, len(peers))
for _, v := range peers {
result = append(result, v)
if len(peers) < max {
result = append(result, v)
} else {
break
}
}
return
}
@@ -156,7 +161,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
}
for _, infoHash := range req.InfoHashes {
resp.Files = append(resp.Files, h.store.ScrapeSwarm(infoHash, req.AddressFamily))
resp.Files = append(resp.Files, h.store.ScrapeSwarm(infoHash, req.Peer))
}
return ctx, nil

View File

@@ -3,7 +3,7 @@ package middleware
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
@@ -42,12 +42,16 @@ func (hooks hookList) handleAnnounce(ctx context.Context, req *bittorrent.Announ
}
func benchHookListV4(b *testing.B, hooks hookList) {
req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{IP: bittorrent.IP{IP: net.ParseIP("1.2.3.4"), AddressFamily: bittorrent.IPv4}}}
req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{
AddrPort: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 0),
}}
benchHookList(b, hooks, req)
}
func benchHookListV6(b *testing.B, hooks hookList) {
req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{IP: bittorrent.IP{IP: net.ParseIP("fc00::0001"), AddressFamily: bittorrent.IPv6}}}
req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{
AddrPort: netip.AddrPortFrom(netip.MustParseAddr("fc00:0001"), 0),
}}
benchHookList(b, hooks, req)
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"net/http"
"net/http/pprof"
"net/netip"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -20,6 +21,18 @@ type Server struct {
srv *http.Server
}
// AddressFamily returns the label value for reporting the address family of an IP address.
func AddressFamily(ip netip.Addr) string {
switch {
case ip.Is4(), ip.Is4In6():
return "IPv4"
case ip.Is6():
return "IPv6"
default:
return "<unknown>"
}
}
// Stop shuts down the server.
func (s *Server) Stop() stop.Result {
c := make(stop.Channel)

View File

@@ -6,6 +6,7 @@ import (
"encoding/binary"
"fmt"
"math"
"net/netip"
"reflect"
"runtime"
"sync"
@@ -228,12 +229,12 @@ func (ps *store) getClock() int64 {
return timecache.NowUnixNano()
}
func (ps *store) shardIndex(infoHash bittorrent.InfoHash, af bittorrent.AddressFamily) uint32 {
func (ps *store) shardIndex(infoHash bittorrent.InfoHash, addr netip.Addr) uint32 {
// There are twice the amount of shards specified by the user, the first
// half is dedicated to IPv4 swarms and the second half is dedicated to
// IPv6 swarms.
idx := binary.BigEndian.Uint32([]byte(infoHash[:4])) % (uint32(len(ps.shards)) / 2)
if af == bittorrent.IPv6 {
if addr.Is6() && !addr.Is4In6() {
idx += uint32(len(ps.shards) / 2)
}
return idx
@@ -248,7 +249,7 @@ func (ps *store) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
pk := p.RawString()
shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)]
shard := ps.shards[ps.shardIndex(ih, p.Addr())]
shard.Lock()
if _, ok := shard.swarms[ih]; !ok {
@@ -279,7 +280,7 @@ func (ps *store) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
pk := p.RawString()
shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)]
shard := ps.shards[ps.shardIndex(ih, p.Addr())]
shard.Lock()
if _, ok := shard.swarms[ih]; !ok {
@@ -312,7 +313,7 @@ func (ps *store) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error {
pk := p.RawString()
shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)]
shard := ps.shards[ps.shardIndex(ih, p.Addr())]
shard.Lock()
if _, ok := shard.swarms[ih]; !ok {
@@ -343,7 +344,7 @@ func (ps *store) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error
pk := p.RawString()
shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)]
shard := ps.shards[ps.shardIndex(ih, p.Addr())]
shard.Lock()
if _, ok := shard.swarms[ih]; !ok {
@@ -376,7 +377,7 @@ func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) erro
pk := p.RawString()
shard := ps.shards[ps.shardIndex(ih, p.IP.AddressFamily)]
shard := ps.shards[ps.shardIndex(ih, p.Addr())]
shard.Lock()
if _, ok := shard.swarms[ih]; !ok {
@@ -404,14 +405,14 @@ func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) erro
return nil
}
func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, announcer bittorrent.Peer) (peers []bittorrent.Peer, err error) {
func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, peer bittorrent.Peer) (peers []bittorrent.Peer, err error) {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
default:
}
shard := ps.shards[ps.shardIndex(ih, announcer.IP.AddressFamily)]
shard := ps.shards[ps.shardIndex(ih, peer.Addr())]
shard.RLock()
if _, ok := shard.swarms[ih]; !ok {
@@ -445,7 +446,7 @@ func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int,
// Append leechers until we reach numWant.
if numWant > 0 {
leechers := shard.swarms[ih].leechers
announcerPK := announcer.RawString()
announcerPK := peer.RawString()
for pk := range leechers {
if pk == announcerPK {
continue
@@ -465,7 +466,7 @@ func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int,
return
}
func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) (resp bittorrent.Scrape) {
func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash, peer bittorrent.Peer) (resp bittorrent.Scrape) {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -473,7 +474,7 @@ func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash, addressFamily bittorrent.Ad
}
resp.InfoHash = ih
shard := ps.shards[ps.shardIndex(ih, addressFamily)]
shard := ps.shards[ps.shardIndex(ih, peer.Addr())]
shard.RLock()
swarm, ok := shard.swarms[ih]

View File

@@ -2,24 +2,24 @@
// BitTorrent tracker keeping peer data in redis with hash.
// There two categories of hash:
//
// - IPv{4,6}_{L,S}_infohash
// - CHI_{4,6}_{L,S}_infohash
// To save peers that hold the infohash, used for fast searching,
// deleting, and timeout handling
//
// - IPv{4,6}
// - CHI_{4,6}
// To save all the infohashes, used for garbage collection,
// metrics aggregation and leecher graduation
//
// Tree keys are used to record the count of swarms, seeders
// and leechers for each group (IPv4, IPv6).
//
// - IPv{4,6}_infohash_count
// - CHI_{4,6}_I_C
// To record the number of infohashes.
//
// - IPv{4,6}_S_count
// - CHI_{4,6}_S_C
// To record the number of seeders.
//
// - IPv{4,6}_L_count
// - CHI_{4,6}_L_C
// To record the number of leechers.
package redis
@@ -53,6 +53,19 @@ const (
defaultReadTimeout = time.Second * 15
defaultWriteTimeout = time.Second * 15
defaultConnectTimeout = time.Second * 15
prefixKey = "CHI_"
ih4Key = "CHI_4_I"
ih6Key = "CHI_6_I"
ih4SeederKey = "CHI_4_S_"
ih6SeederKey = "CHI_6_S_"
ih4LeecherKey = "CHI_4_L_"
ih6LeecherKey = "CHI_6_L_"
cnt4SeederKey = "CHI_4_C_S"
cnt6SeederKey = "CHI_6_C_S"
cnt4LeecherKey = "CHI_4_C_L"
cnt6LeecherKey = "CHI_6_C_L"
cnt4InfoHashKey = "CHI_4_C_I"
cnt6InfoHashKey = "CHI_6_C_I"
)
// ErrSentinelAndClusterChecked returned from initializer if both Config.Sentinel and Config.Cluster provided
@@ -264,13 +277,11 @@ func New(conf Config) (storage.Storage, error) {
ps.logFields = cfg.LogFields()
// Start a goroutine for garbage collection.
ps.wg.Add(1)
go ps.runGC(cfg.GarbageCollectionInterval, cfg.PeerLifetime)
go ps.scheduleGC(cfg.GarbageCollectionInterval, cfg.PeerLifetime)
if cfg.PrometheusReportingInterval > 0 {
// Start a goroutine for reporting statistics to Prometheus.
ps.wg.Add(1)
go ps.runProm(cfg.PrometheusReportingInterval)
go ps.schedulerProm(cfg.PrometheusReportingInterval)
} else {
log.Info("prometheus disabled because of zero reporting interval")
}
@@ -278,21 +289,32 @@ func New(conf Config) (storage.Storage, error) {
return ps, nil
}
func (ps *store) runGC(gcInterval, peerLifeTime time.Duration) {
func (ps *store) scheduleGC(gcInterval, peerLifeTime time.Duration) {
ps.wg.Add(1)
defer ps.wg.Done()
t := time.NewTimer(gcInterval)
defer t.Stop()
for {
select {
case <-ps.closed:
return
case <-time.After(gcInterval):
case <-t.C:
before := time.Now().Add(-peerLifeTime)
log.Debug("storage: purging peers with no announces since", log.Fields{"before": before})
ps.collectGarbage(before)
cutoffUnix := before.UnixNano()
start := time.Now()
ps.gc(cutoffUnix, false)
ps.gc(cutoffUnix, true)
duration := time.Since(start).Milliseconds()
log.Debug("storage: recordGCDuration", log.Fields{"timeTaken(ms)": duration})
storage.PromGCDurationMilliseconds.Observe(float64(duration))
t.Reset(gcInterval)
}
}
}
func (ps *store) runProm(reportInterval time.Duration) {
func (ps *store) schedulerProm(reportInterval time.Duration) {
ps.wg.Add(1)
defer ps.wg.Done()
t := time.NewTicker(reportInterval)
for {
@@ -317,68 +339,39 @@ type store struct {
logFields log.Fields
}
var groups = []string{bittorrent.IPv4.String(), bittorrent.IPv6.String()}
// leecherInfoHashKey generates string IPvN_L_hash
func leecherInfoHashKey(addressFamily, infoHash string) string {
return addressFamily + "_L_" + infoHash
}
// seederInfoHashKey generates string IPvN_S_hash
func seederInfoHashKey(addressFamily, infoHash string) string {
return addressFamily + "_S_" + infoHash
}
// seederInfoHashKey generates string IPvN_infohash_count
func infoHashCountKey(addressFamily string) string {
return addressFamily + "_infohash_count"
}
// seederInfoHashKey generates string IPvN_L_count
func leecherCountKey(addressFamily string) string {
return addressFamily + "_L_count"
}
// seederInfoHashKey generates string IPvN_S_count
func seederCountKey(addressFamily string) string {
return addressFamily + "_S_count"
func (ps *store) count(key string) (n uint64) {
var err error
if n, err = ps.con.Get(ps.ctx, key).Uint64(); err != nil && !errors.Is(err, redis.Nil) {
log.Error("storage: GET counter failure", log.Fields{
"key": key,
"error": err,
})
}
return
}
// populateProm aggregates metrics over all groups and then posts them to
// prometheus.
func (ps *store) populateProm() {
var numInfoHashes, numSeeders, numLeechers int64
for _, group := range groups {
if n, err := ps.con.Get(ps.ctx, infoHashCountKey(group)).Int64(); err != nil && !errors.Is(err, redis.Nil) {
log.Error("storage: GET counter failure", log.Fields{
"key": infoHashCountKey(group),
"error": err,
})
numInfoHashes, numSeeders, numLeechers := new(uint64), new(uint64), new(uint64)
fetchFn := func(v6 bool) {
var cntSeederKey, cntLeecherKey, cntInfoHashKey string
if v6 {
cntSeederKey, cntLeecherKey, cntInfoHashKey = cnt6SeederKey, cnt6LeecherKey, cnt6InfoHashKey
} else {
numInfoHashes += n
}
if n, err := ps.con.Get(ps.ctx, seederCountKey(group)).Int64(); err != nil && !errors.Is(err, redis.Nil) {
log.Error("storage: GET counter failure", log.Fields{
"key": seederCountKey(group),
"error": err,
})
} else {
numSeeders += n
}
if n, err := ps.con.Get(ps.ctx, leecherCountKey(group)).Int64(); err != nil && !errors.Is(err, redis.Nil) {
log.Error("storage: GET counter failure", log.Fields{
"key": leecherCountKey(group),
"error": err,
})
} else {
numLeechers += n
cntSeederKey, cntLeecherKey, cntInfoHashKey = cnt4SeederKey, cnt4LeecherKey, cnt4InfoHashKey
}
*numInfoHashes += ps.count(cntInfoHashKey)
*numSeeders += ps.count(cntSeederKey)
*numLeechers += ps.count(cntLeecherKey)
}
storage.PromInfoHashesCount.Set(float64(numInfoHashes))
storage.PromSeedersCount.Set(float64(numSeeders))
storage.PromLeechersCount.Set(float64(numLeechers))
fetchFn(false)
fetchFn(true)
storage.PromInfoHashesCount.Set(float64(*numInfoHashes))
storage.PromSeedersCount.Set(float64(*numSeeders))
storage.PromLeechersCount.Set(float64(*numLeechers))
}
func (ps *store) getClock() int64 {
@@ -409,154 +402,185 @@ func asNil(err error) error {
return err
}
func (ps *store) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
addressFamily := p.IP.AddressFamily.String()
func (ps *store) PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
var ihSummaryKey, ihPeerKey, cntPeerKey, cntInfoHashKey string
log.Debug("storage: PutSeeder", log.Fields{
"InfoHash": ih,
"Peer": p,
"Peer": peer,
})
encodedSeederInfoHash := seederInfoHashKey(addressFamily, ih.RawString())
if peer.Addr().Is6() {
ihSummaryKey, ihPeerKey, cntPeerKey, cntInfoHashKey = ih6Key, ih6SeederKey, cnt6SeederKey, cnt6InfoHashKey
} else {
ihSummaryKey, ihPeerKey, cntPeerKey, cntInfoHashKey = ih4Key, ih4SeederKey, cnt4SeederKey, cnt4InfoHashKey
}
ihPeerKey += ih.RawString()
now := ps.getClock()
return ps.tx(func(tx redis.Pipeliner) (err error) {
if err = tx.HSet(ps.ctx, encodedSeederInfoHash, p.RawString(), now).Err(); err != nil {
if err = tx.HSet(ps.ctx, ihPeerKey, peer.RawString(), now).Err(); err != nil {
return
}
if err = ps.con.Incr(ps.ctx, seederCountKey(addressFamily)).Err(); err != nil {
if err = ps.con.Incr(ps.ctx, cntPeerKey).Err(); err != nil {
return
}
if err = ps.con.HSet(ps.ctx, addressFamily, encodedSeederInfoHash, now).Err(); err != nil {
var added int64
if added, err = ps.con.SAdd(ps.ctx, ihSummaryKey, ihPeerKey).Result(); err != nil {
return
}
err = ps.con.Incr(ps.ctx, infoHashCountKey(addressFamily)).Err()
if added > 0 {
err = ps.con.Incr(ps.ctx, cntInfoHashKey).Err()
}
return
})
}
func (ps *store) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
addressFamily := p.IP.AddressFamily.String()
func (ps *store) DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
var ihPeerKey, cntPeerKey string
log.Debug("storage: DeleteSeeder", log.Fields{
"InfoHash": ih,
"Peer": p,
"Peer": peer,
})
if peer.Addr().Is6() {
ihPeerKey, cntPeerKey = ih6SeederKey, cnt6SeederKey
} else {
ihPeerKey, cntPeerKey = ih4SeederKey, cnt4SeederKey
}
ihPeerKey += ih.RawString()
encodedSeederInfoHash := seederInfoHashKey(addressFamily, ih.RawString())
deleted, err := ps.con.HDel(ps.ctx, encodedSeederInfoHash, p.RawString()).Uint64()
deleted, err := ps.con.HDel(ps.ctx, ihPeerKey, peer.RawString()).Uint64()
err = asNil(err)
if err == nil {
if deleted == 0 {
err = storage.ErrResourceDoesNotExist
} else {
err = ps.con.Decr(ps.ctx, seederCountKey(addressFamily)).Err()
err = ps.con.Decr(ps.ctx, cntPeerKey).Err()
}
}
return err
}
func (ps *store) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error {
addressFamily := p.IP.AddressFamily.String()
func (ps *store) PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
var ihSummaryKey, ihPeerKey, cntPeerKey string
log.Debug("storage: PutLeecher", log.Fields{
"InfoHash": ih,
"Peer": p,
"Peer": peer,
})
if peer.Addr().Is6() {
ihSummaryKey, ihPeerKey, cntPeerKey = ih6Key, ih6LeecherKey, cnt6LeecherKey
} else {
ihSummaryKey, ihPeerKey, cntPeerKey = ih4Key, ih4LeecherKey, cnt4LeecherKey
}
ihPeerKey += ih.RawString()
// Update the peer in the swarm.
encodedLeecherInfoHash := leecherInfoHashKey(addressFamily, ih.RawString())
now := ps.getClock()
return ps.tx(func(tx redis.Pipeliner) (err error) {
if err = tx.HSet(ps.ctx, encodedLeecherInfoHash, p.RawString(), now).Err(); err != nil {
if err = tx.HSet(ps.ctx, ihPeerKey, peer.RawString(), now).Err(); err != nil {
return
}
if err = tx.HSet(ps.ctx, addressFamily, encodedLeecherInfoHash, now).Err(); err != nil {
if err = tx.Incr(ps.ctx, cntPeerKey).Err(); err != nil {
return err
}
err = tx.Incr(ps.ctx, leecherCountKey(addressFamily)).Err()
err = tx.HSet(ps.ctx, ihSummaryKey, ihPeerKey, now).Err()
return
})
}
func (ps *store) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error {
addressFamily := p.IP.AddressFamily.String()
func (ps *store) DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
var ihPeerKey, cntPeerKey string
log.Debug("storage: DeleteLeecher", log.Fields{
"InfoHash": ih,
"Peer": p,
"Peer": peer,
})
encodedLeecherInfoHash := leecherInfoHashKey(addressFamily, ih.RawString())
if peer.Addr().Is6() {
ihPeerKey, cntPeerKey = ih6LeecherKey, cnt6LeecherKey
} else {
ihPeerKey, cntPeerKey = ih4LeecherKey, cnt4LeecherKey
}
ihPeerKey += ih.RawString()
deleted, err := ps.con.HDel(ps.ctx, encodedLeecherInfoHash, p.RawString()).Uint64()
deleted, err := ps.con.HDel(ps.ctx, ihPeerKey, peer.RawString()).Uint64()
err = asNil(err)
if err == nil {
if deleted == 0 {
err = storage.ErrResourceDoesNotExist
} else {
err = ps.con.Decr(ps.ctx, leecherCountKey(addressFamily)).Err()
err = ps.con.Decr(ps.ctx, cntPeerKey).Err()
}
}
return err
}
func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error {
addressFamily := p.IP.AddressFamily.String()
func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
var ihSummaryKey, ihSeederKey, ihLeecherKey, cntSeederKey, cntLeecherKey, cntInfoHashKey string
log.Debug("storage: GraduateLeecher", log.Fields{
"InfoHash": ih,
"Peer": p,
"Peer": peer,
})
encodedInfoHash := ih.RawString()
encodedLeecherInfoHash := leecherInfoHashKey(addressFamily, encodedInfoHash)
encodedSeederInfoHash := seederInfoHashKey(addressFamily, encodedInfoHash)
peerKey := p.RawString()
if peer.Addr().Is6() {
ihSummaryKey, ihSeederKey, cntSeederKey = ih6Key, ih6SeederKey, cnt6SeederKey
ihLeecherKey, cntLeecherKey, cntInfoHashKey = ih6LeecherKey, cnt6LeecherKey, cnt6InfoHashKey
} else {
ihSummaryKey, ihSeederKey, cntSeederKey = ih4Key, ih4SeederKey, cnt4SeederKey
ihLeecherKey, cntLeecherKey, cntInfoHashKey = ih4LeecherKey, cnt4LeecherKey, cnt4InfoHashKey
}
infoHash, peerKey := ih.RawString(), peer.RawString()
ihSeederKey, ihLeecherKey = ihSeederKey+infoHash, ihLeecherKey+infoHash
now := ps.getClock()
return ps.tx(func(tx redis.Pipeliner) error {
deleted, err := tx.HDel(ps.ctx, encodedLeecherInfoHash, peerKey).Uint64()
deleted, err := tx.HDel(ps.ctx, ihLeecherKey, peerKey).Uint64()
err = asNil(err)
if err == nil {
if deleted > 0 {
err = tx.Decr(ps.ctx, leecherCountKey(addressFamily)).Err()
err = tx.Decr(ps.ctx, cntLeecherKey).Err()
}
}
if err == nil {
err = tx.HSet(ps.ctx, encodedSeederInfoHash, peerKey, now).Err()
err = tx.HSet(ps.ctx, ihSeederKey, peerKey, now).Err()
}
if err == nil {
err = tx.Incr(ps.ctx, seederCountKey(addressFamily)).Err()
err = tx.Incr(ps.ctx, cntSeederKey).Err()
}
if err == nil {
err = tx.HSet(ps.ctx, addressFamily, encodedSeederInfoHash, now).Err()
err = tx.HSet(ps.ctx, ihSummaryKey, ihSeederKey, now).Err()
}
if err == nil {
err = tx.Incr(ps.ctx, infoHashCountKey(addressFamily)).Err()
err = tx.Incr(ps.ctx, cntInfoHashKey).Err()
}
return err
})
}
func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, announcer bittorrent.Peer) (peers []bittorrent.Peer, err error) {
addressFamily := announcer.IP.AddressFamily.String()
func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, peer bittorrent.Peer) (peers []bittorrent.Peer, err error) {
var ihSeederKey, ihLeecherKey string
log.Debug("storage: AnnouncePeers", log.Fields{
"InfoHash": ih,
"seeder": seeder,
"numWant": numWant,
"Peer": announcer,
"Peer": peer,
})
encodedInfoHash := ih.RawString()
encodedLeecherInfoHash := leecherInfoHashKey(addressFamily, encodedInfoHash)
encodedSeederInfoHash := seederInfoHashKey(addressFamily, encodedInfoHash)
if peer.Addr().Is6() {
ihSeederKey, ihLeecherKey = ih6SeederKey, cnt6LeecherKey
} else {
ihSeederKey, ihLeecherKey = ih4SeederKey, ih4LeecherKey
}
infoHash := ih.RawString()
ihSeederKey, ihLeecherKey = ihSeederKey+infoHash, ihLeecherKey+infoHash
leechers, err := ps.con.HKeys(ps.ctx, encodedLeecherInfoHash).Result()
leechers, err := ps.con.HKeys(ps.ctx, ihLeecherKey).Result()
err = asNil(err)
if err != nil {
return nil, err
}
seeders, err := ps.con.HKeys(ps.ctx, encodedSeederInfoHash).Result()
seeders, err := ps.con.HKeys(ps.ctx, ihSeederKey).Result()
err = asNil(err)
if err != nil {
return nil, err
@@ -595,7 +619,7 @@ func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int,
// Append leechers until we reach numWant.
if numWant > 0 {
announcerPK := announcer.RawString()
announcerPK := peer.RawString()
for _, peerKey := range leechers {
if peerKey != announcerPK {
if numWant == 0 {
@@ -615,28 +639,36 @@ func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int,
return
}
func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash, af bittorrent.AddressFamily) (resp bittorrent.Scrape) {
func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash, peer bittorrent.Peer) (resp bittorrent.Scrape) {
var ihSeederKey, ihLeecherKey string
log.Debug("storage: ScrapeSwarm", log.Fields{
"InfoHash": ih,
"Peer": peer,
})
resp.InfoHash = ih
addressFamily := af.String()
encodedInfoHash := ih.RawString()
encodedLeecherInfoHash := leecherInfoHashKey(addressFamily, encodedInfoHash)
encodedSeederInfoHash := seederInfoHashKey(addressFamily, encodedInfoHash)
if peer.Addr().Is6() {
ihSeederKey, ihLeecherKey = ih6SeederKey, cnt6LeecherKey
} else {
ihSeederKey, ihLeecherKey = ih4SeederKey, ih4LeecherKey
}
infoHash := ih.RawString()
ihSeederKey, ihLeecherKey = ihSeederKey+infoHash, ihLeecherKey+infoHash
leechersLen, err := ps.con.HLen(ps.ctx, encodedLeecherInfoHash).Result()
leechersLen, err := ps.con.HLen(ps.ctx, ihLeecherKey).Result()
err = asNil(err)
if err != nil {
log.Error("storage: Redis HLEN failure", log.Fields{
"Hkey": encodedLeecherInfoHash,
"Hkey": ihLeecherKey,
"error": err,
})
return
}
seedersLen, err := ps.con.HLen(ps.ctx, encodedSeederInfoHash).Result()
seedersLen, err := ps.con.HLen(ps.ctx, ihSeederKey).Result()
err = asNil(err)
if err != nil {
log.Error("storage: Redis HLEN failure", log.Fields{
"Hkey": encodedSeederInfoHash,
"Hkey": ihSeederKey,
"error": err,
})
return
@@ -649,11 +681,11 @@ func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash, af bittorrent.AddressFamily
}
func (ps *store) Put(ctx string, value storage.Entry) error {
return ps.con.HSet(ps.ctx, ctx, value.Key, value.Value).Err()
return ps.con.HSet(ps.ctx, prefixKey+ctx, value.Key, value.Value).Err()
}
func (ps *store) Contains(ctx string, key string) (bool, error) {
exist, err := ps.con.HExists(ps.ctx, ctx, key).Result()
exist, err := ps.con.HExists(ps.ctx, prefixKey+ctx, key).Result()
return exist, asNil(err)
}
@@ -665,12 +697,12 @@ func (ps *store) BulkPut(ctx string, pairs ...storage.Entry) (err error) {
for _, p := range pairs {
args = append(args, p.Key, p.Value)
}
err = ps.con.HSet(ps.ctx, ctx, args...).Err()
err = ps.con.HSet(ps.ctx, prefixKey+ctx, args...).Err()
if err != nil {
if strings.Contains(err.Error(), argNumErrorMsg) {
log.Warn("This REDIS version/implementation does not support variadic arguments for HSET")
for _, p := range pairs {
if err = ps.con.HSet(ps.ctx, ctx, p.Key, p.Value).Err(); err != nil {
if err = ps.con.HSet(ps.ctx, prefixKey+ctx, p.Key, p.Value).Err(); err != nil {
break
}
}
@@ -681,7 +713,7 @@ func (ps *store) BulkPut(ctx string, pairs ...storage.Entry) (err error) {
}
func (ps *store) Load(ctx string, key string) (v any, err error) {
v, err = ps.con.HGet(ps.ctx, ctx, key).Result()
v, err = ps.con.HGet(ps.ctx, prefixKey+ctx, key).Result()
if err != nil && errors.Is(err, redis.Nil) {
v, err = nil, nil
}
@@ -690,12 +722,12 @@ func (ps *store) Load(ctx string, key string) (v any, err error) {
func (ps *store) Delete(ctx string, keys ...string) (err error) {
if len(keys) > 0 {
err = asNil(ps.con.HDel(ps.ctx, ctx, keys...).Err())
err = asNil(ps.con.HDel(ps.ctx, prefixKey+ctx, keys...).Err())
if err != nil {
if strings.Contains(err.Error(), argNumErrorMsg) {
log.Warn("This REDIS version/implementation does not support variadic arguments for HDEL")
for _, k := range keys {
if err = asNil(ps.con.HDel(ps.ctx, ctx, k).Err()); err != nil {
if err = asNil(ps.con.HDel(ps.ctx, prefixKey+ctx, k).Err()); err != nil {
break
}
}
@@ -705,7 +737,7 @@ func (ps *store) Delete(ctx string, keys ...string) (err error) {
return
}
// collectGarbage deletes all Peers from the Storage which are older than the
// gc deletes all Peers from the Storage which are older than the
// cutoff time.
//
// This function must be able to execute while other methods on this interface
@@ -717,11 +749,11 @@ func (ps *store) Delete(ctx string, keys ...string) (err error) {
// - The Put(Seeder|Leecher) and GraduateLeecher methods only ever add infohash
// keys to addressFamily hashes and increment the infohash counter.
// - The only method that deletes from the addressFamily hashes is
// collectGarbage, which also decrements the counters. That means that,
// gc, which also decrements the counters. That means that,
// even if a Delete(Seeder|Leecher) call removes the last peer from a swarm,
// the infohash counter is not changed and the infohash is left in the
// addressFamily hash until it will be cleaned up by collectGarbage.
// - collectGarbage must run regularly.
// addressFamily hash until it will be cleaned up by gc.
// - gc must run regularly.
// - A WATCH ... MULTI ... EXEC block fails, if between the WATCH and the 'EXEC'
// any of the watched keys have changed. The location of the 'MULTI' doesn't
// matter.
@@ -740,120 +772,124 @@ func (ps *store) Delete(ctx string, keys ...string) (err error) {
// not empty and start no transaction.
// - If the change happens after the HLEN, we will attempt a transaction and it
// will fail. This is okay, the swarm is not empty, we will try cleaning it up
// next time collectGarbage runs.
// next time gc runs.
// 4. (1,0): Again, two ways:
// - If the change happens before the HLEN, we will see an empty swarm. This
// situation happens if a call to Delete(Seeder|Leecher) removed the last
// peer asynchronously. We will attempt a transaction, but the transaction
// will fail. This is okay, the infohash key will remain in the addressFamily
// hash, we will attempt to clean it up the next time 'collectGarbage` runs.
// hash, we will attempt to clean it up the next time 'gc` runs.
// - If the change happens after the HLEN, we will not even attempt to make the
// transaction. The infohash key will remain in the addressFamil hash and
// we'll attempt to clean it up the next time collectGarbage runs.
func (ps *store) collectGarbage(cutoff time.Time) {
cutoffUnix := cutoff.UnixNano()
start := time.Now()
var err error
for _, group := range groups {
// list all infoHashes in the group
var infoHashes []string
infoHashes, err = ps.con.HKeys(ps.ctx, group).Result()
err = asNil(err)
if err == nil {
for _, infoHash := range infoHashes {
isSeeder := len(infoHash) > 5 && infoHash[5:6] == "S"
// list all (peer, timeout) pairs for the ih
peerList, err := ps.con.HGetAll(ps.ctx, infoHash).Result()
err = asNil(err)
if err == nil {
var removedPeerCount int64
for peerKey, timeStamp := range peerList {
var peer bittorrent.Peer
if peer, err = bittorrent.NewPeer(peerKey); err == nil {
if mtime, err := strconv.ParseInt(timeStamp, 10, 64); err == nil {
if mtime <= cutoffUnix {
log.Debug("storage: deleting peer", log.Fields{
"Peer": peer,
})
var count int64
count, err = ps.con.HDel(ps.ctx, infoHash, peerKey).Result()
err = asNil(err)
if err == nil {
removedPeerCount += count
}
// we'll attempt to clean it up the next time gc runs.
func (ps *store) gc(cutoffUnix int64, v6 bool) {
// list all infoHashKeys in the group
var ihSummaryKey, ihSeederKey, ihLeecherKey, cntSeederKey, cntLeecherKey, cntInfoHashKey string
if v6 {
cntSeederKey, cntLeecherKey, cntInfoHashKey = cnt6SeederKey, cnt6LeecherKey, cnt6InfoHashKey
ihSummaryKey, ihSeederKey, ihLeecherKey = ih6Key, ih6SeederKey, ih6LeecherKey
} else {
cntSeederKey, cntLeecherKey, cntInfoHashKey = cnt4SeederKey, cnt4LeecherKey, cnt4InfoHashKey
ihSummaryKey, ihSeederKey, ihLeecherKey = ih4Key, ih4SeederKey, ih4LeecherKey
}
infoHashKeys, err := ps.con.SMembers(ps.ctx, ihSummaryKey).Result()
err = asNil(err)
if err == nil {
for _, infoHashKey := range infoHashKeys {
var cntKey string
var seeder bool
if seeder = strings.HasPrefix(infoHashKey, ihSeederKey); seeder {
cntKey = cntSeederKey
} else if strings.HasPrefix(infoHashKey, ihLeecherKey) {
cntKey = cntLeecherKey
} else {
log.Warn("storage: Redis: unexpected record found in info hash set", log.Fields{
"hashSet": ihSummaryKey,
"infoHashKey": infoHashKey,
})
continue
}
// list all (peer, timeout) pairs for the ih
peerList, err := ps.con.HGetAll(ps.ctx, infoHashKey).Result()
err = asNil(err)
if err == nil {
var removedPeerCount int64
for peerKey, timeStamp := range peerList {
var peer bittorrent.Peer
if peer, err = bittorrent.NewPeer(peerKey); err == nil {
if mtime, err := strconv.ParseInt(timeStamp, 10, 64); err == nil {
if mtime <= cutoffUnix {
log.Debug("storage: Redis: deleting peer", log.Fields{
"Peer": peer,
})
var count int64
count, err = ps.con.HDel(ps.ctx, infoHashKey, peerKey).Result()
err = asNil(err)
if err == nil {
removedPeerCount += count
}
}
}
if err != nil {
log.Error("storage: Redis: unable to delete info hash peer", log.Fields{
"group": group,
"infoHash": infoHash,
"peer": peer,
"key": peerKey,
"error": err,
})
}
}
// DECR seeder/leecher counter
if removedPeerCount > 0 {
var decrCounter string
if isSeeder {
decrCounter = seederCountKey(group)
} else {
decrCounter = leecherCountKey(group)
}
if err := ps.con.DecrBy(ps.ctx, decrCounter, removedPeerCount).Err(); err != nil {
log.Error("storage: Redis: unable to decrement seeder/leecher peer count", log.Fields{
"group": group,
"infoHash": infoHash,
"key": decrCounter,
"error": err,
})
}
}
// use WATCH to avoid race condition
// https://redis.io/topics/transactions
err = asNil(ps.con.Watch(ps.ctx, func(tx *redis.Tx) (err error) {
var infoHashCount int64
infoHashCount, err = ps.con.HLen(ps.ctx, infoHash).Result()
err = asNil(err)
if err == nil && infoHashCount == 0 {
// Empty hashes are not shown among existing keys,
// in other words, it's removed automatically after `HDEL` the last field.
// _, err := ps.con.Del(ps.ctx, infoHash)
var deletedCount int64
deletedCount, err = ps.con.HDel(ps.ctx, group, infoHash).Result()
err = asNil(err)
if err == nil && isSeeder && deletedCount > 0 {
err = ps.con.Decr(ps.ctx, infoHashCountKey(group)).Err()
}
}
return err
}, infoHash))
if err != nil {
log.Error("storage: Redis: unable to clean info hash records", log.Fields{
"group": group,
"infoHash": infoHash,
"error": err,
log.Error("storage: Redis: unable to delete info hash peer", log.Fields{
"hashSet": ihSummaryKey,
"infoHashKey": infoHashKey,
"peer": peer,
"key": peerKey,
"error": err,
})
}
} else {
log.Error("storage: Redis: unable to fetch info hash peers", log.Fields{
"group": group,
"infoHash": infoHash,
"error": err,
}
// DECR seeder/leecher counter
if removedPeerCount > 0 {
if err := ps.con.DecrBy(ps.ctx, cntKey, removedPeerCount).Err(); err != nil {
log.Error("storage: Redis: unable to decrement seeder/leecher peer count", log.Fields{
"hashSet": ihSummaryKey,
"infoHashKey": infoHashKey,
"key": cntKey,
"error": err,
})
}
}
// use WATCH to avoid race condition
// https://redis.io/topics/transactions
err = asNil(ps.con.Watch(ps.ctx, func(tx *redis.Tx) (err error) {
var infoHashCount int64
infoHashCount, err = ps.con.HLen(ps.ctx, infoHashKey).Result()
err = asNil(err)
if err == nil && infoHashCount == 0 {
// Empty hashes are not shown among existing keys,
// in other words, it's removed automatically after `HDEL` the last field.
// _, err := ps.con.Del(ps.ctx, infoHashKey)
var deletedCount int64
deletedCount, err = ps.con.SRem(ps.ctx, ihSummaryKey, infoHashKey).Result()
err = asNil(err)
if err == nil && seeder && deletedCount > 0 {
err = ps.con.Decr(ps.ctx, cntInfoHashKey).Err()
}
}
return err
}, infoHashKey))
if err != nil {
log.Error("storage: Redis: unable to clean info hash records", log.Fields{
"hashSet": ihSummaryKey,
"infoHashKey": infoHashKey,
"error": err,
})
}
} else {
log.Error("storage: Redis: unable to fetch info hash peers", log.Fields{
"hashSet": ihSummaryKey,
"infoHashKey": infoHashKey,
"error": err,
})
}
} else {
log.Error("storage: Redis: unable to fetch info hashes", log.Fields{"group": group, "error": err})
}
} else {
log.Error("storage: Redis: unable to fetch info hash set", log.Fields{"hashSet": ihSummaryKey, "error": err})
}
duration := time.Since(start).Milliseconds()
log.Debug("storage: recordGCDuration", log.Fields{"timeTaken(ms)": duration})
storage.PromGCDurationMilliseconds.Observe(float64(duration))
}
func (ps *store) Stop() stop.Result {

View File

@@ -55,33 +55,33 @@ var ErrDriverDoesNotExist = errors.New("peer store driver with that name does no
type Storage interface {
// PutSeeder adds a Seeder to the Swarm identified by the provided
// InfoHash.
PutSeeder(infoHash bittorrent.InfoHash, p bittorrent.Peer) error
PutSeeder(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error
// DeleteSeeder removes a Seeder from the Swarm identified by the
// provided InfoHash.
//
// If the Swarm or Peer does not exist, this function returns
// ErrResourceDoesNotExist.
DeleteSeeder(infoHash bittorrent.InfoHash, p bittorrent.Peer) error
DeleteSeeder(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error
// PutLeecher adds a Leecher to the Swarm identified by the provided
// InfoHash.
// If the Swarm does not exist already, it is created.
PutLeecher(infoHash bittorrent.InfoHash, p bittorrent.Peer) error
PutLeecher(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error
// DeleteLeecher removes a Leecher from the Swarm identified by the
// provided InfoHash.
//
// If the Swarm or Peer does not exist, this function returns
// ErrResourceDoesNotExist.
DeleteLeecher(infoHash bittorrent.InfoHash, p bittorrent.Peer) error
DeleteLeecher(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error
// GraduateLeecher promotes a Leecher to a Seeder in the Swarm
// identified by the provided InfoHash.
//
// If the given Peer is not present as a Leecher or the swarm does not exist
// already, the Peer is added as a Seeder and no error is returned.
GraduateLeecher(infoHash bittorrent.InfoHash, p bittorrent.Peer) error
GraduateLeecher(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error
// AnnouncePeers is a best effort attempt to return Peers from the Swarm
// identified by the provided InfoHash.
@@ -98,7 +98,7 @@ type Storage interface {
// leechers
//
// Returns ErrResourceDoesNotExist if the provided InfoHash is not tracked.
AnnouncePeers(infoHash bittorrent.InfoHash, seeder bool, numWant int, p bittorrent.Peer) (peers []bittorrent.Peer, err error)
AnnouncePeers(infoHash bittorrent.InfoHash, seeder bool, numWant int, peer bittorrent.Peer) (peers []bittorrent.Peer, err error)
// ScrapeSwarm returns information required to answer a Scrape request
// about a Swarm identified by the given InfoHash.
@@ -108,7 +108,7 @@ type Storage interface {
// filling the Snatches field is optional.
//
// If the Swarm does not exist, an empty Scrape and no error is returned.
ScrapeSwarm(infoHash bittorrent.InfoHash, addressFamily bittorrent.AddressFamily) bittorrent.Scrape
ScrapeSwarm(infoHash bittorrent.InfoHash, peer bittorrent.Peer) bittorrent.Scrape
// Put used to place arbitrary k-v data with specified context
// into storage. ctx parameter used to group data
@@ -117,7 +117,7 @@ type Storage interface {
// BulkPut used to place array of k-v data in specified context.
// Useful when several data entries should be added in single transaction/connection
BulkPut(ctx string, pairs ...Entry) error
BulkPut(ctx string, values ...Entry) error
// Contains checks if any data in specified context exist
Contains(ctx string, key string) (bool, error)

View File

@@ -2,13 +2,13 @@ package test
import (
"math/rand"
"net"
"net/netip"
"runtime"
"sync/atomic"
"testing"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/randseed"
_ "github.com/sot-tech/mochi/pkg/randseed"
"github.com/sot-tech/mochi/storage"
)
@@ -28,23 +28,25 @@ func generateInfoHashes() (a [1000]bittorrent.InfoHash) {
}
func generatePeers() (a [1000]bittorrent.Peer) {
r := rand.New(rand.NewSource(randseed.GenSeed()))
for i := range a {
ip := make([]byte, 4)
n, err := r.Read(ip)
n, err := rand.Read(ip)
if err != nil || n != 4 {
panic("unable to create random bytes")
}
id := [bittorrent.PeerIDLen]byte{}
n, err = r.Read(id[:])
n, err = rand.Read(id[:])
if err != nil || n != bittorrent.InfoHashV1Len {
panic("unable to create random bytes")
}
port := uint16(r.Uint32())
addr, ok := netip.AddrFromSlice(ip)
if !ok {
panic("unable to create ip from random bytes")
}
port := uint16(rand.Uint32())
a[i] = bittorrent.Peer{
ID: id,
IP: bittorrent.IP{IP: net.IP(ip), AddressFamily: bittorrent.IPv4},
Port: port,
ID: id,
AddrPort: netip.AddrPortFrom(addr, port),
}
}
@@ -442,7 +444,7 @@ func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) {
// ScrapeSwarm can run in parallel.
func (bh *benchHolder) ScrapeSwarm(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.Storage, bd *benchData) error {
ps.ScrapeSwarm(bd.infohashes[0], bittorrent.IPv4)
ps.ScrapeSwarm(bd.infohashes[0], bd.peers[0])
return nil
})
}
@@ -452,7 +454,7 @@ func (bh *benchHolder) ScrapeSwarm(b *testing.B) {
// ScrapeSwarm1kInfoHash can run in parallel.
func (bh *benchHolder) ScrapeSwarm1kInfoHash(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.Storage, bd *benchData) error {
ps.ScrapeSwarm(bd.infohashes[i%1000], bittorrent.IPv4)
ps.ScrapeSwarm(bd.infohashes[i%1000], bd.peers[0])
return nil
})
}

View File

@@ -34,7 +34,7 @@ func (th *testHolder) DeleteSeeder(t *testing.T) {
func (th *testHolder) PutLeecher(t *testing.T) {
for _, c := range testData {
peer := v4Peer
if c.peer.IP.AddressFamily == bittorrent.IPv6 {
if c.peer.Addr().Is6() {
peer = v6Peer
}
err := th.st.PutLeecher(c.ih, peer)
@@ -52,7 +52,7 @@ func (th *testHolder) DeleteLeecher(t *testing.T) {
func (th *testHolder) AnnouncePeers(t *testing.T) {
for _, c := range testData {
peer := v4Peer
if c.peer.IP.AddressFamily == bittorrent.IPv6 {
if c.peer.Addr().Is6() {
peer = v6Peer
}
_, err := th.st.AnnouncePeers(c.ih, false, 50, peer)
@@ -62,7 +62,7 @@ func (th *testHolder) AnnouncePeers(t *testing.T) {
func (th *testHolder) ScrapeSwarm(t *testing.T) {
for _, c := range testData {
scrape := th.st.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
scrape := th.st.ScrapeSwarm(c.ih, c.peer)
require.Equal(t, uint32(0), scrape.Complete)
require.Equal(t, uint32(0), scrape.Incomplete)
require.Equal(t, uint32(0), scrape.Snatches)
@@ -72,7 +72,7 @@ func (th *testHolder) ScrapeSwarm(t *testing.T) {
func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) {
for _, c := range testData {
peer := v4Peer
if c.peer.IP.AddressFamily == bittorrent.IPv6 {
if c.peer.Addr().Is6() {
peer = v6Peer
}
err := th.st.PutLeecher(c.ih, c.peer)
@@ -87,7 +87,7 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) {
require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer))
scrape := th.st.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
scrape := th.st.ScrapeSwarm(c.ih, c.peer)
require.Equal(t, uint32(2), scrape.Incomplete)
require.Equal(t, uint32(0), scrape.Complete)
@@ -103,7 +103,7 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) {
func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) {
for _, c := range testData {
peer := v4Peer
if c.peer.IP.AddressFamily == bittorrent.IPv6 {
if c.peer.Addr().Is6() {
peer = v6Peer
}
err := th.st.PutSeeder(c.ih, c.peer)
@@ -114,7 +114,7 @@ func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) {
require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer))
scrape := th.st.ScrapeSwarm(c.ih, c.peer.IP.AddressFamily)
scrape := th.st.ScrapeSwarm(c.ih, c.peer)
require.Equal(t, uint32(1), scrape.Incomplete)
require.Equal(t, uint32(1), scrape.Complete)
@@ -130,7 +130,7 @@ func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) {
func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) {
for _, c := range testData {
peer := v4Peer
if c.peer.IP.AddressFamily == bittorrent.IPv6 {
if c.peer.Addr().Is6() {
peer = v6Peer
}
err := th.st.PutLeecher(c.ih, c.peer)

View File

@@ -1,33 +1,36 @@
package test
import (
"net"
"net/netip"
"github.com/sot-tech/mochi/bittorrent"
)
var (
testIh1, testIh2 bittorrent.InfoHash
testPeerID bittorrent.PeerID
testData []hashPeer
v4Peer, v6Peer bittorrent.Peer
testIh1, testIh2 bittorrent.InfoHash
testPeerID0, testPeerID1, testPeerID2, testPeerID3 bittorrent.PeerID
testData []hashPeer
v4Peer, v6Peer bittorrent.Peer
)
func init() {
testIh1, _ = bittorrent.NewInfoHash("00000000000000000001")
testIh2, _ = bittorrent.NewInfoHash("00000000000000000002")
testPeerID, _ = bittorrent.NewPeerID([]byte("00000000000000000001"))
testPeerID0, _ = bittorrent.NewPeerID([]byte("00000000000000000001"))
testPeerID1, _ = bittorrent.NewPeerID([]byte("00000000000000000002"))
testPeerID2, _ = bittorrent.NewPeerID([]byte("99999999999999999994"))
testPeerID3, _ = bittorrent.NewPeerID([]byte("99999999999999999996"))
testData = []hashPeer{
{
testIh1,
bittorrent.Peer{ID: testPeerID, Port: 1, IP: bittorrent.IP{IP: net.ParseIP("1.1.1.1").To4(), AddressFamily: bittorrent.IPv4}},
bittorrent.Peer{ID: testPeerID0, AddrPort: netip.MustParseAddrPort("1.1.1.1:1")},
},
{
testIh2,
bittorrent.Peer{ID: testPeerID, Port: 2, IP: bittorrent.IP{IP: net.ParseIP("abab::0001"), AddressFamily: bittorrent.IPv6}},
bittorrent.Peer{ID: testPeerID1, AddrPort: netip.MustParseAddrPort("[abab::0001]:2")},
},
}
v4Peer = bittorrent.Peer{ID: testPeerID, IP: bittorrent.IP{IP: net.ParseIP("99.99.99.99").To4(), AddressFamily: bittorrent.IPv4}, Port: 9994}
v6Peer = bittorrent.Peer{ID: testPeerID, IP: bittorrent.IP{IP: net.ParseIP("fc00::0001"), AddressFamily: bittorrent.IPv6}, Port: 9996}
v4Peer = bittorrent.Peer{ID: testPeerID2, AddrPort: netip.MustParseAddrPort("99.99.99.99:9994")}
v6Peer = bittorrent.Peer{ID: testPeerID3, AddrPort: netip.MustParseAddrPort("[fc00::0001]:9996")}
}