diff --git a/cmd/chihaya/e2e.go b/cmd/chihaya/e2e.go index b1530f0..b527e1c 100644 --- a/cmd/chihaya/e2e.go +++ b/cmd/chihaya/e2e.go @@ -65,7 +65,7 @@ func generateInfohash() [20]byte { panic(fmt.Errorf("not enough randomness? Got %d bytes", n)) } - return [20]byte(bittorrent.InfoHashFromBytes(b)) + return bittorrent.InfoHashFromBytes(b) } func test(addr string, delay time.Duration) error { diff --git a/middleware/torrentapproval/container/container.go b/middleware/torrentapproval/container/container.go index 9f6648d..28f718b 100644 --- a/middleware/torrentapproval/container/container.go +++ b/middleware/torrentapproval/container/container.go @@ -3,25 +3,21 @@ package container import ( "errors" "github.com/chihaya/chihaya/bittorrent" - "github.com/chihaya/chihaya/pkg/stop" - "gopkg.in/yaml.v2" "sync" ) -type Constructor func () Configuration - -type Configuration interface { - Build() (Container, error) +type Builder interface { + Build([]byte) (Container, error) } var ( - constructorsMU sync.Mutex - constructors = make(map[string]Constructor) + buildersMU sync.Mutex + builders = make(map[string]Builder) ErrContainerDoesNotExist = errors.New("torrent hash container with that name does not exist") ) -func Register(n string, c Constructor) { +func Register(n string, c Builder) { if len(n) == 0 { panic("middleware: could not register a Container with an empty name") } @@ -29,28 +25,25 @@ func Register(n string, c Constructor) { panic("middleware: could not register a Container with nil builder constructor") } - constructorsMU.Lock() - defer constructorsMU.Unlock() - constructors[n] = c + buildersMU.Lock() + defer buildersMU.Unlock() + builders[n] = c } type Container interface { - stop.Stopper Contains(bittorrent.InfoHash) bool } func GetContainer(name string, confBytes []byte) (Container, error) { - constructorsMU.Lock() - defer constructorsMU.Unlock() + + buildersMU.Lock() + defer buildersMU.Unlock() var err error var cn Container - if getConfig, exist := constructors[name]; !exist { + if builder, exist := builders[name]; !exist { err = ErrContainerDoesNotExist } else { - conf := getConfig() - if err = yaml.Unmarshal(confBytes, &conf); err == nil { - cn, err = conf.Build() - } + cn, err = builder.Build(confBytes) } return cn, err } diff --git a/middleware/torrentapproval/container/directory/directory.go b/middleware/torrentapproval/container/directory/directory.go index e95de54..1635d1a 100644 --- a/middleware/torrentapproval/container/directory/directory.go +++ b/middleware/torrentapproval/container/directory/directory.go @@ -6,35 +6,40 @@ import ( "github.com/chihaya/chihaya/middleware/torrentapproval/container" "github.com/chihaya/chihaya/middleware/torrentapproval/container/list" "github.com/chihaya/chihaya/pkg/stop" + "gopkg.in/yaml.v2" "sync" ) func init() { - container.Register("list", func() container.Configuration { - return Config{} - }) + container.Register("directory", builder{}) } +type builder struct {} + type Config struct { WhitelistPath string `yaml:"whitelist_path"` BlacklistPath string `yaml:"blacklist_path"` } -func (b Config) Build() (container.Container, error) { - if len(b.WhitelistPath) > 0 && len(b.BlacklistPath) > 0 { +func (b builder) Build(confBytes []byte) (container.Container, error) { + c := new(Config) + if err := yaml.Unmarshal(confBytes, c); err != nil { + return nil, fmt.Errorf("unable to deserialise configuration: %v", err) + } + if len(c.WhitelistPath) > 0 && len(c.BlacklistPath) > 0 { return nil, fmt.Errorf("using both whitelist and blacklist is invalid") } var err error lst := &directory{ List: list.List{ Hashes: sync.Map{}, - Invert: len(b.WhitelistPath) == 0, + Invert: len(c.WhitelistPath) == 0, }, watcher: nil, } - dir := b.WhitelistPath + dir := c.WhitelistPath if lst.Invert { - dir = b.BlacklistPath + dir = c.BlacklistPath } var w *dirwatch.Instance w, err = dirwatch.New(dir) diff --git a/middleware/torrentapproval/container/list/list.go b/middleware/torrentapproval/container/list/list.go index fd948c2..77a89c9 100644 --- a/middleware/torrentapproval/container/list/list.go +++ b/middleware/torrentapproval/container/list/list.go @@ -5,16 +5,16 @@ import ( "fmt" "github.com/chihaya/chihaya/bittorrent" "github.com/chihaya/chihaya/middleware/torrentapproval/container" - "github.com/chihaya/chihaya/pkg/stop" + "gopkg.in/yaml.v2" "sync" ) func init() { - container.Register("list", func() container.Configuration { - return Config{} - }) + container.Register("list", builder{}) } +type builder struct {} + type Config struct { Whitelist []string `yaml:"whitelist"` Blacklist []string `yaml:"blacklist"` @@ -22,7 +22,11 @@ type Config struct { var DUMMY struct{} -func (c Config) Build() (container.Container, error) { +func (b builder) Build(confBytes []byte) (container.Container, error) { + c := new(Config) + if err := yaml.Unmarshal(confBytes, c); err != nil { + return nil, fmt.Errorf("unable to deserialise configuration: %v", err) + } if len(c.Whitelist) > 0 && len(c.Blacklist) > 0 { return nil, fmt.Errorf("using both whitelist and blacklist is invalid") } @@ -55,10 +59,6 @@ type List struct { Hashes sync.Map } -func (l *List) Stop() stop.Result { - return stop.AlreadyStopped -} - func (l *List) Contains(hash bittorrent.InfoHash) bool { _, result := l.Hashes.Load(hash) return result != l.Invert diff --git a/middleware/torrentapproval/torrentapproval.go b/middleware/torrentapproval/torrentapproval.go index 7030dd9..45e34b5 100644 --- a/middleware/torrentapproval/torrentapproval.go +++ b/middleware/torrentapproval/torrentapproval.go @@ -11,6 +11,8 @@ import ( "github.com/chihaya/chihaya/bittorrent" "github.com/chihaya/chihaya/middleware" + _ "github.com/chihaya/chihaya/middleware/torrentapproval/container/directory" + _ "github.com/chihaya/chihaya/middleware/torrentapproval/container/list" ) // Name is the name by which this middleware is registered with Chihaya. @@ -73,5 +75,8 @@ func (h *hook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeRequest, } func (h *hook) Stop() stop.Result { - return h.hashContainer.Stop() + if st, isOk := h.hashContainer.(stop.Stopper); isOk{ + return st.Stop() + } + return stop.AlreadyStopped } diff --git a/middleware/torrentapproval/torrentapproval_test.go b/middleware/torrentapproval/torrentapproval_test.go index 16da4b3..6781408 100644 --- a/middleware/torrentapproval/torrentapproval_test.go +++ b/middleware/torrentapproval/torrentapproval_test.go @@ -68,7 +68,7 @@ func TestHandleAnnounce(t *testing.T) { for _, tt := range cases { t.Run(fmt.Sprintf("testing hash %s", tt.ih), func(t *testing.T) { d := driver{} - cfg, err := yaml.Marshal(tt) + cfg, err := yaml.Marshal(tt.cfg) require.Nil(t, err) h, err := d.NewHook(cfg) require.Nil(t, err)