diff --git a/bittorrent/bittorrent.go b/bittorrent/bittorrent.go index 50d2ce4..689427e 100644 --- a/bittorrent/bittorrent.go +++ b/bittorrent/bittorrent.go @@ -11,7 +11,6 @@ import ( "fmt" "net" "net/netip" - "time" "github.com/pkg/errors" @@ -119,98 +118,6 @@ func (i InfoHash) RawString() string { return string(i) } -// AnnounceRequest represents the parsed parameters from an announce request. -type AnnounceRequest struct { - Event Event - InfoHash InfoHash - Compact bool - EventProvided bool - NumWantProvided bool - IPProvided bool - NumWant uint32 - Left uint64 - Downloaded uint64 - Uploaded uint64 - - Peer - Params -} - -// LogFields renders the current response as a set of log fields. -func (r AnnounceRequest) LogFields() log.Fields { - return log.Fields{ - "event": r.Event, - "infoHash": r.InfoHash, - "compact": r.Compact, - "eventProvided": r.EventProvided, - "numWantProvided": r.NumWantProvided, - "ipProvided": r.IPProvided, - "numWant": r.NumWant, - "left": r.Left, - "downloaded": r.Downloaded, - "uploaded": r.Uploaded, - "peer": r.Peer, - "params": r.Params, - } -} - -// AnnounceResponse represents the parameters used to create an announce -// response. -type AnnounceResponse struct { - Compact bool - Complete uint32 - Incomplete uint32 - Interval time.Duration - MinInterval time.Duration - IPv4Peers []Peer - IPv6Peers []Peer -} - -// LogFields renders the current response as a set of log fields. -func (r AnnounceResponse) LogFields() log.Fields { - return log.Fields{ - "compact": r.Compact, - "complete": r.Complete, - "interval": r.Interval, - "minInterval": r.MinInterval, - "ipv4Peers": r.IPv4Peers, - "ipv6Peers": r.IPv6Peers, - } -} - -// ScrapeRequest represents the parsed parameters from a scrape request. -type ScrapeRequest struct { - // netip.Addr not used in internal logic, - // but MAY be used in middleware (per-ip block etc.) - netip.Addr - InfoHashes []InfoHash - Params Params -} - -// LogFields renders the current response as a set of log fields. -func (r ScrapeRequest) LogFields() log.Fields { - return log.Fields{ - "ip": r.Addr, - "infoHashes": r.InfoHashes, - "params": r.Params, - } -} - -// ScrapeResponse represents the parameters used to create a scrape response. -// -// The Scrapes must be in the same order as the InfoHashes in the corresponding -// ScrapeRequest. -type ScrapeResponse struct { - Files []Scrape -} - -// LogFields renders the current response as a set of Logrus fields. -func (sr ScrapeResponse) LogFields() log.Fields { - return log.Fields{ - "files": sr.Files, - } -} - // Scrape represents the state of a swarm that is returned in a scrape response. type Scrape struct { InfoHash InfoHash diff --git a/bittorrent/request.go b/bittorrent/request.go new file mode 100644 index 0000000..a095109 --- /dev/null +++ b/bittorrent/request.go @@ -0,0 +1,200 @@ +package bittorrent + +import ( + "fmt" + "net/netip" + "sort" + "time" + + "github.com/sot-tech/mochi/pkg/log" +) + +// RequestAddress wrapper for netip.Addr with Provided flag. +// Used in RequestAddresses to determine addresses priority +type RequestAddress struct { + netip.Addr + Provided bool +} + +// Validate checks if netip.Addr is valid and not unspecified (0.0.0.0) +func (a RequestAddress) Validate() bool { + return a.IsValid() && !a.IsUnspecified() +} + +func (a RequestAddress) String() string { + var p string + if a.Provided { + p = "(provided)" + } else { + p = "(detected)" + } + return fmt.Sprint(a.Addr.String(), p) +} + +// RequestAddresses is an array of RequestAddress used mainly for +// sort.Interface implementation. +// Frontends may determine peer's address from connections info +// or from provided values or combine these addresses to fetch maximum +// connection information about peer +type RequestAddresses []RequestAddress + +func (aa RequestAddresses) Len() int { + return len(aa) +} + +// Less returns true only if i-th RequestAddress is marked as +// RequestAddress.Provided and j-th is not (provided address has +// higher priority) +func (aa RequestAddresses) Less(i, j int) bool { + return aa[i].Provided && !aa[j].Provided +} + +func (aa RequestAddresses) Swap(i, j int) { + aa[i], aa[j] = aa[j], aa[i] +} + +// Add checks if provided RequestAddress is valid and adds unmapped +// netip.Addr to array +func (aa *RequestAddresses) Add(a RequestAddress) { + if a.Validate() { + a.Addr = a.Unmap() + *aa = append(*aa, a) + } +} + +// Validate checks if array is not empty and every RequestAddress is valid, +// then sorts addresses with Sort +func (aa RequestAddresses) Validate() bool { + if len(aa) == 0 { + return false + } + for _, a := range aa { + if !a.Validate() { + return false + } + } + if len(aa) > 1 { + sort.Sort(aa) + } + return true +} + +// GetFirst returns first address from array +// or empty netip.Addr if array is empty +func (aa RequestAddresses) GetFirst() netip.Addr { + var a netip.Addr + if len(aa) > 0 { + a = aa[0].Addr + } + return a +} + +// RequestPeer is bundle of peer ID, provided or +// determined addresses and net port +type RequestPeer struct { + ID PeerID + Port uint16 + RequestAddresses +} + +// Peers constructs array of Peer-s with the same ID and Port +// for every RequestAddress array. +func (rp RequestPeer) Peers() (peers []Peer) { + for _, a := range rp.RequestAddresses { + peers = append(peers, Peer{ + ID: rp.ID, + AddrPort: netip.AddrPortFrom(a.Addr, rp.Port), + }) + } + return +} + +// AnnounceRequest represents the parsed parameters from an announce request. +type AnnounceRequest struct { + Event Event + InfoHash InfoHash + Compact bool + EventProvided bool + NumWantProvided bool + NumWant uint32 + Left uint64 + Downloaded uint64 + Uploaded uint64 + + RequestPeer + Params +} + +// LogFields renders the current response as a set of log fields. +func (r AnnounceRequest) LogFields() log.Fields { + return log.Fields{ + "event": r.Event, + "infoHash": r.InfoHash, + "compact": r.Compact, + "eventProvided": r.EventProvided, + "numWantProvided": r.NumWantProvided, + "numWant": r.NumWant, + "left": r.Left, + "downloaded": r.Downloaded, + "uploaded": r.Uploaded, + "peers": r.RequestPeer, + "params": r.Params, + } +} + +// AnnounceResponse represents the parameters used to create an announce +// response. +type AnnounceResponse struct { + Compact bool + Complete uint32 + Incomplete uint32 + Interval time.Duration + MinInterval time.Duration + IPv4Peers []Peer + IPv6Peers []Peer +} + +// LogFields renders the current response as a set of log fields. +func (r AnnounceResponse) LogFields() log.Fields { + return log.Fields{ + "compact": r.Compact, + "complete": r.Complete, + "interval": r.Interval, + "minInterval": r.MinInterval, + "ipv4Peers": r.IPv4Peers, + "ipv6Peers": r.IPv6Peers, + } +} + +// ScrapeRequest represents the parsed parameters from a scrape request. +type ScrapeRequest struct { + // RequestAddresses not used in internal logic, + // but MAY be used in middleware (per-ip block etc.) + RequestAddresses + InfoHashes []InfoHash + Params Params +} + +// LogFields renders the current response as a set of log fields. +func (r ScrapeRequest) LogFields() log.Fields { + return log.Fields{ + "ip": r.RequestAddresses, + "infoHashes": r.InfoHashes, + "params": r.Params, + } +} + +// ScrapeResponse represents the parameters used to create a scrape response. +// +// The Scrapes must be in the same order as the InfoHashes in the corresponding +// ScrapeRequest. +type ScrapeResponse struct { + Files []Scrape +} + +// LogFields renders the current response as a set of Logrus fields. +func (sr ScrapeResponse) LogFields() log.Fields { + return log.Fields{ + "files": sr.Files, + } +} diff --git a/bittorrent/sanitize.go b/bittorrent/sanitize.go index 63ed49b..2d8167e 100644 --- a/bittorrent/sanitize.go +++ b/bittorrent/sanitize.go @@ -1,8 +1,6 @@ package bittorrent import ( - "net/netip" - "github.com/sot-tech/mochi/pkg/log" ) @@ -17,23 +15,23 @@ var ( // SanitizeAnnounce enforces a max and default NumWant and coerces the peer's // IP address into the proper format. func SanitizeAnnounce(r *AnnounceRequest, maxNumWant, defaultNumWant uint32) error { - if r.Port() == 0 { + if r.Port == 0 { return ErrInvalidPort } + if !r.Validate() { + return ErrInvalidIP + } + if !r.NumWantProvided { r.NumWant = defaultNumWant } else if r.NumWant > maxNumWant { r.NumWant = maxNumWant } - r.AddrPort = netip.AddrPortFrom(r.Addr(), r.Port()) - if !r.Addr().IsValid() || r.Addr().IsUnspecified() { - return ErrInvalidIP - } - log.Debug("sanitized announce", r, log.Fields{ - "ipPort": r.AddrPort, + "port": r.Port, + "addresses": r.RequestAddresses, "maxNumWant": maxNumWant, "defaultNumWant": defaultNumWant, }) @@ -41,19 +39,18 @@ func SanitizeAnnounce(r *AnnounceRequest, maxNumWant, defaultNumWant uint32) err } // SanitizeScrape enforces a max number of infohashes for a single scrape -// request. +// request and checks if addresses are valid. func SanitizeScrape(r *ScrapeRequest, maxScrapeInfoHashes uint32) error { if len(r.InfoHashes) > int(maxScrapeInfoHashes) { r.InfoHashes = r.InfoHashes[:maxScrapeInfoHashes] } - r.Addr = r.Addr.Unmap() - if !r.Addr.IsValid() || r.Addr.IsUnspecified() { + if !r.Validate() { return ErrInvalidIP } log.Debug("sanitized scrape", r, log.Fields{ - "ip": r.Addr, + "addresses": r.RequestAddresses, "maxScrapeInfoHashes": maxScrapeInfoHashes, }) return nil diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index 69df3bc..656f428 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -106,32 +106,7 @@ func (cfg Config) Validate() Config { } } - if cfg.MaxNumWant <= 0 { - validcfg.MaxNumWant = defaultMaxNumWant - log.Warn("falling back to default configuration", log.Fields{ - "name": "http.MaxNumWant", - "provided": cfg.MaxNumWant, - "default": validcfg.MaxNumWant, - }) - } - - if cfg.DefaultNumWant <= 0 { - validcfg.DefaultNumWant = defaultDefaultNumWant - log.Warn("falling back to default configuration", log.Fields{ - "name": "http.DefaultNumWant", - "provided": cfg.DefaultNumWant, - "default": validcfg.DefaultNumWant, - }) - } - - if cfg.MaxScrapeInfoHashes <= 0 { - validcfg.MaxScrapeInfoHashes = defaultMaxScrapeInfoHashes - log.Warn("falling back to default configuration", log.Fields{ - "name": "http.MaxScrapeInfoHashes", - "provided": cfg.MaxScrapeInfoHashes, - "default": validcfg.MaxScrapeInfoHashes, - }) - } + validcfg.ParseOptions.ParseOptions = cfg.ParseOptions.ParseOptions.Validate() return validcfg } @@ -286,6 +261,7 @@ func (f *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, ps http var err error var start time.Time var addr netip.Addr + var req *bittorrent.AnnounceRequest if f.EnableRequestTiming && metrics.Enabled() { start = time.Now() defer func() { @@ -293,12 +269,12 @@ func (f *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, ps http }() } - req, err := ParseAnnounce(r, f.ParseOptions) + req, err = ParseAnnounce(r, f.ParseOptions) if err != nil { WriteError(w, err) return } - addr = req.Addr() + addr = req.GetFirst() ctx := injectRouteParamsToContext(context.Background(), ps) ctx, resp, err := f.logic.HandleAnnounce(ctx, req) @@ -334,7 +310,7 @@ func (f *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httpro WriteError(w, err) return } - addr = req.Addr + addr = req.GetFirst() ctx := injectRouteParamsToContext(context.Background(), ps) ctx, resp, err := f.logic.HandleScrape(ctx, req) diff --git a/frontend/http/parser.go b/frontend/http/parser.go index fd8b1ba..16e5f26 100644 --- a/frontend/http/parser.go +++ b/frontend/http/parser.go @@ -6,6 +6,7 @@ import ( "net/netip" "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/frontend" ) // ParseOptions is the configuration used to parse an Announce Request. @@ -14,20 +15,10 @@ import ( // If RealIPHeader is not empty string, the value of the first HTTP Header with // that name will be used. type ParseOptions struct { - AllowIPSpoofing bool `cfg:"allow_ip_spoofing"` - RealIPHeader string `cfg:"real_ip_header"` - MaxNumWant uint32 `cfg:"max_numwant"` - DefaultNumWant uint32 `cfg:"default_numwant"` - MaxScrapeInfoHashes uint32 `cfg:"max_scrape_infohashes"` + frontend.ParseOptions + RealIPHeader string `cfg:"real_ip_header"` } -// Default parser config constants. -const ( - defaultMaxNumWant = 100 - defaultDefaultNumWant = 50 - defaultMaxScrapeInfoHashes = 50 -) - var ( errNoInfoHash = bittorrent.ClientError("no info hash supplied") errMultipleInfoHashes = bittorrent.ClientError("multiple info hashes supplied") @@ -62,7 +53,7 @@ func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequ compactStr, _ := qp.String("compact") request.Compact = compactStr != "" && compactStr != "0" - // Parse the infohash from the request. + // Parse the info hash from the request. infoHashes := qp.InfoHashes() if len(infoHashes) < 1 { return nil, errNoInfoHash @@ -77,7 +68,7 @@ func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequ if !ok { return nil, errInvalidPeerID } - request.Peer.ID, err = bittorrent.NewPeerID([]byte(peerID)) + request.ID, err = bittorrent.NewPeerID([]byte(peerID)) if err != nil { return nil, errInvalidPeerID } @@ -100,24 +91,23 @@ func ParseAnnounce(r *http.Request, opts ParseOptions) (*bittorrent.AnnounceRequ } // Determine the number of peers the client wants in the response. - numwant, err := qp.Uint("numwant", 32) + numWant, err := qp.Uint("numwant", 32) if err != nil && !errors.Is(err, bittorrent.ErrKeyNotFound) { return nil, errInvalidParameterNumWant } - // If there were no errors, the user actually provided the numwant. + // If there were no errors, the user actually provided the numWant. request.NumWantProvided = err == nil - request.NumWant = uint32(numwant) + request.NumWant = uint32(numWant) // Parse the port where the client is listening. port, err := qp.Uint("port", 16) if err != nil { return nil, bittorrent.ErrInvalidPort } + request.Port = uint16(port) // Parse the IP address where the client is listening. - ip, spoofed := requestedIP(r, qp, opts) - request.Peer.AddrPort = netip.AddrPortFrom(ip, uint16(port)) - request.IPProvided = spoofed + request.RequestAddresses = requestedIPs(r, qp, opts) if err = bittorrent.SanitizeAnnounce(request, opts.MaxNumWant, opts.DefaultNumWant); err != nil { request = nil @@ -138,38 +128,43 @@ func ParseScrape(r *http.Request, opts ParseOptions) (*bittorrent.ScrapeRequest, return nil, errNoInfoHash } - ip, _ := requestedIP(r, qp, opts) - request := &bittorrent.ScrapeRequest{ - InfoHashes: infoHashes, - Params: qp, - Addr: ip, + InfoHashes: infoHashes, + Params: qp, + RequestAddresses: requestedIPs(r, qp, opts), } - if err := bittorrent.SanitizeScrape(request, opts.MaxScrapeInfoHashes); err != nil { - return nil, err - } + err = bittorrent.SanitizeScrape(request, opts.MaxScrapeInfoHashes) - return request, nil + return request, err } -// requestedIP determines the IP address for a BitTorrent client request. -func requestedIP(r *http.Request, p bittorrent.Params, opts ParseOptions) (ip netip.Addr, spoofed bool) { +// requestedIPs determines the IP address for a BitTorrent client request. +func requestedIPs(r *http.Request, p bittorrent.Params, opts ParseOptions) (addresses bittorrent.RequestAddresses) { if opts.AllowIPSpoofing { for _, f := range []string{"ip", "ipv4", "ipv6"} { if ipStr, ok := p.String(f); ok { - spoofed = true - ip, _ = netip.ParseAddr(ipStr) - return + addresses.Add(parseRequestAddress(ipStr, true)) } } } if ipStr := r.Header.Get(opts.RealIPHeader); ipStr != "" && opts.RealIPHeader != "" { - ip, _ = netip.ParseAddr(ipStr) + addresses.Add(parseRequestAddress(ipStr, false)) + } else { + addrPort, _ := netip.ParseAddrPort(r.RemoteAddr) + addresses.Add(bittorrent.RequestAddress{ + Addr: addrPort.Addr(), + Provided: false, + }) + } + return +} + +func parseRequestAddress(s string, provided bool) (ra bittorrent.RequestAddress) { + a, e := netip.ParseAddr(s) + if e == nil { + ra.Addr, ra.Provided = a, provided } - - addrPort, _ := netip.ParseAddrPort(r.RemoteAddr) - ip = addrPort.Addr() return } diff --git a/frontend/options.go b/frontend/options.go new file mode 100644 index 0000000..093e6a4 --- /dev/null +++ b/frontend/options.go @@ -0,0 +1,53 @@ +package frontend + +import "github.com/sot-tech/mochi/pkg/log" + +// ParseOptions is the configuration used to parse an Announce Request. +// +// If AllowIPSpoofing is true, IPs provided via params will be used. +type ParseOptions struct { + AllowIPSpoofing bool `cfg:"allow_ip_spoofing"` + MaxNumWant uint32 `cfg:"max_numwant"` + DefaultNumWant uint32 `cfg:"default_numwant"` + MaxScrapeInfoHashes uint32 `cfg:"max_scrape_infohashes"` +} + +// Validate sanity checks values set in a config and returns a new config with +// default values replacing anything that is invalid. +func (op ParseOptions) Validate() ParseOptions { + valid := op + if op.MaxNumWant <= 0 { + valid.MaxNumWant = defaultMaxNumWant + log.Warn("falling back to default configuration", log.Fields{ + "name": "MaxNumWant", + "provided": op.MaxNumWant, + "default": valid.MaxNumWant, + }) + } + + if op.DefaultNumWant <= 0 { + valid.DefaultNumWant = defaultDefaultNumWant + log.Warn("falling back to default configuration", log.Fields{ + "name": "DefaultNumWant", + "provided": op.DefaultNumWant, + "default": valid.DefaultNumWant, + }) + } + + if op.MaxScrapeInfoHashes <= 0 { + valid.MaxScrapeInfoHashes = defaultMaxScrapeInfoHashes + log.Warn("falling back to default configuration", log.Fields{ + "name": "MaxScrapeInfoHashes", + "provided": op.MaxScrapeInfoHashes, + "default": valid.MaxScrapeInfoHashes, + }) + } + return valid +} + +// Default parser config constants. +const ( + defaultMaxNumWant = 100 + defaultDefaultNumWant = 50 + defaultMaxScrapeInfoHashes = 50 +) diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index e67ecce..4d4451b 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -32,7 +32,7 @@ type Config struct { PrivateKey string `cfg:"private_key"` MaxClockSkew time.Duration `cfg:"max_clock_skew"` EnableRequestTiming bool `cfg:"enable_request_timing"` - ParseOptions + frontend.ParseOptions } // LogFields renders the current config as a set of Logrus fields. @@ -67,32 +67,7 @@ func (cfg Config) Validate() Config { log.Warn("UDP private key was not provided, using generated key", log.Fields{"key": validcfg.PrivateKey}) } - if cfg.MaxNumWant <= 0 { - validcfg.MaxNumWant = defaultMaxNumWant - log.Warn("falling back to default configuration", log.Fields{ - "name": "udp.MaxNumWant", - "provided": cfg.MaxNumWant, - "default": validcfg.MaxNumWant, - }) - } - - if cfg.DefaultNumWant <= 0 { - validcfg.DefaultNumWant = defaultDefaultNumWant - log.Warn("falling back to default configuration", log.Fields{ - "name": "udp.DefaultNumWant", - "provided": cfg.DefaultNumWant, - "default": validcfg.DefaultNumWant, - }) - } - - if cfg.MaxScrapeInfoHashes <= 0 { - validcfg.MaxScrapeInfoHashes = defaultMaxScrapeInfoHashes - log.Warn("falling back to default configuration", log.Fields{ - "name": "udp.MaxScrapeInfoHashes", - "provided": cfg.MaxScrapeInfoHashes, - "default": validcfg.MaxScrapeInfoHashes, - }) - } + validcfg.ParseOptions = cfg.ParseOptions.Validate() return validcfg } diff --git a/frontend/udp/parser.go b/frontend/udp/parser.go index a182bfc..bc191fd 100644 --- a/frontend/udp/parser.go +++ b/frontend/udp/parser.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/frontend" ) const ( @@ -49,29 +50,13 @@ var ( errInvalidPeerID = bittorrent.ClientError("invalid info hash") ) -// ParseOptions is the configuration used to parse an Announce Request. -// -// If AllowIPSpoofing is true, IPs provided via params will be used. -type ParseOptions struct { - AllowIPSpoofing bool `cfg:"allow_ip_spoofing"` - MaxNumWant uint32 `cfg:"max_numwant"` - DefaultNumWant uint32 `cfg:"default_numwant"` - MaxScrapeInfoHashes uint32 `cfg:"max_scrape_infohashes"` -} - -// Default parser config constants. -const ( - defaultMaxNumWant = 100 - defaultDefaultNumWant = 50 - defaultMaxScrapeInfoHashes = 50 -) - // ParseAnnounce parses an AnnounceRequest from a UDP request. // // If v6Action is true, the announce is parsed the // "old opentracker way": // https://web.archive.org/web/20170503181830/http://opentracker.blog.h3q.com/2007/12/28/the-ipv6-situation/ -func ParseAnnounce(r Request, v6Action bool, opts ParseOptions) (*bittorrent.AnnounceRequest, error) { +func ParseAnnounce(r Request, v6Action bool, opts frontend.ParseOptions) (*bittorrent.AnnounceRequest, error) { + var err error ipEnd := 84 + net.IPv4len if v6Action { ipEnd = 84 + net.IPv6len @@ -81,70 +66,44 @@ func ParseAnnounce(r Request, v6Action bool, opts ParseOptions) (*bittorrent.Ann return nil, errMalformedPacket } + request := new(bittorrent.AnnounceRequest) + // XXX: pure V2 hashes will cause invalid parsing, // but BEP-52 says, that V2 hashes SHOULD be truncated - infoHash := r.Packet[16:36] - peerIDBytes := r.Packet[36:56] - downloaded := binary.BigEndian.Uint64(r.Packet[56:64]) - left := binary.BigEndian.Uint64(r.Packet[64:72]) - uploaded := binary.BigEndian.Uint64(r.Packet[72:80]) + request.InfoHash, err = bittorrent.NewInfoHash(r.Packet[16:36]) + if err != nil { + return nil, errInvalidInfoHash + } + + request.ID, err = bittorrent.NewPeerID(r.Packet[36:56]) + if err != nil { + return nil, errInvalidPeerID + } + + request.Downloaded = binary.BigEndian.Uint64(r.Packet[56:64]) + request.Left = binary.BigEndian.Uint64(r.Packet[64:72]) + request.Uploaded = binary.BigEndian.Uint64(r.Packet[72:80]) eventID := int(r.Packet[83]) if eventID >= len(eventIDs) { return nil, bittorrent.ErrUnknownEvent } + request.Event, request.EventProvided = eventIDs[eventID], true - ip := r.IP - ipProvided := false + request.Add(bittorrent.RequestAddress{Addr: r.IP}) if opts.AllowIPSpoofing { - ipBytes := r.Packet[84:ipEnd] - spoofed, ok := netip.AddrFromSlice(ipBytes) - if !ok { - return nil, bittorrent.ErrInvalidIP + if spoofed, ok := netip.AddrFromSlice(r.Packet[84:ipEnd]); ok { + request.Add(bittorrent.RequestAddress{Addr: spoofed, Provided: true}) } - ipProvided = true - ip = spoofed - } - if !opts.AllowIPSpoofing && r.IP.IsUnspecified() { - // We have no IP address to fallback on. - return nil, bittorrent.ErrInvalidIP } - numWant := binary.BigEndian.Uint32(r.Packet[ipEnd+4 : ipEnd+8]) - port := binary.BigEndian.Uint16(r.Packet[ipEnd+8 : ipEnd+10]) - - params, err := handleOptionalParameters(r.Packet[ipEnd+10:]) + request.NumWant, request.NumWantProvided = binary.BigEndian.Uint32(r.Packet[ipEnd+4:ipEnd+8]), true + request.Port = binary.BigEndian.Uint16(r.Packet[ipEnd+8 : ipEnd+10]) + request.Params, err = handleOptionalParameters(r.Packet[ipEnd+10:]) if err != nil { return nil, err } - ih, err := bittorrent.NewInfoHash(infoHash) - if err != nil { - return nil, errInvalidInfoHash - } - - peerID, err := bittorrent.NewPeerID(peerIDBytes) - if err != nil { - return nil, errInvalidPeerID - } - - request := &bittorrent.AnnounceRequest{ - Event: eventIDs[eventID], - InfoHash: ih, - NumWant: numWant, - Left: left, - Downloaded: downloaded, - Uploaded: uploaded, - IPProvided: ipProvided, - NumWantProvided: true, - EventProvided: true, - Peer: bittorrent.Peer{ - ID: peerID, - AddrPort: netip.AddrPortFrom(ip, port), - }, - Params: params, - } - if err = bittorrent.SanitizeAnnounce(request, opts.MaxNumWant, opts.DefaultNumWant); err != nil { request = nil } @@ -214,7 +173,7 @@ func handleOptionalParameters(packet []byte) (bittorrent.Params, error) { } // ParseScrape parses a ScrapeRequest from a UDP request. -func ParseScrape(r Request, opts ParseOptions) (*bittorrent.ScrapeRequest, error) { +func ParseScrape(r Request, opts frontend.ParseOptions) (*bittorrent.ScrapeRequest, error) { // If a scrape isn't at least 36 bytes long, it's malformed. if len(r.Packet) < 36 { return nil, errMalformedPacket @@ -243,7 +202,11 @@ func ParseScrape(r Request, opts ParseOptions) (*bittorrent.ScrapeRequest, error } if err == nil { // Sanitize the request. - request = &bittorrent.ScrapeRequest{InfoHashes: infoHashes, Addr: r.IP} + request = &bittorrent.ScrapeRequest{ + InfoHashes: infoHashes, + RequestAddresses: bittorrent.RequestAddresses{bittorrent.RequestAddress{Addr: r.IP}}, + } + err = bittorrent.SanitizeScrape(request, opts.MaxScrapeInfoHashes) } diff --git a/middleware/clientapproval/clientapproval.go b/middleware/clientapproval/clientapproval.go index 607aa7c..644e8ed 100644 --- a/middleware/clientapproval/clientapproval.go +++ b/middleware/clientapproval/clientapproval.go @@ -79,7 +79,7 @@ func build(options conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, erro } func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) { - clientID := NewClientID(req.Peer.ID) + clientID := NewClientID(req.ID) if len(h.approved) > 0 { if _, found := h.approved[clientID]; !found { diff --git a/middleware/clientapproval/clientapproval_test.go b/middleware/clientapproval/clientapproval_test.go index af3a462..8dabd2a 100644 --- a/middleware/clientapproval/clientapproval_test.go +++ b/middleware/clientapproval/clientapproval_test.go @@ -63,7 +63,7 @@ func TestHandleAnnounce(t *testing.T) { peerid, err := bittorrent.NewPeerID([]byte(tt.peerID)) require.Nil(t, err) - req.Peer.ID = peerid + req.ID = peerid nctx, err := h.HandleAnnounce(ctx, req, resp) require.Equal(t, ctx, nctx) diff --git a/middleware/hooks.go b/middleware/hooks.go index 523aa2d..4effc59 100644 --- a/middleware/hooks.go +++ b/middleware/hooks.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/storage" ) @@ -63,9 +64,13 @@ func (h *swarmInteractionHook) HandleAnnounce(ctx context.Context, req *bittorre default: storeFn = h.store.PutLeecher } - - if err = storeFn(req.InfoHash, req.Peer); err == nil && len(req.InfoHash) == bittorrent.InfoHashV2Len { - err = storeFn(req.InfoHash.TruncateV1(), req.Peer) + for _, p := range req.Peers() { + if err = storeFn(req.InfoHash, p); err == nil && len(req.InfoHash) == bittorrent.InfoHashV2Len { + err = storeFn(req.InfoHash.TruncateV1(), p) + } + if err != nil { + break + } } return @@ -84,54 +89,75 @@ type skipResponseHook struct{} // skip. var SkipResponseHookKey = skipResponseHook{} -// type scrapeAddressType struct{} - -// ScrapeIsIPv6Key is the key under which to store whether or not the -// address used to request a scrape was an IPv6 address. -// The value is expected to be of type bool. -// A missing value or a value that is not a bool for this key is equivalent to -// it being set to false. -// var ScrapeIsIPv6Key = scrapeAddressType{} - type responseHook struct { store storage.PeerStorage } +func (h *responseHook) scrape(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) { + leechers, seeders, snatched = h.store.ScrapeSwarm(ih) + if len(ih) == bittorrent.InfoHashV2Len { + l, s, n := h.store.ScrapeSwarm(ih.TruncateV1()) + leechers, seeders, snatched = leechers+l, seeders+s, snatched+n + } + return +} + func (h *responseHook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (_ context.Context, err error) { if ctx.Value(SkipResponseHookKey) != nil { return ctx, nil } // Add the Scrape data to the response. - resp.Incomplete, resp.Complete, _ = h.store.ScrapeSwarm(req.InfoHash) - if len(req.InfoHash) == bittorrent.InfoHashV2Len { - incomplete, complete, _ := h.store.ScrapeSwarm(req.InfoHash.TruncateV1()) - resp.Incomplete, resp.Complete = resp.Incomplete+incomplete, resp.Complete+complete - } + resp.Incomplete, resp.Complete, _ = h.scrape(req.InfoHash) err = h.appendPeers(req, resp) return ctx, err } -func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) error { +type fetchArgs struct { + ih bittorrent.InfoHash + v6 bool +} + +func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (err error) { seeding := req.Left == 0 max := int(req.NumWant) - storePeers, err := h.store.AnnouncePeers(req.InfoHash, seeding, max, req.Peer) - if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) { - return err - } - err = nil - peers := make([]bittorrent.Peer, 0, len(resp.IPv4Peers)+len(resp.IPv6Peers)+len(storePeers)) + peers := make([]bittorrent.Peer, 0, len(resp.IPv4Peers)+len(resp.IPv6Peers)) + primaryIP := req.GetFirst() + v6First := primaryIP.Is6() + args := []fetchArgs{{req.InfoHash, v6First}, {req.InfoHash, !v6First}} - // append peers, which added in middleware - if req.Peer.Addr().Is6() { + if len(req.InfoHash) == bittorrent.InfoHashV2Len { + ih := req.InfoHash.TruncateV1() + args = append(args, fetchArgs{ih, v6First}, fetchArgs{ih, !v6First}) + } + + if v6First { peers = append(peers, resp.IPv6Peers...) peers = append(peers, resp.IPv4Peers...) } else { peers = append(peers, resp.IPv4Peers...) peers = append(peers, resp.IPv6Peers...) } - peers = append(peers, storePeers...) + if l := len(peers); l > max { + peers, max = peers[:max], 0 + } else { + max -= l + } + + for _, a := range args { + if max <= 0 { + break + } + var storePeers []bittorrent.Peer + storePeers, err = h.store.AnnouncePeers(a.ih, seeding, max, a.v6) + if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) { + return err + } + err = nil + peers = append(peers, storePeers...) + max -= len(storePeers) + } // Some clients expect a minimum of their own peer representation returned to // them if they are the only peer in a swarm. @@ -141,31 +167,36 @@ func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittor } else { resp.Incomplete++ } - peers = append(peers, req.Peer) + peers = append(peers, req.Peers()...) } - uniquePeers := make(map[bittorrent.Peer]interface{}, len(peers)) + l := len(peers) + uniquePeers := make(map[bittorrent.Peer]interface{}, l) - resp.IPv4Peers = make([]bittorrent.Peer, 0, len(peers)/2) - resp.IPv6Peers = make([]bittorrent.Peer, 0, len(peers)/2) + resp.IPv4Peers = make([]bittorrent.Peer, 0, l/2) + resp.IPv6Peers = make([]bittorrent.Peer, 0, l/2) for _, p := range peers { - if err != nil || len(uniquePeers) > max { - break - } if _, found := uniquePeers[p]; !found { - uniquePeers[p] = nil if p.Addr().Is6() { resp.IPv6Peers = append(resp.IPv6Peers, p) + uniquePeers[p] = nil } else if p.Addr().Is4() { resp.IPv4Peers = append(resp.IPv4Peers, p) + uniquePeers[p] = nil } else { - err = bittorrent.ErrInvalidIP + log.Warn("received invalid peer from storage", log.Fields{"peer": p}) } } } + log.Info("responseHook announce peers", log.Fields{ + "infoHash": req.InfoHash, + "requestPeer": req.RequestPeer, + "ipv4Peers": resp.IPv4Peers, + "ipv6Peers": resp.IPv6Peers, + }) - return err + return } func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, resp *bittorrent.ScrapeResponse) (context.Context, error) { @@ -175,12 +206,7 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR for _, infoHash := range req.InfoHashes { scr := bittorrent.Scrape{InfoHash: infoHash} - scr.Incomplete, scr.Complete, scr.Snatches = h.store.ScrapeSwarm(infoHash) - if len(infoHash) == bittorrent.InfoHashV2Len { - leechers, seeders, snatched := h.store.ScrapeSwarm(infoHash.TruncateV1()) - scr.Incomplete, scr.Complete, scr.Snatches = scr.Incomplete+leechers, scr.Complete+seeders, scr.Snatches+snatched - } - + scr.Incomplete, scr.Complete, scr.Snatches = h.scrape(infoHash) resp.Files = append(resp.Files, scr) } diff --git a/middleware/logic_test.go b/middleware/logic_test.go index 4650d63..5929354 100644 --- a/middleware/logic_test.go +++ b/middleware/logic_test.go @@ -42,15 +42,15 @@ func (hooks hookList) handleAnnounce(ctx context.Context, req *bittorrent.Announ } func benchHookListV4(b *testing.B, hooks hookList) { - req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{ - AddrPort: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 0), + req := &bittorrent.AnnounceRequest{RequestPeer: bittorrent.RequestPeer{ + RequestAddresses: []bittorrent.RequestAddress{{Addr: netip.MustParseAddr("1.2.3.4")}}, }} benchHookList(b, hooks, req) } func benchHookListV6(b *testing.B, hooks hookList) { - req := &bittorrent.AnnounceRequest{Peer: bittorrent.Peer{ - AddrPort: netip.AddrPortFrom(netip.MustParseAddr("fc00:0001"), 0), + req := &bittorrent.AnnounceRequest{RequestPeer: bittorrent.RequestPeer{ + RequestAddresses: []bittorrent.RequestAddress{{Addr: netip.MustParseAddr("fc00::0001")}}, }} benchHookList(b, hooks, req) } diff --git a/middleware/pkg/random/entropy.go b/middleware/pkg/random/entropy.go index b070a50..6fd01bf 100644 --- a/middleware/pkg/random/entropy.go +++ b/middleware/pkg/random/entropy.go @@ -14,6 +14,6 @@ func DeriveEntropyFromRequest(req *bittorrent.AnnounceRequest) (v0 uint64, v1 ui if len(req.InfoHash) >= bittorrent.InfoHashV1Len { v0 = binary.BigEndian.Uint64([]byte(req.InfoHash[:8])) + binary.BigEndian.Uint64([]byte(req.InfoHash[8:16])) } - v1 = binary.BigEndian.Uint64(req.Peer.ID[:8]) + binary.BigEndian.Uint64(req.Peer.ID[8:16]) + v1 = binary.BigEndian.Uint64(req.ID[:8]) + binary.BigEndian.Uint64(req.ID[8:16]) return } diff --git a/storage/keydb/storage.go b/storage/keydb/storage.go index f1506c5..07eb4c7 100644 --- a/storage/keydb/storage.go +++ b/storage/keydb/storage.go @@ -152,15 +152,15 @@ func (s store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) (er } // AnnouncePeers is the same function as redis.AnnouncePeers -func (s store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, peer bittorrent.Peer) ([]bittorrent.Peer, error) { +func (s store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) { log.Debug("storage: KeyDB: AnnouncePeers", log.Fields{ "infoHash": ih, "seeder": seeder, "numWant": numWant, - "peer": peer, + "v6": v6, }) - return s.GetPeers(ih, seeder, numWant, peer, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd { + return s.GetPeers(ih, seeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd { return s.SRandMemberN(context.TODO(), infoHashKey, int64(maxCount)) }) } diff --git a/storage/memory/storage.go b/storage/memory/storage.go index d73a7b0..29366b8 100644 --- a/storage/memory/storage.go +++ b/storage/memory/storage.go @@ -193,6 +193,7 @@ func (ps *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error shard := ps.shards[ps.shardIndex(ih, p.Addr().Is6())] shard.Lock() + defer shard.Unlock() if _, ok := shard.swarms[ih]; !ok { shard.swarms[ih] = swarm{ @@ -209,7 +210,6 @@ func (ps *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error // Update the peer in the swarm. shard.swarms[ih].seeders[p] = ps.getClock() - shard.Unlock() return nil } @@ -222,14 +222,13 @@ func (ps *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) err shard := ps.shards[ps.shardIndex(ih, p.Addr().Is6())] shard.Lock() + defer shard.Unlock() if _, ok := shard.swarms[ih]; !ok { - shard.Unlock() return storage.ErrResourceDoesNotExist } if _, ok := shard.swarms[ih].seeders[p]; !ok { - shard.Unlock() return storage.ErrResourceDoesNotExist } @@ -240,7 +239,6 @@ func (ps *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) err delete(shard.swarms, ih) } - shard.Unlock() return nil } @@ -253,6 +251,7 @@ func (ps *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error shard := ps.shards[ps.shardIndex(ih, p.Addr().Is6())] shard.Lock() + defer shard.Unlock() if _, ok := shard.swarms[ih]; !ok { shard.swarms[ih] = swarm{ @@ -269,7 +268,6 @@ func (ps *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error // Update the peer in the swarm. shard.swarms[ih].leechers[p] = ps.getClock() - shard.Unlock() return nil } @@ -282,14 +280,13 @@ func (ps *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) er shard := ps.shards[ps.shardIndex(ih, p.Addr().Is6())] shard.Lock() + defer shard.Unlock() if _, ok := shard.swarms[ih]; !ok { - shard.Unlock() return storage.ErrResourceDoesNotExist } if _, ok := shard.swarms[ih].leechers[p]; !ok { - shard.Unlock() return storage.ErrResourceDoesNotExist } @@ -300,7 +297,6 @@ func (ps *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) er delete(shard.swarms, ih) } - shard.Unlock() return nil } @@ -313,6 +309,7 @@ func (ps *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) shard := ps.shards[ps.shardIndex(ih, p.Addr().Is6())] shard.Lock() + defer shard.Unlock() if _, ok := shard.swarms[ih]; !ok { shard.swarms[ih] = swarm{ @@ -335,60 +332,44 @@ func (ps *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) // Update the peer in the swarm. shard.swarms[ih].seeders[p] = ps.getClock() - shard.Unlock() return nil } -func parsePeers(peersMap map[bittorrent.Peer]int64, maxCount int, skipPeer bittorrent.Peer) (peers []bittorrent.Peer) { +func parsePeers(peersMap map[bittorrent.Peer]int64, maxCount int) (peers []bittorrent.Peer) { for p := range peersMap { if maxCount == 0 { break } - if p != skipPeer { - peers = append(peers, p) - maxCount-- - } + peers = append(peers, p) + maxCount-- } return } -func (ps *peerStore) getPeers(shard *peerShard, ih bittorrent.InfoHash, maxCount int, leechersOnly bool, skipPeer bittorrent.Peer) (peers []bittorrent.Peer) { +func (ps *peerStore) getPeers(shard *peerShard, ih bittorrent.InfoHash, maxCount int, forSeeder bool) (peers []bittorrent.Peer) { shard.RLock() defer shard.RUnlock() if swarm, ok := shard.swarms[ih]; ok { - if !leechersOnly { - peers = append(peers, parsePeers(swarm.seeders, maxCount, skipPeer)...) - maxCount -= len(peers) - } - if maxCount > 0 { - peers = append(peers, parsePeers(swarm.leechers, maxCount, skipPeer)...) + if forSeeder { + peers = parsePeers(swarm.leechers, maxCount) + } else { + peers = append(peers, parsePeers(swarm.seeders, maxCount)...) + if maxCount -= len(peers); maxCount > 0 { + peers = append(peers, parsePeers(swarm.leechers, maxCount)...) + } } } return } -func (ps *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, peer bittorrent.Peer) (peers []bittorrent.Peer, err error) { +func (ps *peerStore) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) { select { case <-ps.closed: panic("attempted to interact with stopped memory store") default: } - isV6 := peer.Addr().Is6() - - if seeder { - // Append leechers as possible. - peers = ps.getPeers(ps.shards[ps.shardIndex(ih, isV6)], ih, numWant, true, peer) - if numWant -= len(peers); numWant > 0 { - peers = append(peers, ps.getPeers(ps.shards[ps.shardIndex(ih, !isV6)], ih, numWant, true, peer)...) - } - } else { - // Append as many seeders as possible. - peers = ps.getPeers(ps.shards[ps.shardIndex(ih, isV6)], ih, numWant, false, peer) - if numWant -= len(peers); numWant > 0 { - peers = append(peers, ps.getPeers(ps.shards[ps.shardIndex(ih, !isV6)], ih, numWant, false, peer)...) - } - } + peers = ps.getPeers(ps.shards[ps.shardIndex(ih, v6)], ih, numWant, seeder) return } diff --git a/storage/redis/storage.go b/storage/redis/storage.go index 3e36cdc..1e3a9c3 100644 --- a/storage/redis/storage.go +++ b/storage/redis/storage.go @@ -457,17 +457,15 @@ func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) e }) } -func (ps Connection) parsePeersList(peersResult *redis.StringSliceCmd, skipPeerID string) (peers []bittorrent.Peer, err error) { +func (ps Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers []bittorrent.Peer, err error) { var peerIds []string peerIds, err = peersResult.Result() if err = AsNil(err); err == nil { for _, peerID := range peerIds { - if peerID != skipPeerID { - if p, err := bittorrent.NewPeer(peerID); err == nil { - peers = append(peers, p) - } else { - log.Error("storage: Redis: unable to decode leecher", log.Fields{"peerID": peerID}) - } + if p, err := bittorrent.NewPeer(peerID); err == nil { + peers = append(peers, p) + } else { + log.Error("storage: Redis: unable to decode leecher", log.Fields{"peerID": peerID}) } } } @@ -480,29 +478,21 @@ type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd // converts result to bittorrent.Peer array. // If forSeeder set to true - returns only leechers, if false - // seeders and if maxCount not reached - leechers. -func (ps Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount int, peer bittorrent.Peer, membersFn getPeersFn) (out []bittorrent.Peer, err error) { - infoHash, peerID, isV6 := ih.RawString(), peer.RawString(), peer.Addr().Is6() +func (ps Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) { + infoHash := ih.RawString() - var infoHashKeys []string + infoHashKeys := make([]string, 1, 2) if forSeeder { - infoHashKeys = append(infoHashKeys, InfoHashKey(infoHash, false, isV6), - InfoHashKey(infoHash, false, !isV6)) + infoHashKeys[0] = InfoHashKey(infoHash, false, isV6) } else { - // Append as many peers as possible. - // Priority: - // same ip family seeders > same ip family leechers > - // foreign ip family seeders > foreign ip family leechers - infoHashKeys = append(infoHashKeys, - InfoHashKey(infoHash, true, isV6), - InfoHashKey(infoHash, false, isV6), - InfoHashKey(infoHash, true, !isV6), - InfoHashKey(infoHash, false, !isV6)) + infoHashKeys[0] = InfoHashKey(infoHash, true, isV6) + infoHashKeys = append(infoHashKeys, InfoHashKey(infoHash, false, isV6)) } for _, infoHashKey := range infoHashKeys { var peers []bittorrent.Peer - peers, err = ps.parsePeersList(membersFn(context.TODO(), infoHashKey, maxCount), peerID) + peers, err = ps.parsePeersList(membersFn(context.TODO(), infoHashKey, maxCount)) maxCount -= len(peers) out = append(out, peers...) if err != nil || maxCount <= 0 { @@ -525,15 +515,15 @@ func (ps Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount i return } -func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, peer bittorrent.Peer) ([]bittorrent.Peer, error) { +func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) { log.Debug("storage: Redis: AnnouncePeers", log.Fields{ "infoHash": ih, "seeder": seeder, "numWant": numWant, - "peer": peer, + "peer": v6, }) - return ps.GetPeers(ih, seeder, numWant, peer, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd { + return ps.GetPeers(ih, seeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd { return ps.HRandField(ctx, infoHashKey, maxCount, false) }) } diff --git a/storage/storage.go b/storage/storage.go index 50c97d8..37036e9 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -148,33 +148,33 @@ type PeerStorage interface { DataStorage // PutSeeder adds a Seeder to the Swarm identified by the provided // InfoHash. - PutSeeder(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error + PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error // DeleteSeeder removes a Seeder from the Swarm identified by the // provided InfoHash. // // If the Swarm or Peer does not exist, this function returns // ErrResourceDoesNotExist. - DeleteSeeder(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error + DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error // PutLeecher adds a Leecher to the Swarm identified by the provided // InfoHash. // If the Swarm does not exist already, it is created. - PutLeecher(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error + PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error // DeleteLeecher removes a Leecher from the Swarm identified by the // provided InfoHash. // // If the Swarm or Peer does not exist, this function returns // ErrResourceDoesNotExist. - DeleteLeecher(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error + DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error // GraduateLeecher promotes a Leecher to a Seeder in the Swarm // identified by the provided InfoHash. // // If the given Peer is not present as a Leecher or the swarm does not exist // already, the Peer is added as a Seeder and no error is returned. - GraduateLeecher(infoHash bittorrent.InfoHash, peer bittorrent.Peer) error + GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error // AnnouncePeers is a best effort attempt to return Peers from the Swarm // identified by the provided InfoHash. @@ -191,7 +191,7 @@ type PeerStorage interface { // leechers // // Returns ErrResourceDoesNotExist if the provided InfoHash is not tracked. - AnnouncePeers(infoHash bittorrent.InfoHash, seeder bool, numWant int, peer bittorrent.Peer) (peers []bittorrent.Peer, err error) + AnnouncePeers(ih bittorrent.InfoHash, seeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) // ScrapeSwarm returns information required to answer a Scrape request // about a Swarm identified by the given InfoHash. @@ -201,7 +201,7 @@ type PeerStorage interface { // filling the Snatches field is optional. // // If the Swarm does not exist, an empty Scrape and no error is returned. - ScrapeSwarm(infoHash bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) + ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) // Stopper is an interface that expects a Stop method to stop the PeerStorage. // For more details see the documentation in the stop package. diff --git a/storage/test/storage_bench.go b/storage/test/storage_bench.go index b8020fd..53b57c8 100644 --- a/storage/test/storage_bench.go +++ b/storage/test/storage_bench.go @@ -404,7 +404,7 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error { // AnnounceLeecher can run in parallel. func (bh *benchHolder) AnnounceLeecher(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[0], false, 50, bd.peers[0]) + _, err := ps.AnnouncePeers(bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -415,7 +415,7 @@ func (bh *benchHolder) AnnounceLeecher(b *testing.B) { // AnnounceLeecher1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], false, 50, bd.peers[0]) + _, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -426,7 +426,7 @@ func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { // AnnounceSeeder can run in parallel. func (bh *benchHolder) AnnounceSeeder(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[0], true, 50, bd.peers[0]) + _, err := ps.AnnouncePeers(bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -437,7 +437,7 @@ func (bh *benchHolder) AnnounceSeeder(b *testing.B) { // AnnounceSeeder1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], true, 50, bd.peers[0]) + _, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6()) return err }) } diff --git a/storage/test/storage_test_base.go b/storage/test/storage_test_base.go index 5f1ad59..5d78410 100644 --- a/storage/test/storage_test_base.go +++ b/storage/test/storage_test_base.go @@ -54,11 +54,7 @@ func (th *testHolder) DeleteLeecher(t *testing.T) { func (th *testHolder) AnnouncePeers(t *testing.T) { for _, c := range testData { - peer := v4Peer - if c.peer.Addr().Is6() { - peer = v6Peer - } - _, err := th.st.AnnouncePeers(c.ih, false, 50, peer) + _, err := th.st.AnnouncePeers(c.ih, false, 50, c.peer.Addr().Is6()) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -77,19 +73,16 @@ func (th *testHolder) ScrapeSwarm(t *testing.T) { func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { - peer := v4Peer - if c.peer.Addr().Is6() { - peer = v6Peer - } + isV6 := c.peer.Addr().Is6() err := th.st.PutLeecher(c.ih, c.peer) require.Nil(t, err) - peers, err := th.st.AnnouncePeers(c.ih, true, 50, peer) + peers, err := th.st.AnnouncePeers(c.ih, true, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // non-seeder announce should still return the leecher - peers, err = th.st.AnnouncePeers(c.ih, false, 50, peer) + peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) @@ -100,7 +93,7 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { err = th.st.DeleteLeecher(c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(c.ih, true, 50, peer) + peers, err = th.st.AnnouncePeers(c.ih, true, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -111,15 +104,12 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { - peer := v4Peer - if c.peer.Addr().Is6() { - peer = v6Peer - } + isV6 := c.peer.Addr().Is6() err := th.st.PutSeeder(c.ih, c.peer) require.Nil(t, err) // Should be leecher to see the seeder - peers, err := th.st.AnnouncePeers(c.ih, false, 50, peer) + peers, err := th.st.AnnouncePeers(c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) @@ -130,7 +120,7 @@ func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) { err = th.st.DeleteSeeder(c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(c.ih, false, 50, peer) + peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -141,8 +131,9 @@ func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { + isV6 := c.peer.Addr().Is6() peer := v4Peer - if c.peer.Addr().Is6() { + if isV6 { peer = v6Peer } err := th.st.PutLeecher(c.ih, c.peer) @@ -152,7 +143,7 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { require.Nil(t, err) // Has to be leecher to see the graduated seeder - peers, err := th.st.AnnouncePeers(c.ih, false, 50, peer) + peers, err := th.st.AnnouncePeers(c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) @@ -161,7 +152,7 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { require.Equal(t, storage.ErrResourceDoesNotExist, err) // Verify it's still there - peers, err = th.st.AnnouncePeers(c.ih, false, 50, peer) + peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer))