diff --git a/middleware/torrentapproval/container/directory/directory.go b/middleware/torrentapproval/container/directory/directory.go index 027b0fd..979fa43 100644 --- a/middleware/torrentapproval/container/directory/directory.go +++ b/middleware/torrentapproval/container/directory/directory.go @@ -10,6 +10,7 @@ import ( "crypto/sha256" "fmt" "io" + "iter" "os" "path/filepath" "strings" @@ -53,22 +54,6 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er return nil, fmt.Errorf("unable to deserialise configuration: %w", err) } var err error - d := &directory{ - List: list.List{ - Invert: c.Invert, - Storage: st, - StorageCtx: c.StorageCtx, - }, - closed: make(chan bool), - } - if len(d.StorageCtx) == 0 { - logger.Warn(). - Str("name", "StorageCtx"). - Str("provided", d.StorageCtx). - Str("default", container.DefaultStorageCtxName). - Msg("falling back to default configuration") - d.StorageCtx = container.DefaultStorageCtxName - } if c.Period == 0 { logger.Warn(). Str("name", "Period"). @@ -77,7 +62,12 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er Msg("falling back to default configuration") c.Period = defaultPeriod } - go d.runScan(c.Path, c.Period) + d := NewScanner(list.List{ + Invert: c.Invert, + Storage: st, + StorageCtx: c.StorageCtx, + }, path(c.Path)) + go d.Run(c.Period) return d, err } @@ -99,7 +89,71 @@ type torrentNameInfoStruct struct { Name string `bencode:"name"` } -func (d *directory) runScan(path string, period time.Duration) { +// PathReader - interface for abstract directory reader +type PathReader interface { + // ReadDir returns names of torrent entries. + // Implementation must return absolute names of entries + // to fetch torrent file-like data. + ReadDir() (it iter.Seq[string], err error) + // ReadData returns reader for entry data + ReadData(entry string) (io.ReadCloser, error) +} + +type path string + +var _ PathReader = path("") + +func (p path) ReadDir() (it iter.Seq[string], err error) { + var entries []os.DirEntry + dir := string(p) + if entries, err = os.ReadDir(dir); err == nil { + it = func(yield func(string) bool) { + for _, e := range entries { + if !e.IsDir() && strings.ToLower(filepath.Ext(e.Name())) == ".torrent" { + if !yield(filepath.Join(dir, e.Name())) { + return + } + } + } + } + } + return it, err +} + +func (p path) ReadData(entry string) (io.ReadCloser, error) { + return os.Open(entry) +} + +// NewScanner creates Scanner instance. +func NewScanner(list list.List, reader PathReader) *Scanner { + if len(list.StorageCtx) == 0 { + logger.Warn(). + Str("name", "StorageCtx"). + Str("provided", list.StorageCtx). + Str("default", container.DefaultStorageCtxName). + Msg("falling back to default configuration") + list.StorageCtx = container.DefaultStorageCtxName + } + return &Scanner{ + List: list, + reader: reader, + closed: make(chan bool), + } +} + +// Scanner holds list of approved/rejected torrents +type Scanner struct { + list.List + reader PathReader + closed chan bool +} + +// Run starts periodic directory scanning and blocks until Stop called +func (d *Scanner) Run(period time.Duration) { + if d.reader == nil { + log.Warn().Msg("reader not provided") + return + } t := time.NewTicker(period) defer t.Stop() files := make(map[string][2]bittorrent.InfoHash) @@ -112,16 +166,14 @@ func (d *directory) runScan(path string, period time.Duration) { return case <-t.C: logger.Debug().Msg("starting directory scan") - if entries, err := os.ReadDir(path); err == nil { - for _, e := range entries { - if !e.IsDir() && strings.ToLower(filepath.Ext(e.Name())) == ".torrent" { - tmpFiles[filepath.Join(path, e.Name())] = true - } + if entries, err := d.reader.ReadDir(); err == nil { + for e := range entries { + tmpFiles[e] = true } for p := range tmpFiles { if _, exists := files[p]; !exists { - var f *os.File - if f, err = os.Open(p); err == nil { + var f io.ReadCloser + if f, err = d.reader.ReadData(p); err == nil { var info torrentRawInfoStruct err = bencode.NewDecoder(io.LimitReader(f, maxTorrentSize)).Decode(&info) _ = f.Close() @@ -186,13 +238,8 @@ func (d *directory) runScan(path string, period time.Duration) { } } -type directory struct { - list.List - closed chan bool -} - // Close closes watching of torrent directory -func (d *directory) Close() error { +func (d *Scanner) Close() error { if d.closed != nil { close(d.closed) } diff --git a/middleware/torrentapproval/container/s3/s3.go b/middleware/torrentapproval/container/s3/s3.go index 41c178b..c3a127c 100644 --- a/middleware/torrentapproval/container/s3/s3.go +++ b/middleware/torrentapproval/container/s3/s3.go @@ -1,11 +1,13 @@ +// Package s3 implements container which +// checks if hash present in any of torrent file +// placed in S3-like storage. package s3 import ( "context" - "crypto/sha1" - "crypto/sha256" "fmt" "io" + "iter" "path/filepath" "strings" "time" @@ -13,37 +15,28 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" awss3 "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/sot-tech/mochi/bittorrent" + "github.com/sot-tech/mochi/middleware/torrentapproval/container" + "github.com/sot-tech/mochi/middleware/torrentapproval/container/directory" "github.com/sot-tech/mochi/middleware/torrentapproval/container/list" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/str2bytes" "github.com/sot-tech/mochi/storage" - "github.com/zeebo/bencode" ) var logger = log.NewLogger("middleware/torrent approval/s3") -const ( - defaultPeriod = time.Minute - maxTorrentSize = 10 * 1024 * 1024 -) +const defaultPeriod = time.Minute -// Config - implementation of directory container configuration. +// Config - implementation of S3 container configuration. // Extends list.Config because uses the same storage and Approved function. type Config struct { list.Config Bucket string - Path string + Prefix string Period time.Duration } -type s3 struct { - list.List - closed chan bool -} - func init() { container.Register("s3", build) } @@ -53,23 +46,6 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er if err := conf.Unmarshal(c); err != nil { return nil, fmt.Errorf("unable to deserialise configuration: %w", err) } - var err error - s := &s3{ - List: list.List{ - Invert: c.Invert, - Storage: st, - StorageCtx: c.StorageCtx, - }, - closed: make(chan bool), - } - if len(s.StorageCtx) == 0 { - logger.Warn(). - Str("name", "StorageCtx"). - Str("provided", s.StorageCtx). - Str("default", container.DefaultStorageCtxName). - Msg("falling back to default configuration") - s.StorageCtx = container.DefaultStorageCtxName - } if c.Period == 0 { logger.Warn(). Str("name", "Period"). @@ -79,124 +55,55 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er c.Period = defaultPeriod } - ctx := context.Background() - sdkConfig, err := config.LoadDefaultConfig(ctx) + sdkConfig, err := config.LoadDefaultConfig(context.Background()) if err != nil { - return nil, fmt.Errorf("unable load aws sdk configuration: %w", err) + return nil, fmt.Errorf("unable load AWS S3 SDK configuration: %w", err) } - s3Client := awss3.NewFromConfig(sdkConfig) - go s.runScan(ctx, c.Bucket, c.Path, s3Client, c.Period) + s := directory.NewScanner(list.List{ + Invert: c.Invert, + Storage: st, + StorageCtx: c.StorageCtx, + }, s3{client: awss3.NewFromConfig(sdkConfig), bucket: c.Bucket, prefix: c.Prefix}) + go s.Run(c.Period) + return s, err } -// BencodeRawBytes wrapper for byte slice to get raw 'info' section from -// torrent file -type BencodeRawBytes []byte - -// UnmarshalBencode just appends raw byte slice to result -func (ba *BencodeRawBytes) UnmarshalBencode(in []byte) error { - *ba = append([]byte(nil), in...) - return nil +type s3 struct { + client *awss3.Client + bucket, prefix string } -type torrentRawInfoStruct struct { - Info BencodeRawBytes `bencode:"info"` -} +var _ directory.PathReader = s3{} -type torrentNameInfoStruct struct { - Name string `bencode:"name"` -} - -func (s *s3) runScan(ctx context.Context, bucket, prefix string, s3Client *awss3.Client, period time.Duration) { - t := time.NewTicker(period) - defer t.Stop() - files := make(map[string][2]bittorrent.InfoHash) - tmpFiles := make(map[string]bool) - // nolint:gosec - s1, s2 := sha1.New(), sha256.New() - for { - select { - case <-s.closed: - return - case <-t.C: - logger.Debug().Msg("starting directory scan") - listObj := &awss3.ListObjectsV2Input{Bucket: &bucket, Prefix: &prefix} - if entries, err := s3Client.ListObjectsV2(ctx, listObj); err == nil { - for _, e := range entries.Contents { - if strings.ToLower(filepath.Ext(*e.Key)) == ".torrent" { - tmpFiles[filepath.Join(prefix, *e.Key)] = true +func (s s3) ReadDir() (it iter.Seq[string], err error) { + entries, err := s.client.ListObjectsV2(context.Background(), &awss3.ListObjectsV2Input{ + Bucket: &s.bucket, + Prefix: &s.prefix, + }) + if err == nil { + it = func(yield func(string) bool) { + for _, e := range entries.Contents { + if e.Key != nil && strings.ToLower(filepath.Ext(*e.Key)) == ".torrent" { + if !yield(filepath.Join(s.prefix, *e.Key)) { + return } } - for p := range tmpFiles { - if _, exists := files[p]; !exists { - requestInput := &awss3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(p), - } - - result, err := s3Client.GetObject(ctx, requestInput) - if err != nil { - log.Print(err) - } - var info torrentRawInfoStruct - err = bencode.NewDecoder(io.LimitReader(result.Body, maxTorrentSize)).Decode(&info) - if err == nil { - s1.Write(info.Info) - h1, _ := bittorrent.NewInfoHash(s1.Sum(nil)) - s1.Reset() - - s2.Write(info.Info) - h2, _ := bittorrent.NewInfoHash(s2.Sum(nil)) - s2.Reset() - - files[p] = [2]bittorrent.InfoHash{h1, h2} - var name torrentNameInfoStruct - if err := bencode.DecodeBytes(info.Info, &name); err != nil { - logger.Warn(). - Err(err). - Str("file", p). - Msg("unable to unmarshal torrent info") - } - if len(name.Name) == 0 { - name.Name = list.DUMMY - } - bName := str2bytes.StringToBytes(name.Name) - logger.Err(s.Storage.Put(ctx, s.StorageCtx, storage.Entry{ - Key: h1.RawString(), - Value: bName, - }, storage.Entry{ - Key: h2.RawString(), - Value: bName, - }, storage.Entry{ - Key: h2.TruncateV1().RawString(), - Value: bName, - })). - Str("file", p). - Stringer("infoHash", h1). - Stringer("infoHashV2", h2). - Msg("added torrent to approval list") - } - } - if err != nil { - logger.Warn().Err(err).Str("file", p).Msg("unable to read file") - } - } - for p, ih := range files { - if _, isOk := tmpFiles[p]; !isOk { - delete(files, p) - logger.Err(s.Storage.Delete(ctx, s.StorageCtx, ih[0].RawString(), - ih[1].RawString(), ih[1].TruncateV1().RawString())). - Str("file", p). - Stringer("infoHash", ih[1]). - Stringer("infoHashV2", ih[1]). - Msg("deleted torrent from approval list") - } - } - clear(tmpFiles) - } else { - logger.Warn().Err(err).Msg("unable to get directory content") } } } + return it, err +} + +func (s s3) ReadData(entry string) (data io.ReadCloser, err error) { + var result *awss3.GetObjectOutput + result, err = s.client.GetObject(context.Background(), &awss3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(entry), + }) + if err == nil { + data = result.Body + } + return }