diff --git a/bittorrent/params.go b/bittorrent/params.go index 3a08b64..819bfe5 100644 --- a/bittorrent/params.go +++ b/bittorrent/params.go @@ -1,6 +1,7 @@ package bittorrent import ( + "context" "errors" "net/url" "strconv" @@ -84,6 +85,27 @@ func (rp RouteParams) ByName(name string) string { return "" } +// InjectRouteParamsToContext returns new context with specified RouteParams placed in +// RouteParamsKey key +func InjectRouteParamsToContext(ctx context.Context, rp RouteParams) context.Context { + if rp == nil { + rp = RouteParams{} + } + return context.WithValue(ctx, RouteParamsKey, rp) +} + +// RemapRouteParamsToBgContext returns new context with context.Background parent +// and copied RouteParams from inCtx +func RemapRouteParamsToBgContext(inCtx context.Context) context.Context { + rp, isOk := inCtx.Value(RouteParamsKey).(RouteParams) + if !isOk { + rp = RouteParams{} + } else { + logger.Warn().Msg("unable to fetch route parameters, probably jammed context") + } + return context.WithValue(context.Background(), RouteParamsKey, rp) +} + // ParseURLData parses a request URL or UDP URLData as defined in BEP41. // It expects a concatenated string of the request's path and query parts as // defined in RFC 3986. As both the udp: and http: scheme used by BitTorrent diff --git a/cmd/mochi/main.go b/cmd/mochi/main.go index 6b26686..f8564e6 100644 --- a/cmd/mochi/main.go +++ b/cmd/mochi/main.go @@ -37,7 +37,7 @@ func main() { if err := s.Run(*configPath); err != nil { log.Fatal("unable to start server: ", err) } - defer s.Dispose() + defer s.Shutdown() ch := make(chan os.Signal, 2) signal.Notify(ch, os.Interrupt, syscall.SIGTERM) <-ch diff --git a/cmd/mochi/server.go b/cmd/mochi/server.go index d3dc243..94a618c 100644 --- a/cmd/mochi/server.go +++ b/cmd/mochi/server.go @@ -3,20 +3,23 @@ package main import ( "errors" "fmt" + "io" + "sync" + + "github.com/rs/zerolog" "github.com/sot-tech/mochi/frontend" "github.com/sot-tech/mochi/middleware" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) // Server represents the state of a running instance. type Server struct { - storage storage.PeerStorage - logic *middleware.Logic - sg *stop.Group + frontends []io.Closer + hooks []io.Closer + storage storage.PeerStorage } // Run begins an instance of Conf. @@ -28,11 +31,9 @@ func (r *Server) Run(configFilePath string) error { return fmt.Errorf("failed to read config: %w", err) } - r.sg = stop.NewGroup() - if len(cfg.MetricsAddr) > 0 { log.Info().Str("address", cfg.MetricsAddr).Msg("starting metrics server") - r.sg.Add(metrics.NewServer(cfg.MetricsAddr)) + r.frontends = append(r.frontends, metrics.NewServer(cfg.MetricsAddr)) } else { log.Info().Msg("metrics disabled because of empty address") } @@ -46,17 +47,30 @@ func (r *Server) Run(configFilePath string) error { if err != nil { return fmt.Errorf("failed to configure pre-hooks: %w", err) } + + for _, h := range preHooks { + if c, isOk := h.(io.Closer); isOk { + r.hooks = append(r.hooks, c) + } + } + postHooks, err := middleware.NewHooks(cfg.PostHooks, r.storage) if err != nil { return fmt.Errorf("failed to configure post-hooks: %w", err) } + for _, h := range postHooks { + if c, isOk := h.(io.Closer); isOk { + r.hooks = append(r.hooks, c) + } + } + if len(cfg.Frontends) > 0 { var fs []frontend.Frontend - r.logic = middleware.NewLogic(cfg.AnnounceInterval, cfg.MinAnnounceInterval, r.storage, preHooks, postHooks) - if fs, err = frontend.NewFrontends(cfg.Frontends, r.logic); err == nil { + logic := middleware.NewLogic(cfg.AnnounceInterval, cfg.MinAnnounceInterval, r.storage, preHooks, postHooks) + if fs, err = frontend.NewFrontends(cfg.Frontends, logic); err == nil { for _, f := range fs { - r.sg.Add(f) + r.frontends = append(r.frontends, f) } } else { err = fmt.Errorf("failed to configure frontends: %w", err) @@ -68,21 +82,50 @@ func (r *Server) Run(configFilePath string) error { return err } -// Dispose shuts down an instance of Server. -func (r *Server) Dispose() { +// Shutdown shuts down an instance of Server. +func (r *Server) Shutdown() { log.Debug().Msg("stopping frontends and metrics server") - if errs := r.sg.Stop().Wait(); len(errs) > 0 { - log.Error().Errs("errors", errs).Msg("error occurred while shutting down frontends") - } + closeGroup(r.frontends).Msg("frontends stopped") - log.Debug().Msg("stopping logic") - if errs := r.logic.Stop().Wait(); len(errs) > 0 { - log.Error().Errs("errors", errs).Msg("error occurred while shutting down middlewares") - } + log.Debug().Msg("stopping middleware") + closeGroup(r.hooks).Msg("hooks stopped") log.Debug().Msg("stopping peer store") - if errs := r.storage.Stop().Wait(); len(errs) != 0 { - log.Error().Errs("errors", errs).Msg("error occurred while shutting down peer store") + var err error + if r.storage != nil { + err = r.storage.Close() + } else { + err = errors.New("peer store not configured") } + log.Err(err).Msg("peer store stopped") log.Close() } + +func closeGroup(cls []io.Closer) (e *zerolog.Event) { + l := len(cls) + errs := make([]error, l) + wg := sync.WaitGroup{} + wg.Add(l) + for i, cl := range cls { + go func(i int, cl io.Closer) { + defer wg.Done() + if e := cl.Close(); e != nil { + errs[i] = e + } + }(i, cl) + } + wg.Wait() + nnErrs := make([]error, 0, l) + for _, e := range errs { + if e != nil { + nnErrs = append(nnErrs, e) + } + } + var evt *zerolog.Event + if len(nnErrs) > 0 { + evt = log.Error().Errs("errors", nnErrs) + } else { + evt = log.Info() + } + return evt +} diff --git a/dist/example_config.yaml b/dist/example_config.yaml index 3713af9..3e92b1c 100644 --- a/dist/example_config.yaml +++ b/dist/example_config.yaml @@ -14,13 +14,15 @@ min_announce_interval: 15m # /debug/pprof/{cmdline,profile,symbol,trace} serves profiles in the pprof format metrics_addr: "0.0.0.0:6880" +# This block defines named configurations of network listeners (frontends). +# At least one listener should be provided. frontends: # This block defines configuration for the tracker's HTTP interface. # If you do not wish to run this, delete this section. - name: http config: # The network interface that will bind to an HTTP server for serving - # BitTorrent traffic. Remove this to disable the non-TLS listener. + # BitTorrent traffic. addr: "0.0.0.0:6969" # Mark this frontend as HTTPS server for serving @@ -32,6 +34,8 @@ frontends: tls_key_path: "" # Enable SO_REUSEPORT to allow starting multiple mochi instances with the same HTTP(S) port. + # You can also use this parameter to define two or mote listeners for the same address and port, + # and (possibly) increase throughput. reuse_port: true # The timeout durations for HTTP requests. @@ -97,8 +101,8 @@ frontends: # The maximum number of infohashes that can be scraped in one request. max_scrape_infohashes: 50 - # This block defines configuration for the tracker's UDP interface. - # If you do not wish to run this, delete this section. + # This block defines configuration for the tracker's UDP interface. + # If you do not wish to run this, delete this section. - name: udp config: # The network interface that will bind to a UDP server for serving @@ -106,6 +110,8 @@ frontends: addr: "0.0.0.0:6969" # Enable SO_REUSEPORT to allow starting multiple mochi instances with the same UDP port. + # You can also use this parameter to define two or mote listeners for the same address and port, + # and (a little) increase throughput (faster queue processing because of multiple 'workers'). reuse_port: true # The leeway for a timestamp on a connection ID. diff --git a/frontend/frontend.go b/frontend/frontend.go index e871d96..b00c1e1 100644 --- a/frontend/frontend.go +++ b/frontend/frontend.go @@ -4,12 +4,12 @@ package frontend import ( "fmt" + "io" "sync" "github.com/sot-tech/mochi/middleware" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" ) var ( @@ -46,7 +46,7 @@ func RegisterBuilder(name string, b Builder) { // Frontend dummy interface for bittorrent frontends type Frontend interface { - stop.Stopper + io.Closer } // NewFrontends is a utility function for initializing Frontend-s in bulk. diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index fe93992..4e39661 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -18,7 +18,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" ) var ( @@ -81,17 +80,17 @@ type httpFE struct { } // NewFrontend builds and starts http bittorrent frontend from provided configuration -func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend, err error) { +func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend, error) { var cfg Config + var err error if err = c.Unmarshal(&cfg); err != nil { - return + return nil, err } if cfg, err = cfg.Validate(); err != nil { - return + return nil, err } if len(cfg.AnnounceRoutes) < 1 || len(cfg.ScrapeRoutes) < 1 { - err = errRoutesNotProvided - return + return nil, errRoutesNotProvided } f := &httpFE{ @@ -110,7 +109,7 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend if cfg.UseTLS { var cert tls.Certificate if cert, err = tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath); err != nil { - return + return nil, err } f.srv.TLSConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -150,34 +149,17 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend return f, nil } -// Stop provides a thread-safe way to shut down a currently running Frontend. -func (f *httpFE) Stop() stop.Result { - c := make(stop.Channel) - if f.srv != nil { - go func() { - c.Done(f.srv.Shutdown(context.Background())) - }() - } - return c.Result() +// Close provides a thread-safe way to shut down a currently running Frontend. +func (f *httpFE) Close() error { + return f.srv.Shutdown(context.Background()) } -func injectRouteParamsToContext(ctx context.Context, ps httprouter.Params) context.Context { - rp := bittorrent.RouteParams{} - for _, p := range ps { - rp = append(rp, bittorrent.RouteParam{Key: p.Key, Value: p.Value}) +func httpParamsToRouteParams(in httprouter.Params) (out bittorrent.RouteParams) { + out = make([]bittorrent.RouteParam, 0, len(in)) + for _, p := range in { + out = append(out, bittorrent.RouteParam{Key: p.Key, Value: p.Value}) } - return context.WithValue(ctx, bittorrent.RouteParamsKey, rp) -} - -func remapRouteParamsToBgContext(inCtx context.Context) context.Context { - rp, isOk := inCtx.Value(bittorrent.RouteParamsKey).(bittorrent.RouteParams) - if !isOk { - rp = bittorrent.RouteParams{} - } else { - logger.Warn().Msg("unable to fetch route parameters, probably jammed context") - } - // FIXME: cancelable context - return context.WithValue(context.TODO(), bittorrent.RouteParamsKey, rp) + return } // announceRoute parses and responds to an Announce. @@ -200,7 +182,7 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro } addr = req.GetFirst() - ctx := injectRouteParamsToContext(r.Context(), ps) + ctx := bittorrent.InjectRouteParamsToContext(r.Context(), httpParamsToRouteParams(ps)) ctx, resp, err := f.logic.HandleAnnounce(ctx, req) if err != nil { WriteError(w, err) @@ -215,8 +197,7 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro } // next actions are background and should not be canceled after http writer closed - ctx = remapRouteParamsToBgContext(ctx) - + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go f.logic.AfterAnnounce(ctx, req, resp) } @@ -239,7 +220,7 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout } addr = req.GetFirst() - ctx := injectRouteParamsToContext(r.Context(), ps) + ctx := bittorrent.InjectRouteParamsToContext(r.Context(), httpParamsToRouteParams(ps)) ctx, resp, err := f.logic.HandleScrape(ctx, req) if err != nil { WriteError(w, err) @@ -254,8 +235,7 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout } // next actions are background and should not be canceled after http writer closed - ctx = remapRouteParamsToBgContext(ctx) - + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go f.logic.AfterScrape(ctx, req, resp) } diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index 6dadd61..97c960e 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -20,7 +20,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" ) @@ -81,6 +80,8 @@ type udpFE struct { logic *middleware.Logic maxClockSkew time.Duration collectTimings bool + ctxCancel context.CancelFunc + onceCloser sync.Once frontend.ParseOptions } @@ -109,43 +110,37 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend, } if f.socket, err = cfg.ListenUDP(); err == nil { + var ctx context.Context + ctx, f.ctxCancel = context.WithCancel(context.Background()) f.wg.Add(1) - go func() { - if err := f.serve(); err != nil { + go func(ctx context.Context) { + if err := f.serve(ctx); err != nil { logger.Fatal().Err(err).Msg("server failed") } - }() + }(ctx) } return f, err } -// Stop provides a thread-safe way to shut down a currently running Frontend. -func (t *udpFE) Stop() stop.Result { - select { - case <-t.closing: - return stop.AlreadyStopped - default: - } - - c := make(stop.Channel) - go func() { +// Close provides a thread-safe way to shut down a currently running Frontend. +func (t *udpFE) Close() (err error) { + t.onceCloser.Do(func() { close(t.closing) - var err error if t.socket != nil { + t.ctxCancel() _ = t.socket.SetReadDeadline(time.Now()) t.wg.Wait() err = t.socket.Close() } - c.Done(err) - }() + }) - return c.Result() + return } // serve blocks while listening and serving UDP BitTorrent requests // until Stop() is called or an error is returned. -func (t *udpFE) serve() error { +func (t *udpFE) serve(ctx context.Context) error { pool := bytepool.NewBytePool(2048) defer t.wg.Done() @@ -188,7 +183,7 @@ func (t *udpFE) serve() error { if t.collectTimings && metrics.Enabled() { start = time.Now() } - action, err := t.handleRequest( + action, err := t.handleRequest(ctx, Request{(*buffer)[:n], addr}, ResponseWriter{t.socket, addrPort}, ) @@ -218,7 +213,7 @@ func (w ResponseWriter) Write(b []byte) (int, error) { } // handleRequest parses and responds to a UDP Request. -func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, err error) { +func (t *udpFE) handleRequest(ctx context.Context, r Request, w ResponseWriter) (actionName string, err error) { if len(r.Packet) < 16 { // Malformed, no client packets are less than 16 bytes. // We explicitly return nothing in case this is a DoS attempt. @@ -265,9 +260,9 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e return } - var ctx context.Context var resp *bittorrent.AnnounceResponse - ctx, resp, err = t.logic.HandleAnnounce(context.Background(), req) + ctx := bittorrent.InjectRouteParamsToContext(ctx, bittorrent.RouteParams{}) + ctx, resp, err = t.logic.HandleAnnounce(ctx, req) if err != nil { WriteError(w, txID, err) return @@ -275,6 +270,7 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e WriteAnnounce(w, txID, resp, actionID == announceV6ActionID, r.IP.Is6()) + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go t.logic.AfterAnnounce(ctx, req, resp) case scrapeActionID: @@ -287,9 +283,9 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e return } - var ctx context.Context var resp *bittorrent.ScrapeResponse - ctx, resp, err = t.logic.HandleScrape(context.Background(), req) + ctx := bittorrent.InjectRouteParamsToContext(ctx, bittorrent.RouteParams{}) + ctx, resp, err = t.logic.HandleScrape(ctx, req) if err != nil { WriteError(w, txID, err) return @@ -297,6 +293,7 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e WriteScrape(w, txID, resp) + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go t.logic.AfterScrape(ctx, req, resp) default: diff --git a/frontend/udp/frontend_test.go b/frontend/udp/frontend_test.go index 8ae8b8c..0313e9b 100644 --- a/frontend/udp/frontend_test.go +++ b/frontend/udp/frontend_test.go @@ -29,8 +29,7 @@ func TestStartStopRaceIssue437(t *testing.T) { if err != nil { t.Fatal(err) } - errC := fe.Stop() - if errs := <-errC; len(errs) != 0 { - t.Fatal(errs) + if err = fe.Close(); err != nil { + t.Fatal(err) } } diff --git a/frontend/udp/writer.go b/frontend/udp/writer.go index c6d8e34..5ed0efe 100644 --- a/frontend/udp/writer.go +++ b/frontend/udp/writer.go @@ -3,7 +3,6 @@ package udp import ( "encoding/binary" "errors" - "fmt" "io" "time" @@ -12,17 +11,15 @@ import ( // WriteError writes the failure reason as a null-terminated string. func WriteError(w io.Writer, txID []byte, err error) { - // If the client wasn't at fault, acknowledge it. - var clientErr bittorrent.ClientError - if !errors.Is(err, &clientErr) { - err = fmt.Errorf("internal error occurred: %w", err) - } - buf := reqRespBufferPool.Get() defer reqRespBufferPool.Put(buf) writeHeader(buf, txID, errorActionID) - _, _ = buf.WriteString(err.Error()) - _, _ = buf.WriteRune('\000') + // If the client wasn't at fault, acknowledge it. + if !errors.As(err, new(bittorrent.ClientError)) { + buf.WriteString("internal error occurred: ") + } + buf.WriteString(err.Error()) + buf.WriteByte('\000') _, _ = w.Write(buf.Bytes()) } @@ -50,7 +47,7 @@ func WriteAnnounce(w io.Writer, txID []byte, resp *bittorrent.AnnounceResponse, } for _, peer := range peers { - _, _ = buf.Write(peer.Addr().AsSlice()) + buf.Write(peer.Addr().AsSlice()) _ = binary.Write(buf, binary.BigEndian, peer.Port()) } @@ -79,8 +76,7 @@ func WriteConnectionID(w io.Writer, txID, connID []byte) { defer reqRespBufferPool.Put(buf) writeHeader(buf, txID, connectActionID) - _, _ = buf.Write(connID) - + buf.Write(connID) _, _ = w.Write(buf.Bytes()) } diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 34556ae..65c6eaf 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -19,7 +19,6 @@ import ( "github.com/sot-tech/mochi/middleware" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -110,13 +109,12 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (h middleware.Hook, err return } -func (h *hook) Stop() stop.Result { +func (h *hook) Close() error { logger.Debug().Msg("attempting to shutdown JWT middleware") - c := make(stop.Channel) if h.jwks != nil { - go h.jwks.EndBackground() + h.jwks.EndBackground() } - return c.Result() + return nil } type verifiableClaims interface { diff --git a/middleware/logic.go b/middleware/logic.go index 458c156..26495e1 100644 --- a/middleware/logic.go +++ b/middleware/logic.go @@ -6,7 +6,6 @@ import ( "time" "github.com/sot-tech/mochi/bittorrent" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -119,25 +118,3 @@ func (l *Logic) Ping(ctx context.Context) (err error) { } return } - -// Stop stops the Logic. -// -// This stops any hooks that implement stop.Stopper. -func (l *Logic) Stop() stop.Result { - stopGroup := stop.NewGroup() - for _, hook := range l.preHooks { - stoppable, ok := hook.(stop.Stopper) - if ok { - stopGroup.Add(stoppable) - } - } - - for _, hook := range l.postHooks { - stoppable, ok := hook.(stop.Stopper) - if ok { - stopGroup.Add(stoppable) - } - } - - return stopGroup.Stop() -} diff --git a/middleware/torrentapproval/container/directory/directory.go b/middleware/torrentapproval/container/directory/directory.go index 9358cbc..16714c8 100644 --- a/middleware/torrentapproval/container/directory/directory.go +++ b/middleware/torrentapproval/container/directory/directory.go @@ -17,7 +17,6 @@ import ( "github.com/sot-tech/mochi/middleware/torrentapproval/container/list" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -49,11 +48,6 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er }, watcher: nil, } - var w *dirwatch.Instance - if w, err = dirwatch.New(c.Path); err != nil { - return nil, fmt.Errorf("unable to initialize directory watch: %w", err) - } - d.watcher = w if len(d.StorageCtx) == 0 { logger.Warn(). Str("name", "StorageCtx"). @@ -62,6 +56,11 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er Msg("falling back to default configuration") d.StorageCtx = container.DefaultStorageCtxName } + var w *dirwatch.Instance + if w, err = dirwatch.New(c.Path); err != nil { + return nil, fmt.Errorf("unable to initialize directory watch: %w", err) + } + d.watcher = w go func() { for event := range d.watcher.Events { var mi *metainfo.MetaInfo @@ -124,10 +123,10 @@ type directory struct { watcher *dirwatch.Instance } -// Stop closes watching of torrent directory -func (d *directory) Stop() stop.Result { - st := make(stop.Channel) - d.watcher.Close() - st.Done() - return st.Result() +// Close closes watching of torrent directory +func (d *directory) Close() error { + if d.watcher != nil { + d.watcher.Close() + } + return nil } diff --git a/middleware/torrentapproval/torrentapproval.go b/middleware/torrentapproval/torrentapproval.go index 93ca348..977f8b1 100644 --- a/middleware/torrentapproval/torrentapproval.go +++ b/middleware/torrentapproval/torrentapproval.go @@ -5,6 +5,7 @@ package torrentapproval import ( "context" "fmt" + "io" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/middleware" @@ -17,7 +18,6 @@ import ( // import static list to enable appropriate support _ "github.com/sot-tech/mochi/middleware/torrentapproval/container/list" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -86,9 +86,9 @@ func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ return ctx, nil } -func (h *hook) Stop() stop.Result { - if st, isOk := h.hashContainer.(stop.Stopper); isOk { - return st.Stop() +func (h *hook) Close() (err error) { + if cl, isOk := h.hashContainer.(io.Closer); isOk { + err = cl.Close() } - return stop.AlreadyStopped + return err } diff --git a/pkg/metrics/server.go b/pkg/metrics/server.go index 03a40bc..7801532 100644 --- a/pkg/metrics/server.go +++ b/pkg/metrics/server.go @@ -14,7 +14,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" ) const ( @@ -50,14 +49,9 @@ func AddressFamily(ip netip.Addr) string { } } -// Stop shuts down the server. -func (s *Server) Stop() stop.Result { - c := make(stop.Channel) - go func() { - c.Done(s.srv.Shutdown(context.Background())) - }() - - return c.Result() +// Close shuts down the server. +func (s *Server) Close() error { + return s.srv.Shutdown(context.Background()) } // NewServer creates a new instance of a Prometheus server that asynchronously diff --git a/pkg/stop/stop.go b/pkg/stop/stop.go deleted file mode 100644 index d3b29dd..0000000 --- a/pkg/stop/stop.go +++ /dev/null @@ -1,125 +0,0 @@ -// Package stop implements a pattern for shutting down a group of processes. -package stop - -import ( - "sync" -) - -// Channel is used to return zero or more errors asynchronously. Call Done() -// once to pass errors to the Channel. -type Channel chan []error - -// Result is a receive-only version of Channel. Call Wait() once to receive any -// returned errors. -type Result <-chan []error - -// Done adds zero or more errors to the Channel and closes it, indicating the -// caller has finished stopping. It should be called exactly once. -func (ch Channel) Done(errs ...error) { - if len(errs) > 0 && errs[0] != nil { - ch <- errs - } - close(ch) -} - -// Result converts a Channel to a Result. -func (ch Channel) Result() <-chan []error { - return ch -} - -// Wait blocks until Done() is called on the underlying Channel and returns any -// errors. It should be called exactly once. -func (r Result) Wait() []error { - return <-r -} - -// AlreadyStopped is a closed error channel to be used by Funcs when -// an element was already stopped. -var AlreadyStopped Result - -// AlreadyStoppedFunc is a Func that returns AlreadyStopped. -var AlreadyStoppedFunc = func() Result { return AlreadyStopped } - -func init() { - closeMe := make(Channel) - close(closeMe) - AlreadyStopped = closeMe.Result() -} - -// Stopper is an interface that allows a clean shutdown. -type Stopper interface { - // Stop returns a channel that indicates whether the stop was - // successful. - // - // The channel can either return one error or be closed. - // Closing the channel signals a clean shutdown. - // Stop() should return immediately and perform the actual shutdown in a - // separate goroutine. - Stop() Result -} - -// Func is a function that can be used to provide a clean shutdown. -type Func func() Result - -// Group is a collection of Stoppers that can be stopped all at once. -type Group struct { - stoppables []Func - sync.Mutex -} - -// NewGroup allocates a new Group. -func NewGroup() *Group { - return &Group{ - stoppables: make([]Func, 0), - } -} - -// Add appends a Stopper to the Group. -func (cg *Group) Add(toAdd Stopper) { - cg.Lock() - defer cg.Unlock() - - cg.stoppables = append(cg.stoppables, toAdd.Stop) -} - -// AddFunc appends a Func to the Group. -func (cg *Group) AddFunc(toAddFunc Func) { - cg.Lock() - defer cg.Unlock() - - cg.stoppables = append(cg.stoppables, toAddFunc) -} - -// Stop stops all members of the Group. -// -// Stopping will be done in a concurrent fashion. -// The slice of errors returned contains all errors returned by stopping the -// members. -func (cg *Group) Stop() Result { - cg.Lock() - defer cg.Unlock() - - whenDone := make(Channel) - - waitChannels := make([]Result, 0, len(cg.stoppables)) - for _, toStop := range cg.stoppables { - waitFor := toStop() - if waitFor == nil { - panic("received a nil chan from Stop") - } - waitChannels = append(waitChannels, waitFor) - } - - go func() { - var errors []error - for _, waitForMe := range waitChannels { - childErrors := waitForMe.Wait() - if len(childErrors) > 0 { - errors = append(errors, childErrors...) - } - } - whenDone.Done(errors...) - }() - - return whenDone.Result() -} diff --git a/storage/keydb/storage.go b/storage/keydb/storage.go index 7bdca86..59d65e5 100644 --- a/storage/keydb/storage.go +++ b/storage/keydb/storage.go @@ -12,14 +12,12 @@ package keydb import ( "context" "errors" - "time" "github.com/go-redis/redis/v8" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" r "github.com/sot-tech/mochi/storage/redis" ) @@ -79,10 +77,7 @@ func newStore(cfg r.Config) (*store, error) { var st *store if err == nil { - st = &store{ - Connection: rs, - peerTTL: uint(cfg.PeerLifetime.Seconds()), - } + st = &store{Connection: rs, peerTTL: uint(cfg.PeerLifetime.Seconds())} } return st, err @@ -177,24 +172,3 @@ func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (uint32 Msg("scrape swarm") return s.ScrapeIH(ctx, ih, s.SCard) } - -func (*store) GCAware() bool { - return false -} - -func (*store) ScheduleGC(_, _ time.Duration) {} - -func (*store) StatisticsAware() bool { - return false -} - -func (*store) ScheduleStatisticsCollection(_ time.Duration) {} - -func (s *store) Stop() stop.Result { - c := make(stop.Channel) - if s.UniversalClient != nil { - c.Done(s.UniversalClient.Close()) - s.UniversalClient = nil - } - return c.Result() -} diff --git a/storage/memory/storage.go b/storage/memory/storage.go index a88c6ad..d62b37c 100644 --- a/storage/memory/storage.go +++ b/storage/memory/storage.go @@ -14,7 +14,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" "github.com/sot-tech/mochi/storage" ) @@ -96,8 +95,9 @@ type peerStore struct { cfg Config shards []*peerShard - closed chan struct{} - wg sync.WaitGroup + closed chan struct{} + wg sync.WaitGroup + onceCloser sync.Once } var _ storage.PeerStorage = &peerStore{} @@ -468,17 +468,9 @@ func (ds *dataStore) Delete(_ context.Context, ctx string, keys ...string) error return nil } -func (*dataStore) Preservable() bool { - return false -} +func (*dataStore) Preservable() bool { return false } -func (*peerStore) GCAware() bool { - return true -} - -func (*peerStore) StatisticsAware() bool { - return true -} +func (ds *dataStore) Close() error { return nil } // GC deletes all Peers from the PeerStorage which are older than the // cutoff time. @@ -542,12 +534,9 @@ func (*peerStore) Ping(context.Context) error { return nil } -func (ps *peerStore) Stop() stop.Result { - c := make(stop.Channel) - go func() { - if ps.closed != nil { - close(ps.closed) - } +func (ps *peerStore) Close() error { + ps.onceCloser.Do(func() { + close(ps.closed) ps.wg.Wait() // Explicitly deallocate our storage. @@ -556,9 +545,7 @@ func (ps *peerStore) Stop() stop.Result { shards[i] = &peerShard{swarms: make(map[bittorrent.InfoHash]swarm)} } ps.shards = shards + }) - c.Done() - }() - - return c.Result() + return nil } diff --git a/storage/pg/storage.go b/storage/pg/storage.go index 0f238d0..32b69c6 100644 --- a/storage/pg/storage.go +++ b/storage/pg/storage.go @@ -20,7 +20,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" "github.com/sot-tech/mochi/storage" ) @@ -76,7 +75,13 @@ func newStore(cfg Config) (storage.PeerStorage, error) { return nil, err } - return &store{Config: cfg, Pool: con, wg: sync.WaitGroup{}, closed: make(chan any)}, nil + return &store{ + Config: cfg, + Pool: con, + wg: sync.WaitGroup{}, + closed: make(chan any), + onceCloser: sync.Once{}, + }, nil } type peerQueryConf struct { @@ -215,8 +220,9 @@ func (cfg Config) Validate() (Config, error) { type store struct { Config *pgxpool.Pool - wg sync.WaitGroup - closed chan any + wg sync.WaitGroup + closed chan any + onceCloser sync.Once } func (s *store) txBatch(ctx context.Context, batch *pgx.Batch) (err error) { @@ -242,7 +248,7 @@ func (s *store) Put(ctx context.Context, storeCtx string, values ...storage.Entr default: var batch pgx.Batch for _, v := range values { - batch.Queue(s.Data.AddQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(v.Key), pValue: v.Value}) + batch.Queue(s.Data.AddQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(v.Key), pValue: v.Value}) } err = s.txBatch(ctx, &batch) } @@ -281,11 +287,10 @@ func (s *store) Preservable() bool { return true } -func (s *store) GCAware() bool { - return len(s.GCQuery) > 0 -} - func (s *store) ScheduleGC(gcInterval, peerLifeTime time.Duration) { + if len(s.GCQuery) == 0 { + return + } s.wg.Add(1) go func() { defer s.wg.Done() @@ -311,11 +316,10 @@ func (s *store) ScheduleGC(gcInterval, peerLifeTime time.Duration) { }() } -func (s *store) StatisticsAware() bool { - return len(s.InfoHashCountQuery) > 0 -} - func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) { + if len(s.InfoHashCountQuery) == 0 { + return + } s.wg.Add(1) go func() { defer s.wg.Done() @@ -357,9 +361,7 @@ func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor pPort: peer.Port(), pSeeder: seeder, pV6: peer.Addr().Is6(), - } - if s.GCAware() { - args[pCreated] = timecache.Now() + pCreated: timecache.Now(), } _, err = s.Exec(ctx, s.Peer.AddQuery, args) return @@ -576,19 +578,12 @@ func (s *store) Ping(ctx context.Context) error { return err } -func (s *store) Stop() stop.Result { - c := make(stop.Channel) +func (s *store) Close() error { go func() { - if s.closed != nil { - close(s.closed) - } + close(s.closed) s.wg.Wait() - if s.Pool != nil { - logger.Info().Msg("pg exiting. mochi does not clear data in database when exiting.") - s.Close() - s.Pool = nil - } - c.Done() + logger.Info().Msg("pg exiting. mochi does not clear data in database when exiting.") + s.Pool.Close() }() - return c.Result() + return nil } diff --git a/storage/redis/storage.go b/storage/redis/storage.go index d50b4b9..daf19bc 100644 --- a/storage/redis/storage.go +++ b/storage/redis/storage.go @@ -36,7 +36,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" "github.com/sot-tech/mochi/storage" ) @@ -100,11 +99,7 @@ func newStore(cfg Config) (*store, error) { return nil, err } - return &store{ - Connection: rs, - closed: make(chan any), - wg: sync.WaitGroup{}, - }, nil + return &store{Connection: rs, closed: make(chan any)}, nil } // Config holds the configuration of a redis PeerStorage. @@ -287,8 +282,9 @@ type Connection struct { type store struct { Connection - closed chan any - wg sync.WaitGroup + closed chan any + wg sync.WaitGroup + onceCloser sync.Once } func (ps *store) count(key string, getLength bool) (n uint64) { @@ -611,14 +607,6 @@ func (*Connection) Preservable() bool { return true } -func (*store) GCAware() bool { - return true -} - -func (*store) StatisticsAware() bool { - return true -} - // Ping sends `PING` request to Redis server func (ps *Connection) Ping(ctx context.Context) error { return ps.UniversalClient.Ping(ctx).Err() @@ -769,21 +757,12 @@ func (ps *store) gc(cutoff time.Time) { } } -func (ps *store) Stop() stop.Result { - c := make(stop.Channel) - go func() { - if ps.closed != nil { - close(ps.closed) - } +func (ps *store) Close() (err error) { + ps.onceCloser.Do(func() { + close(ps.closed) ps.wg.Wait() - var err error - if ps.UniversalClient != nil { - logger.Info().Msg("redis exiting. mochi does not clear data in redis when exiting. mochi keys have prefix " + PrefixKey) - err = ps.UniversalClient.Close() - ps.UniversalClient = nil - } - c.Done(err) - }() - - return c.Result() + logger.Info().Msg("redis exiting. mochi does not clear data in redis when exiting. mochi keys have prefix " + PrefixKey) + err = ps.UniversalClient.Close() + }) + return } diff --git a/storage/storage.go b/storage/storage.go index bd61386..5b3aa73 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,13 +5,13 @@ package storage import ( "context" "fmt" + "io" "sync" "time" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" ) const ( @@ -93,6 +93,7 @@ var ErrResourceDoesNotExist = bittorrent.ClientError("resource does not exist") // DataStorage is the interface, used for implementing store for arbitrary data type DataStorage interface { + io.Closer // Put used to place arbitrary k-v data with specified context // into storage. storeCtx parameter used to group data // (i.e. data only for specific middleware module: hash key, table name etc...) @@ -189,27 +190,23 @@ type PeerStorage interface { // Ping used for checks if storage is alive // (connection could be established, enough space etc.) Ping(ctx context.Context) error +} - // GCAware marks that this storage supports periodic - // peers collection - GCAware() bool - +// GarbageCollector marks that this storage supports periodic +// stale peers collection +type GarbageCollector interface { // ScheduleGC used to delete stale data, such as timed out seeders/leechers. // Note: implementation must create subroutine by itself ScheduleGC(gcInterval, peerLifeTime time.Duration) +} - // StatisticsAware marks that this storage supports periodic - // statistics collection - StatisticsAware() bool - +// StatisticsCollector marks that this storage supports periodic +// statistics collection +type StatisticsCollector interface { // ScheduleStatisticsCollection used to receive statistics information about hashes, // seeders and leechers count. // Note: implementation must create subroutine by itself ScheduleStatisticsCollection(reportInterval time.Duration) - - // Stopper is an interface that expects a Stop method to stop the PeerStorage. - // For more details see the documentation in the stop package. - stop.Stopper } // RegisterDriver makes a Driver available by the provided name. @@ -256,27 +253,27 @@ func NewStorage(cfg conf.NamedMapConfig) (ps PeerStorage, err error) { return } - if gc := ps.GCAware(); gc { + if gc, isOk := ps.(GarbageCollector); isOk { gcInterval, peerTTL := c.sanitizeGCConfig() logger.Info(). Str("name", cfg.Name). Dur("gcInterval", gcInterval). Dur("peerTTL", peerTTL). Msg("scheduling GC") - ps.ScheduleGC(gcInterval, peerTTL) + gc.ScheduleGC(gcInterval, peerTTL) } else { logger.Debug(). Str("name", cfg.Name). Msg("storage does not support GC") } - if st := ps.StatisticsAware(); st { + if st, isOk := ps.(StatisticsCollector); isOk { if statInterval := c.sanitizeStatisticsConfig(); statInterval > 0 { logger.Info(). Str("name", cfg.Name). Dur("statInterval", statInterval). Msg("scheduling statistics collection") - ps.ScheduleStatisticsCollection(statInterval) + st.ScheduleStatisticsCollection(statInterval) } else { logger.Info().Str("name", cfg.Name).Msg("statistics collection disabled because of zero reporting interval") } diff --git a/storage/test/storage_bench.go b/storage/test/storage_bench.go index a8023d0..fa2765a 100644 --- a/storage/test/storage_bench.go +++ b/storage/test/storage_bench.go @@ -3,6 +3,7 @@ package test import ( + "context" "math/rand" "net" "net/netip" @@ -100,9 +101,8 @@ func (bh *benchHolder) runBenchmark(b *testing.B, parallel bool, sf benchSetupFu } b.StopTimer() - errChan := ps.Stop() - for err := range errChan { - b.Fatal(err) + if err := ps.Close(); err != nil { + b.Fatal() } } @@ -125,7 +125,7 @@ func (bh *benchHolder) Nop(b *testing.B) { // Put can run in parallel. func (bh *benchHolder) Put(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[0]) + return ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) }) } @@ -135,7 +135,7 @@ func (bh *benchHolder) Put(b *testing.B) { // Put1k can run in parallel. func (bh *benchHolder) Put1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -145,7 +145,7 @@ func (bh *benchHolder) Put1k(b *testing.B) { // Put1kInfoHash can run in parallel. func (bh *benchHolder) Put1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -155,7 +155,7 @@ func (bh *benchHolder) Put1kInfoHash(b *testing.B) { // Put1kInfoHash1k can run in parallel. func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -166,11 +166,11 @@ func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) { // PutDelete can not run in parallel. func (bh *benchHolder) PutDelete(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[0]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) }) } @@ -180,11 +180,11 @@ func (bh *benchHolder) PutDelete(b *testing.B) { // PutDelete1k can not run in parallel. func (bh *benchHolder) PutDelete1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -194,11 +194,11 @@ func (bh *benchHolder) PutDelete1k(b *testing.B) { // PutDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -208,11 +208,11 @@ func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) { // PutDelete1kInfoHash1k can not run in parallel. func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -223,7 +223,7 @@ func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) { // DeleteNonexist can run in parallel. func (bh *benchHolder) DeleteNonexist(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) return nil }) } @@ -234,7 +234,7 @@ func (bh *benchHolder) DeleteNonexist(b *testing.B) { // DeleteNonexist can run in parallel. func (bh *benchHolder) DeleteNonexist1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) return nil }) } @@ -245,7 +245,7 @@ func (bh *benchHolder) DeleteNonexist1k(b *testing.B) { // DeleteNonexist1kInfoHash can run in parallel. func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) return nil }) } @@ -256,7 +256,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) { // DeleteNonexist1kInfoHash1k can run in parallel. func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return nil }) } @@ -267,7 +267,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) { // GradNonexist can run in parallel. func (bh *benchHolder) GradNonexist(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[0]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0]) return nil }) } @@ -278,7 +278,7 @@ func (bh *benchHolder) GradNonexist(b *testing.B) { // GradNonexist1k can run in parallel. func (bh *benchHolder) GradNonexist1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) return nil }) } @@ -289,7 +289,7 @@ func (bh *benchHolder) GradNonexist1k(b *testing.B) { // GradNonexist1kInfoHash can run in parallel. func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) return nil }) } @@ -301,7 +301,7 @@ func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) { // GradNonexist1kInfoHash1k can run in parallel. func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return nil }) } @@ -313,15 +313,15 @@ func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) { // PutGradDelete can not run in parallel. func (bh *benchHolder) PutGradDelete(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[0], bd.peers[0]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[0]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) }) } @@ -330,15 +330,15 @@ func (bh *benchHolder) PutGradDelete(b *testing.B) { // PutGradDelete1k can not run in parallel. func (bh *benchHolder) PutGradDelete1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -348,15 +348,15 @@ func (bh *benchHolder) PutGradDelete1k(b *testing.B) { // PutGradDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -366,15 +366,15 @@ func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) { // PutGradDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutGradDelete1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -385,9 +385,9 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error { for i, peer := range bd.peers { var err error if i < l/2 { - err = ps.PutLeecher(nil, ih, peer) + err = ps.PutLeecher(context.TODO(), ih, peer) } else { - err = ps.PutSeeder(nil, ih, peer) + err = ps.PutSeeder(context.TODO(), ih, peer) } if err != nil { return err @@ -404,7 +404,7 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error { // AnnounceLeecher can run in parallel. func (bh *benchHolder) AnnounceLeecher(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -415,7 +415,7 @@ func (bh *benchHolder) AnnounceLeecher(b *testing.B) { // AnnounceLeecher1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -426,7 +426,7 @@ func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { // AnnounceSeeder can run in parallel. func (bh *benchHolder) AnnounceSeeder(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -437,7 +437,7 @@ func (bh *benchHolder) AnnounceSeeder(b *testing.B) { // AnnounceSeeder1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -448,7 +448,7 @@ func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) { // ScrapeSwarm can run in parallel. func (bh *benchHolder) ScrapeSwarm(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - ps.ScrapeSwarm(nil, bd.infoHashes[0]) + ps.ScrapeSwarm(context.TODO(), bd.infoHashes[0]) return nil }) } @@ -458,7 +458,7 @@ func (bh *benchHolder) ScrapeSwarm(b *testing.B) { // ScrapeSwarm1kInfoHash can run in parallel. func (bh *benchHolder) ScrapeSwarm1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - ps.ScrapeSwarm(nil, bd.infoHashes[i%ihCount]) + ps.ScrapeSwarm(context.TODO(), bd.infoHashes[i%ihCount]) return nil }) } diff --git a/storage/test/storage_test_base.go b/storage/test/storage_test_base.go index a14b3f4..a360965 100644 --- a/storage/test/storage_test_base.go +++ b/storage/test/storage_test_base.go @@ -3,6 +3,7 @@ package test import ( + "context" "errors" "testing" @@ -40,7 +41,7 @@ type hashPeer struct { func (th *testHolder) DeleteSeeder(t *testing.T) { for _, c := range testData { - err := th.st.DeleteSeeder(nil, c.ih, c.peer) + err := th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -54,14 +55,14 @@ func (th *testHolder) PutLeecher(t *testing.T) { if c.peer.Addr().Is6() { peer = v6Peer } - err := th.st.PutLeecher(nil, c.ih, peer) + err := th.st.PutLeecher(context.TODO(), c.ih, peer) require.Nil(t, err) } } func (th *testHolder) DeleteLeecher(t *testing.T) { for _, c := range testData { - err := th.st.DeleteLeecher(nil, c.ih, c.peer) + err := th.st.DeleteLeecher(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -71,7 +72,7 @@ func (th *testHolder) DeleteLeecher(t *testing.T) { func (th *testHolder) AnnouncePeers(t *testing.T) { for _, c := range testData { - _, err := th.st.AnnouncePeers(nil, c.ih, false, 50, c.peer.Addr().Is6()) + _, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, c.peer.Addr().Is6()) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -81,7 +82,7 @@ func (th *testHolder) AnnouncePeers(t *testing.T) { func (th *testHolder) ScrapeSwarm(t *testing.T) { for _, c := range testData { - l, s, n := th.st.ScrapeSwarm(nil, c.ih) + l, s, n := th.st.ScrapeSwarm(context.TODO(), c.ih) require.Equal(t, uint32(0), s) require.Equal(t, uint32(0), l) require.Equal(t, uint32(0), n) @@ -91,26 +92,26 @@ func (th *testHolder) ScrapeSwarm(t *testing.T) { func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { isV6 := c.peer.Addr().Is6() - err := th.st.PutLeecher(nil, c.ih, c.peer) + err := th.st.PutLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) - peers, err := th.st.AnnouncePeers(nil, c.ih, true, 50, isV6) + peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, true, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // non-seeder announce should still return the leecher - peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) - l, s, _ := th.st.ScrapeSwarm(nil, c.ih) + l, s, _ := th.st.ScrapeSwarm(context.TODO(), c.ih) require.Equal(t, uint32(2), l) require.Equal(t, uint32(0), s) - err = th.st.DeleteLeecher(nil, c.ih, c.peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(nil, c.ih, true, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, true, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -122,22 +123,22 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { isV6 := c.peer.Addr().Is6() - err := th.st.PutSeeder(nil, c.ih, c.peer) + err := th.st.PutSeeder(context.TODO(), c.ih, c.peer) require.Nil(t, err) // Should be leecher to see the seeder - peers, err := th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) - l, s, _ := th.st.ScrapeSwarm(nil, c.ih) + l, s, _ := th.st.ScrapeSwarm(context.TODO(), c.ih) require.Equal(t, uint32(1), l) require.Equal(t, uint32(1), s) - err = th.st.DeleteSeeder(nil, c.ih, c.peer) + err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -153,44 +154,44 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { if isV6 { peer = v6Peer } - err := th.st.PutLeecher(nil, c.ih, c.peer) + err := th.st.PutLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) - err = th.st.GraduateLeecher(nil, c.ih, c.peer) + err = th.st.GraduateLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) // Has to be leecher to see the graduated seeder - peers, err := th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // Deleting the Peer as a Leecher should have no effect - err = th.st.DeleteLeecher(nil, c.ih, c.peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } require.Nil(t, err) // Verify it's still there - peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // Clean up - err = th.st.DeleteLeecher(nil, c.ih, peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, peer) require.Nil(t, err) // Test ErrDNE for missing leecher - err = th.st.DeleteLeecher(nil, c.ih, peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } require.Nil(t, err) - err = th.st.DeleteSeeder(nil, c.ih, c.peer) + err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) require.Nil(t, err) - err = th.st.DeleteSeeder(nil, c.ih, c.peer) + err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -200,35 +201,35 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) CustomPutContainsLoadDelete(t *testing.T) { for _, c := range testData { - err := th.st.Put(nil, kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())}) + err := th.st.Put(context.TODO(), kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())}) require.Nil(t, err) // check if exist in ctx we put - contains, err := th.st.Contains(nil, kvStoreCtx, c.peer.String()) + contains, err := th.st.Contains(context.TODO(), kvStoreCtx, c.peer.String()) require.Nil(t, err) require.True(t, contains) // check if not exist in another ctx - contains, err = th.st.Contains(nil, "", c.peer.String()) + contains, err = th.st.Contains(context.TODO(), "", c.peer.String()) require.Nil(t, err) require.False(t, contains) // check value and type in ctx we put - out, err := th.st.Load(nil, kvStoreCtx, c.peer.String()) + out, err := th.st.Load(context.TODO(), kvStoreCtx, c.peer.String()) require.Nil(t, err) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) require.Equal(t, c.ih, ih) // check value is nil in another ctx - dummy, err := th.st.Load(nil, "", c.peer.String()) + dummy, err := th.st.Load(context.TODO(), "", c.peer.String()) require.Nil(t, err) require.Nil(t, dummy) - err = th.st.Delete(nil, kvStoreCtx, c.peer.String()) + err = th.st.Delete(context.TODO(), kvStoreCtx, c.peer.String()) require.Nil(t, err) - contains, err = th.st.Contains(nil, "", c.peer.String()) + contains, err = th.st.Contains(context.TODO(), "", c.peer.String()) require.Nil(t, err) require.False(t, contains) } @@ -245,29 +246,29 @@ func (th *testHolder) CustomBulkPutContainsLoadDelete(t *testing.T) { Value: []byte(c.ih.RawString()), }) } - err := th.st.Put(nil, kvStoreCtx, pairs...) + err := th.st.Put(context.TODO(), kvStoreCtx, pairs...) require.Nil(t, err) // check if exist in ctx we put for _, k := range keys { - contains, err := th.st.Contains(nil, kvStoreCtx, k) + contains, err := th.st.Contains(context.TODO(), kvStoreCtx, k) require.Nil(t, err) require.True(t, contains) } // check value and type in ctx we put for _, p := range pairs { - out, _ := th.st.Load(nil, kvStoreCtx, p.Key) + out, _ := th.st.Load(context.TODO(), kvStoreCtx, p.Key) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) require.Equal(t, p.Value, []byte(ih.RawString())) } - err = th.st.Delete(nil, kvStoreCtx, keys...) + err = th.st.Delete(context.TODO(), kvStoreCtx, keys...) require.Nil(t, err) for _, k := range keys { - contains, err := th.st.Contains(nil, kvStoreCtx, k) + contains, err := th.st.Contains(context.TODO(), kvStoreCtx, k) require.Nil(t, err) require.False(t, contains) } @@ -304,8 +305,8 @@ func RunTests(t *testing.T, p storage.PeerStorage) { t.Run("CustomPutContainsLoadDelete", th.CustomPutContainsLoadDelete) t.Run("CustomBulkPutContainsLoadDelete", th.CustomBulkPutContainsLoadDelete) - e := th.st.Stop() - require.Nil(t, <-e) + e := th.st.Close() + require.Nil(t, e) } func containsPeer(peers []bittorrent.Peer, p bittorrent.Peer) bool {