Files
mochi/middleware/clientapproval/clientapproval.go
Lawrence, Rendall 64b27c2df6 (tested) preserve all addresses of peer
* add multiple addresses in request structures and frontend parsers

* move per-ip peer fetch/store from storage to internal hooks

* fetch/store both v1 and v2 info hashes
2022-04-27 00:52:17 +03:00

103 lines
2.5 KiB
Go

// Package clientapproval implements a Hook that fails an Announce based on a
// whitelist or blacklist of BitTorrent client IDs.
package clientapproval
import (
"context"
"errors"
"fmt"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/middleware"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/storage"
)
// Name is the name by which this middleware is registered with Conf.
const Name = "client approval"
func init() {
middleware.RegisterBuilder(Name, build)
}
var (
// ErrClientUnapproved is the error returned when a client's PeerID is invalid.
ErrClientUnapproved = bittorrent.ClientError("unapproved client")
errBothListsProvided = errors.New("using both whitelist and blacklist is invalid")
)
// Config represents all the values required by this middleware to validate
// peers based on their BitTorrent client ID.
type Config struct {
Whitelist []string
Blacklist []string
}
type hook struct {
approved map[ClientID]struct{}
unapproved map[ClientID]struct{}
}
func build(options conf.MapConfig, _ storage.PeerStorage) (middleware.Hook, error) {
var cfg Config
if err := options.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("middleware %s: %w", Name, err)
}
h := &hook{
approved: make(map[ClientID]struct{}),
unapproved: make(map[ClientID]struct{}),
}
if len(cfg.Whitelist) > 0 && len(cfg.Blacklist) > 0 {
return nil, errBothListsProvided
}
for _, cidString := range cfg.Whitelist {
cidBytes := []byte(cidString)
if len(cidBytes) != 6 {
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
}
var cid ClientID
copy(cid[:], cidBytes)
h.approved[cid] = struct{}{}
}
for _, cidString := range cfg.Blacklist {
cidBytes := []byte(cidString)
if len(cidBytes) != 6 {
return nil, errors.New("client ID " + cidString + " must be 6 bytes")
}
var cid ClientID
copy(cid[:], cidBytes)
h.unapproved[cid] = struct{}{}
}
return h, nil
}
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
clientID := NewClientID(req.ID)
if len(h.approved) > 0 {
if _, found := h.approved[clientID]; !found {
return ctx, ErrClientUnapproved
}
}
if len(h.unapproved) > 0 {
if _, found := h.unapproved[clientID]; found {
return ctx, ErrClientUnapproved
}
}
return ctx, nil
}
func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ *bittorrent.ScrapeResponse) (context.Context, error) {
// Scrapes don't require any protection.
return ctx, nil
}