Merge pull request #21 from sot-tech/ctx_fallthrough

Used for passing context to middleware and storage
to allow request cancellation (i.e. if http client closed 
connection before response received).

Reworked (simplified) server shutdown.
This commit is contained in:
SOT-TECH
2022-11-08 17:37:54 +03:00
committed by GitHub
25 changed files with 502 additions and 666 deletions
+21
View File
@@ -1,6 +1,7 @@
package bittorrent
import (
"context"
"errors"
"net/url"
"strconv"
@@ -84,6 +85,26 @@ func (rp RouteParams) ByName(name string) string {
return ""
}
// InjectRouteParamsToContext returns new context with specified RouteParams placed in
// RouteParamsKey key
func InjectRouteParamsToContext(ctx context.Context, rp RouteParams) context.Context {
if rp == nil {
rp = RouteParams{}
}
return context.WithValue(ctx, RouteParamsKey, rp)
}
// RemapRouteParamsToBgContext returns new context with context.Background parent
// and copied RouteParams from inCtx
func RemapRouteParamsToBgContext(inCtx context.Context) context.Context {
rp, isOk := inCtx.Value(RouteParamsKey).(RouteParams)
if !isOk {
logger.Warn().Msg("unable to fetch route parameters, probably jammed context")
rp = RouteParams{}
}
return context.WithValue(context.Background(), RouteParamsKey, rp)
}
// ParseURLData parses a request URL or UDP URLData as defined in BEP41.
// It expects a concatenated string of the request's path and query parts as
// defined in RFC 3986. As both the udp: and http: scheme used by BitTorrent
+1 -1
View File
@@ -37,7 +37,7 @@ func main() {
if err := s.Run(*configPath); err != nil {
log.Fatal("unable to start server: ", err)
}
defer s.Dispose()
defer s.Shutdown()
ch := make(chan os.Signal, 2)
signal.Notify(ch, os.Interrupt, syscall.SIGTERM)
<-ch
+64 -21
View File
@@ -3,20 +3,23 @@ package main
import (
"errors"
"fmt"
"io"
"sync"
"github.com/rs/zerolog"
"github.com/sot-tech/mochi/frontend"
"github.com/sot-tech/mochi/middleware"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/metrics"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/storage"
)
// Server represents the state of a running instance.
type Server struct {
storage storage.PeerStorage
logic *middleware.Logic
sg *stop.Group
frontends []io.Closer
hooks []io.Closer
storage storage.PeerStorage
}
// Run begins an instance of Conf.
@@ -28,11 +31,9 @@ func (r *Server) Run(configFilePath string) error {
return fmt.Errorf("failed to read config: %w", err)
}
r.sg = stop.NewGroup()
if len(cfg.MetricsAddr) > 0 {
log.Info().Str("address", cfg.MetricsAddr).Msg("starting metrics server")
r.sg.Add(metrics.NewServer(cfg.MetricsAddr))
r.frontends = append(r.frontends, metrics.NewServer(cfg.MetricsAddr))
} else {
log.Info().Msg("metrics disabled because of empty address")
}
@@ -46,17 +47,30 @@ func (r *Server) Run(configFilePath string) error {
if err != nil {
return fmt.Errorf("failed to configure pre-hooks: %w", err)
}
for _, h := range preHooks {
if c, isOk := h.(io.Closer); isOk {
r.hooks = append(r.hooks, c)
}
}
postHooks, err := middleware.NewHooks(cfg.PostHooks, r.storage)
if err != nil {
return fmt.Errorf("failed to configure post-hooks: %w", err)
}
for _, h := range postHooks {
if c, isOk := h.(io.Closer); isOk {
r.hooks = append(r.hooks, c)
}
}
if len(cfg.Frontends) > 0 {
var fs []frontend.Frontend
r.logic = middleware.NewLogic(cfg.AnnounceInterval, cfg.MinAnnounceInterval, r.storage, preHooks, postHooks)
if fs, err = frontend.NewFrontends(cfg.Frontends, r.logic); err == nil {
logic := middleware.NewLogic(cfg.AnnounceInterval, cfg.MinAnnounceInterval, r.storage, preHooks, postHooks)
if fs, err = frontend.NewFrontends(cfg.Frontends, logic); err == nil {
for _, f := range fs {
r.sg.Add(f)
r.frontends = append(r.frontends, f)
}
} else {
err = fmt.Errorf("failed to configure frontends: %w", err)
@@ -68,21 +82,50 @@ func (r *Server) Run(configFilePath string) error {
return err
}
// Dispose shuts down an instance of Server.
func (r *Server) Dispose() {
// Shutdown shuts down an instance of Server.
func (r *Server) Shutdown() {
log.Debug().Msg("stopping frontends and metrics server")
if errs := r.sg.Stop().Wait(); len(errs) > 0 {
log.Error().Errs("errors", errs).Msg("error occurred while shutting down frontends")
}
closeGroup(r.frontends).Msg("frontends stopped")
log.Debug().Msg("stopping logic")
if errs := r.logic.Stop().Wait(); len(errs) > 0 {
log.Error().Errs("errors", errs).Msg("error occurred while shutting down middlewares")
}
log.Debug().Msg("stopping middleware")
closeGroup(r.hooks).Msg("hooks stopped")
log.Debug().Msg("stopping peer store")
if errs := r.storage.Stop().Wait(); len(errs) != 0 {
log.Error().Errs("errors", errs).Msg("error occurred while shutting down peer store")
var err error
if r.storage != nil {
err = r.storage.Close()
} else {
err = errors.New("peer store not configured")
}
log.Err(err).Msg("peer store stopped")
log.Close()
}
func closeGroup(cls []io.Closer) (e *zerolog.Event) {
l := len(cls)
errs := make([]error, l)
wg := sync.WaitGroup{}
wg.Add(l)
for i, cl := range cls {
go func(i int, cl io.Closer) {
defer wg.Done()
if e := cl.Close(); e != nil {
errs[i] = e
}
}(i, cl)
}
wg.Wait()
nnErrs := make([]error, 0, l)
for _, e := range errs {
if e != nil {
nnErrs = append(nnErrs, e)
}
}
var evt *zerolog.Event
if len(nnErrs) > 0 {
evt = log.Error().Errs("errors", nnErrs)
} else {
evt = log.Info()
}
return evt
}
+9 -3
View File
@@ -14,13 +14,15 @@ min_announce_interval: 15m
# /debug/pprof/{cmdline,profile,symbol,trace} serves profiles in the pprof format
metrics_addr: "0.0.0.0:6880"
# This block defines named configurations of network listeners (frontends).
# At least one listener should be provided.
frontends:
# This block defines configuration for the tracker's HTTP interface.
# If you do not wish to run this, delete this section.
- name: http
config:
# The network interface that will bind to an HTTP server for serving
# BitTorrent traffic. Remove this to disable the non-TLS listener.
# BitTorrent traffic.
addr: "0.0.0.0:6969"
# Mark this frontend as HTTPS server for serving
@@ -32,6 +34,8 @@ frontends:
tls_key_path: ""
# Enable SO_REUSEPORT to allow starting multiple mochi instances with the same HTTP(S) port.
# You can also use this parameter to define two or mote listeners for the same address and port,
# and (possibly) increase throughput.
reuse_port: true
# The timeout durations for HTTP requests.
@@ -97,8 +101,8 @@ frontends:
# The maximum number of infohashes that can be scraped in one request.
max_scrape_infohashes: 50
# This block defines configuration for the tracker's UDP interface.
# If you do not wish to run this, delete this section.
# This block defines configuration for the tracker's UDP interface.
# If you do not wish to run this, delete this section.
- name: udp
config:
# The network interface that will bind to a UDP server for serving
@@ -106,6 +110,8 @@ frontends:
addr: "0.0.0.0:6969"
# Enable SO_REUSEPORT to allow starting multiple mochi instances with the same UDP port.
# You can also use this parameter to define two or mote listeners for the same address and port,
# and (a little) increase throughput (faster queue processing because of multiple 'workers').
reuse_port: true
# The leeway for a timestamp on a connection ID.
+2 -2
View File
@@ -4,12 +4,12 @@ package frontend
import (
"fmt"
"io"
"sync"
"github.com/sot-tech/mochi/middleware"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/stop"
)
var (
@@ -46,7 +46,7 @@ func RegisterBuilder(name string, b Builder) {
// Frontend dummy interface for bittorrent frontends
type Frontend interface {
stop.Stopper
io.Closer
}
// NewFrontends is a utility function for initializing Frontend-s in bulk.
+31 -28
View File
@@ -18,7 +18,6 @@ import (
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/metrics"
"github.com/sot-tech/mochi/pkg/stop"
)
var (
@@ -81,17 +80,17 @@ type httpFE struct {
}
// NewFrontend builds and starts http bittorrent frontend from provided configuration
func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend, err error) {
func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend, error) {
var cfg Config
var err error
if err = c.Unmarshal(&cfg); err != nil {
return
return nil, err
}
if cfg, err = cfg.Validate(); err != nil {
return
return nil, err
}
if len(cfg.AnnounceRoutes) < 1 || len(cfg.ScrapeRoutes) < 1 {
err = errRoutesNotProvided
return
return nil, errRoutesNotProvided
}
f := &httpFE{
@@ -110,7 +109,7 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend
if cfg.UseTLS {
var cert tls.Certificate
if cert, err = tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath); err != nil {
return
return nil, err
}
f.srv.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
@@ -150,23 +149,17 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend
return f, nil
}
// Stop provides a thread-safe way to shut down a currently running Frontend.
func (f *httpFE) Stop() stop.Result {
c := make(stop.Channel)
if f.srv != nil {
go func() {
c.Done(f.srv.Shutdown(context.Background()))
}()
}
return c.Result()
// Close provides a thread-safe way to shut down a currently running Frontend.
func (f *httpFE) Close() error {
return f.srv.Shutdown(context.Background())
}
func injectRouteParamsToContext(ctx context.Context, ps httprouter.Params) context.Context {
rp := bittorrent.RouteParams{}
for _, p := range ps {
rp = append(rp, bittorrent.RouteParam{Key: p.Key, Value: p.Value})
func httpParamsToRouteParams(in httprouter.Params) (out bittorrent.RouteParams) {
out = make([]bittorrent.RouteParam, 0, len(in))
for _, p := range in {
out = append(out, bittorrent.RouteParam{Key: p.Key, Value: p.Value})
}
return context.WithValue(ctx, bittorrent.RouteParamsKey, rp)
return
}
// announceRoute parses and responds to an Announce.
@@ -189,7 +182,7 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro
}
addr = req.GetFirst()
ctx := injectRouteParamsToContext(context.Background(), ps)
ctx := bittorrent.InjectRouteParamsToContext(r.Context(), httpParamsToRouteParams(ps))
ctx, resp, err := f.logic.HandleAnnounce(ctx, req)
if err != nil {
WriteError(w, err)
@@ -203,6 +196,8 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro
return
}
// next actions are background and should not be canceled after http writer closed
ctx = bittorrent.RemapRouteParamsToBgContext(ctx)
go f.logic.AfterAnnounce(ctx, req, resp)
}
@@ -225,7 +220,7 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout
}
addr = req.GetFirst()
ctx := injectRouteParamsToContext(context.Background(), ps)
ctx := bittorrent.InjectRouteParamsToContext(r.Context(), httpParamsToRouteParams(ps))
ctx, resp, err := f.logic.HandleScrape(ctx, req)
if err != nil {
WriteError(w, err)
@@ -239,18 +234,26 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout
return
}
// next actions are background and should not be canceled after http writer closed
ctx = bittorrent.RemapRouteParamsToBgContext(ctx)
go f.logic.AfterScrape(ctx, req, resp)
}
func (f *httpFE) ping(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var err error
status := http.StatusOK
ctx := r.Context()
if r.Method == http.MethodGet {
err = f.logic.Ping(context.TODO())
err = f.logic.Ping(ctx)
}
if err == nil {
w.WriteHeader(http.StatusOK)
} else {
if err != nil {
logger.Error().Err(err).Msg("ping completed with error")
w.WriteHeader(http.StatusServiceUnavailable)
status = http.StatusServiceUnavailable
}
if ctxErr := ctx.Err(); ctxErr == nil {
w.WriteHeader(status)
} else {
logger.Info().Err(ctxErr).Str("ip", r.RemoteAddr).Msg("ping request cancelled")
}
}
+22 -25
View File
@@ -20,7 +20,6 @@ import (
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/metrics"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/pkg/timecache"
)
@@ -81,6 +80,8 @@ type udpFE struct {
logic *middleware.Logic
maxClockSkew time.Duration
collectTimings bool
ctxCancel context.CancelFunc
onceCloser sync.Once
frontend.ParseOptions
}
@@ -109,43 +110,37 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend,
}
if f.socket, err = cfg.ListenUDP(); err == nil {
var ctx context.Context
ctx, f.ctxCancel = context.WithCancel(context.Background())
f.wg.Add(1)
go func() {
if err := f.serve(); err != nil {
go func(ctx context.Context) {
if err := f.serve(ctx); err != nil {
logger.Fatal().Err(err).Msg("server failed")
}
}()
}(ctx)
}
return f, err
}
// Stop provides a thread-safe way to shut down a currently running Frontend.
func (t *udpFE) Stop() stop.Result {
select {
case <-t.closing:
return stop.AlreadyStopped
default:
}
c := make(stop.Channel)
go func() {
// Close provides a thread-safe way to shut down a currently running Frontend.
func (t *udpFE) Close() (err error) {
t.onceCloser.Do(func() {
close(t.closing)
var err error
if t.socket != nil {
t.ctxCancel()
_ = t.socket.SetReadDeadline(time.Now())
t.wg.Wait()
err = t.socket.Close()
}
c.Done(err)
}()
})
return c.Result()
return
}
// serve blocks while listening and serving UDP BitTorrent requests
// until Stop() is called or an error is returned.
func (t *udpFE) serve() error {
func (t *udpFE) serve(ctx context.Context) error {
pool := bytepool.NewBytePool(2048)
defer t.wg.Done()
@@ -188,7 +183,7 @@ func (t *udpFE) serve() error {
if t.collectTimings && metrics.Enabled() {
start = time.Now()
}
action, err := t.handleRequest(
action, err := t.handleRequest(ctx,
Request{(*buffer)[:n], addr},
ResponseWriter{t.socket, addrPort},
)
@@ -218,7 +213,7 @@ func (w ResponseWriter) Write(b []byte) (int, error) {
}
// handleRequest parses and responds to a UDP Request.
func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, err error) {
func (t *udpFE) handleRequest(ctx context.Context, r Request, w ResponseWriter) (actionName string, err error) {
if len(r.Packet) < 16 {
// Malformed, no client packets are less than 16 bytes.
// We explicitly return nothing in case this is a DoS attempt.
@@ -265,9 +260,9 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e
return
}
var ctx context.Context
var resp *bittorrent.AnnounceResponse
ctx, resp, err = t.logic.HandleAnnounce(context.Background(), req)
ctx := bittorrent.InjectRouteParamsToContext(ctx, bittorrent.RouteParams{})
ctx, resp, err = t.logic.HandleAnnounce(ctx, req)
if err != nil {
WriteError(w, txID, err)
return
@@ -275,6 +270,7 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e
WriteAnnounce(w, txID, resp, actionID == announceV6ActionID, r.IP.Is6())
ctx = bittorrent.RemapRouteParamsToBgContext(ctx)
go t.logic.AfterAnnounce(ctx, req, resp)
case scrapeActionID:
@@ -287,9 +283,9 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e
return
}
var ctx context.Context
var resp *bittorrent.ScrapeResponse
ctx, resp, err = t.logic.HandleScrape(context.Background(), req)
ctx := bittorrent.InjectRouteParamsToContext(ctx, bittorrent.RouteParams{})
ctx, resp, err = t.logic.HandleScrape(ctx, req)
if err != nil {
WriteError(w, txID, err)
return
@@ -297,6 +293,7 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e
WriteScrape(w, txID, resp)
ctx = bittorrent.RemapRouteParamsToBgContext(ctx)
go t.logic.AfterScrape(ctx, req, resp)
default:
+2 -3
View File
@@ -29,8 +29,7 @@ func TestStartStopRaceIssue437(t *testing.T) {
if err != nil {
t.Fatal(err)
}
errC := fe.Stop()
if errs := <-errC; len(errs) != 0 {
t.Fatal(errs)
if err = fe.Close(); err != nil {
t.Fatal(err)
}
}
+8 -12
View File
@@ -3,7 +3,6 @@ package udp
import (
"encoding/binary"
"errors"
"fmt"
"io"
"time"
@@ -12,17 +11,15 @@ import (
// WriteError writes the failure reason as a null-terminated string.
func WriteError(w io.Writer, txID []byte, err error) {
// If the client wasn't at fault, acknowledge it.
var clientErr bittorrent.ClientError
if !errors.Is(err, &clientErr) {
err = fmt.Errorf("internal error occurred: %w", err)
}
buf := reqRespBufferPool.Get()
defer reqRespBufferPool.Put(buf)
writeHeader(buf, txID, errorActionID)
_, _ = buf.WriteString(err.Error())
_, _ = buf.WriteRune('\000')
// If the client wasn't at fault, acknowledge it.
if !errors.As(err, new(bittorrent.ClientError)) {
buf.WriteString("internal error occurred: ")
}
buf.WriteString(err.Error())
buf.WriteByte('\000')
_, _ = w.Write(buf.Bytes())
}
@@ -50,7 +47,7 @@ func WriteAnnounce(w io.Writer, txID []byte, resp *bittorrent.AnnounceResponse,
}
for _, peer := range peers {
_, _ = buf.Write(peer.Addr().AsSlice())
buf.Write(peer.Addr().AsSlice())
_ = binary.Write(buf, binary.BigEndian, peer.Port())
}
@@ -79,8 +76,7 @@ func WriteConnectionID(w io.Writer, txID, connID []byte) {
defer reqRespBufferPool.Put(buf)
writeHeader(buf, txID, connectActionID)
_, _ = buf.Write(connID)
buf.Write(connID)
_, _ = w.Write(buf.Bytes())
}
+16 -16
View File
@@ -46,17 +46,17 @@ func (h *swarmInteractionHook) HandleAnnounce(ctx context.Context, req *bittorre
return
}
var storeFn func(bittorrent.InfoHash, bittorrent.Peer) error
var storeFn func(context.Context, bittorrent.InfoHash, bittorrent.Peer) error
switch {
case req.Event == bittorrent.Stopped:
storeFn = func(hash bittorrent.InfoHash, peer bittorrent.Peer) error {
err = h.store.DeleteSeeder(hash, peer)
storeFn = func(ctx context.Context, hash bittorrent.InfoHash, peer bittorrent.Peer) error {
err = h.store.DeleteSeeder(ctx, hash, peer)
if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) {
return err
}
err = h.store.DeleteLeecher(hash, peer)
err = h.store.DeleteLeecher(ctx, hash, peer)
if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) {
return err
}
@@ -74,8 +74,8 @@ func (h *swarmInteractionHook) HandleAnnounce(ctx context.Context, req *bittorre
storeFn = h.store.PutLeecher
}
for _, p := range req.Peers() {
if err = storeFn(req.InfoHash, p); err == nil && len(req.InfoHash) == bittorrent.InfoHashV2Len {
err = storeFn(req.InfoHash.TruncateV1(), p)
if err = storeFn(ctx, req.InfoHash, p); err == nil && len(req.InfoHash) == bittorrent.InfoHashV2Len {
err = storeFn(ctx, req.InfoHash.TruncateV1(), p)
}
if err != nil {
break
@@ -102,10 +102,10 @@ type responseHook struct {
store storage.PeerStorage
}
func (h *responseHook) scrape(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) {
leechers, seeders, snatched = h.store.ScrapeSwarm(ih)
func (h *responseHook) scrape(ctx context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) {
leechers, seeders, snatched = h.store.ScrapeSwarm(ctx, ih)
if len(ih) == bittorrent.InfoHashV2Len {
l, s, n := h.store.ScrapeSwarm(ih.TruncateV1())
l, s, n := h.store.ScrapeSwarm(ctx, ih.TruncateV1())
leechers, seeders, snatched = leechers+l, seeders+s, snatched+n
}
return
@@ -117,9 +117,9 @@ func (h *responseHook) HandleAnnounce(ctx context.Context, req *bittorrent.Annou
}
// Add the Scrape data to the response.
resp.Incomplete, resp.Complete, _ = h.scrape(req.InfoHash)
resp.Incomplete, resp.Complete, _ = h.scrape(ctx, req.InfoHash)
err = h.appendPeers(req, resp)
err = h.appendPeers(ctx, req, resp)
return ctx, err
}
@@ -128,7 +128,7 @@ type fetchArgs struct {
v6 bool
}
func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (err error) {
func (h *responseHook) appendPeers(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (err error) {
seeding := req.Left == 0
max := int(req.NumWant)
peers := make([]bittorrent.Peer, 0, len(resp.IPv4Peers)+len(resp.IPv6Peers))
@@ -159,7 +159,7 @@ func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittor
break
}
var storePeers []bittorrent.Peer
storePeers, err = h.store.AnnouncePeers(a.ih, seeding, max, a.v6)
storePeers, err = h.store.AnnouncePeers(ctx, a.ih, seeding, max, a.v6)
if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) {
return err
}
@@ -209,13 +209,13 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR
for _, infoHash := range req.InfoHashes {
scr := bittorrent.Scrape{InfoHash: infoHash}
scr.Incomplete, scr.Complete, scr.Snatches = h.scrape(infoHash)
scr.Incomplete, scr.Complete, scr.Snatches = h.scrape(ctx, infoHash)
resp.Files = append(resp.Files, scr)
}
return ctx, nil
}
func (h *responseHook) Ping(_ context.Context) error {
return h.store.Ping()
func (h *responseHook) Ping(ctx context.Context) error {
return h.store.Ping(ctx)
}
+3 -5
View File
@@ -19,7 +19,6 @@ import (
"github.com/sot-tech/mochi/middleware"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/storage"
)
@@ -110,13 +109,12 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (h middleware.Hook, err
return
}
func (h *hook) Stop() stop.Result {
func (h *hook) Close() error {
logger.Debug().Msg("attempting to shutdown JWT middleware")
c := make(stop.Channel)
if h.jwks != nil {
go h.jwks.EndBackground()
h.jwks.EndBackground()
}
return c.Result()
return nil
}
type verifiableClaims interface {
-23
View File
@@ -6,7 +6,6 @@ import (
"time"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/storage"
)
@@ -119,25 +118,3 @@ func (l *Logic) Ping(ctx context.Context) (err error) {
}
return
}
// Stop stops the Logic.
//
// This stops any hooks that implement stop.Stopper.
func (l *Logic) Stop() stop.Result {
stopGroup := stop.NewGroup()
for _, hook := range l.preHooks {
stoppable, ok := hook.(stop.Stopper)
if ok {
stopGroup.Add(stoppable)
}
}
for _, hook := range l.postHooks {
stoppable, ok := hook.(stop.Stopper)
if ok {
stopGroup.Add(stoppable)
}
}
return stopGroup.Stop()
}
@@ -3,6 +3,7 @@
package container
import (
"context"
"errors"
"sync"
@@ -41,7 +42,7 @@ func Register(n string, c Builder) {
// Container holds InfoHash and checks if value approved or not
type Container interface {
Approved(bittorrent.InfoHash) bool
Approved(context.Context, bittorrent.InfoHash) bool
}
// GetContainer creates Container by its name and provided confBytes
@@ -5,6 +5,7 @@
package directory
import (
"context"
"fmt"
"github.com/anacrolix/torrent/metainfo"
@@ -16,7 +17,6 @@ import (
"github.com/sot-tech/mochi/middleware/torrentapproval/container/list"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/storage"
)
@@ -48,11 +48,6 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er
},
watcher: nil,
}
var w *dirwatch.Instance
if w, err = dirwatch.New(c.Path); err != nil {
return nil, fmt.Errorf("unable to initialize directory watch: %w", err)
}
d.watcher = w
if len(d.StorageCtx) == 0 {
logger.Warn().
Str("name", "StorageCtx").
@@ -61,6 +56,11 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er
Msg("falling back to default configuration")
d.StorageCtx = container.DefaultStorageCtxName
}
var w *dirwatch.Instance
if w, err = dirwatch.New(c.Path); err != nil {
return nil, fmt.Errorf("unable to initialize directory watch: %w", err)
}
d.watcher = w
go func() {
for event := range d.watcher.Events {
var mi *metainfo.MetaInfo
@@ -85,28 +85,23 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er
name = list.DUMMY
}
bName := []byte(name)
logger.Err(d.Storage.Put(d.StorageCtx,
storage.Entry{
Key: event.InfoHash.AsString(),
Value: bName,
}, storage.Entry{
Key: v2hash.RawString(),
Value: bName,
}, storage.Entry{
Key: v2hash.TruncateV1().RawString(),
Value: bName,
})).
logger.Err(d.Storage.Put(context.Background(), d.StorageCtx, storage.Entry{
Key: event.InfoHash.AsString(),
Value: bName,
}, storage.Entry{
Key: v2hash.RawString(),
Value: bName,
}, storage.Entry{
Key: v2hash.TruncateV1().RawString(),
Value: bName,
})).
Str("action", "add").
Str("file", event.TorrentFilePath).
Stringer("infoHash", event.InfoHash).
Stringer("infoHashV2", v2hash).
Msg("approval torrent watcher event")
case dirwatch.Removed:
logger.Err(d.Storage.Delete(c.StorageCtx,
event.InfoHash.AsString(),
v2hash.RawString(),
v2hash.TruncateV1().RawString(),
)).
logger.Err(d.Storage.Delete(context.Background(), c.StorageCtx, event.InfoHash.AsString(), v2hash.RawString(), v2hash.TruncateV1().RawString())).
Str("action", "delete").
Str("file", event.TorrentFilePath).
Stringer("infoHash", event.InfoHash).
@@ -128,10 +123,10 @@ type directory struct {
watcher *dirwatch.Instance
}
// Stop closes watching of torrent directory
func (d *directory) Stop() stop.Result {
st := make(stop.Channel)
d.watcher.Close()
st.Done()
return st.Result()
// Close closes watching of torrent directory
func (d *directory) Close() error {
if d.watcher != nil {
d.watcher.Close()
}
return nil
}
@@ -3,6 +3,7 @@
package list
import (
"context"
"fmt"
"github.com/sot-tech/mochi/bittorrent"
@@ -64,7 +65,7 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er
init = append(init, storage.Entry{Key: ih.TruncateV1().RawString(), Value: []byte(DUMMY)})
}
}
if err := l.Storage.Put(l.StorageCtx, init...); err != nil {
if err := l.Storage.Put(context.Background(), l.StorageCtx, init...); err != nil {
return nil, fmt.Errorf("unable to put initial data: %w", err)
}
}
@@ -84,11 +85,11 @@ type List struct {
// Approved checks if specified hash is approved or not.
// If List.Invert set to true and hash found in storage, function will return false,
// that means that hash is blacklisted.
func (l *List) Approved(hash bittorrent.InfoHash) (contains bool) {
func (l *List) Approved(ctx context.Context, hash bittorrent.InfoHash) (contains bool) {
var err error
if contains, err = l.Storage.Contains(l.StorageCtx, hash.RawString()); err == nil {
if contains, err = l.Storage.Contains(ctx, l.StorageCtx, hash.RawString()); err == nil {
if len(hash) == bittorrent.InfoHashV2Len {
if containsV2, errV2 := l.Storage.Contains(l.StorageCtx, hash.TruncateV1().RawString()); err == nil {
if containsV2, errV2 := l.Storage.Contains(ctx, l.StorageCtx, hash.TruncateV1().RawString()); err == nil {
contains = contains || containsV2
} else {
err = errV2
@@ -5,6 +5,7 @@ package torrentapproval
import (
"context"
"fmt"
"io"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/middleware"
@@ -17,7 +18,6 @@ import (
// import static list to enable appropriate support
_ "github.com/sot-tech/mochi/middleware/torrentapproval/container/list"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/storage"
)
@@ -74,7 +74,7 @@ type hook struct {
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) {
var err error
if !h.hashContainer.Approved(req.InfoHash) {
if !h.hashContainer.Approved(ctx, req.InfoHash) {
err = ErrTorrentUnapproved
}
@@ -86,9 +86,9 @@ func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _
return ctx, nil
}
func (h *hook) Stop() stop.Result {
if st, isOk := h.hashContainer.(stop.Stopper); isOk {
return st.Stop()
func (h *hook) Close() (err error) {
if cl, isOk := h.hashContainer.(io.Closer); isOk {
err = cl.Close()
}
return stop.AlreadyStopped
return err
}
+3 -9
View File
@@ -14,7 +14,6 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/stop"
)
const (
@@ -50,14 +49,9 @@ func AddressFamily(ip netip.Addr) string {
}
}
// Stop shuts down the server.
func (s *Server) Stop() stop.Result {
c := make(stop.Channel)
go func() {
c.Done(s.srv.Shutdown(context.Background()))
}()
return c.Result()
// Close shuts down the server.
func (s *Server) Close() error {
return s.srv.Shutdown(context.Background())
}
// NewServer creates a new instance of a Prometheus server that asynchronously
-125
View File
@@ -1,125 +0,0 @@
// Package stop implements a pattern for shutting down a group of processes.
package stop
import (
"sync"
)
// Channel is used to return zero or more errors asynchronously. Call Done()
// once to pass errors to the Channel.
type Channel chan []error
// Result is a receive-only version of Channel. Call Wait() once to receive any
// returned errors.
type Result <-chan []error
// Done adds zero or more errors to the Channel and closes it, indicating the
// caller has finished stopping. It should be called exactly once.
func (ch Channel) Done(errs ...error) {
if len(errs) > 0 && errs[0] != nil {
ch <- errs
}
close(ch)
}
// Result converts a Channel to a Result.
func (ch Channel) Result() <-chan []error {
return ch
}
// Wait blocks until Done() is called on the underlying Channel and returns any
// errors. It should be called exactly once.
func (r Result) Wait() []error {
return <-r
}
// AlreadyStopped is a closed error channel to be used by Funcs when
// an element was already stopped.
var AlreadyStopped Result
// AlreadyStoppedFunc is a Func that returns AlreadyStopped.
var AlreadyStoppedFunc = func() Result { return AlreadyStopped }
func init() {
closeMe := make(Channel)
close(closeMe)
AlreadyStopped = closeMe.Result()
}
// Stopper is an interface that allows a clean shutdown.
type Stopper interface {
// Stop returns a channel that indicates whether the stop was
// successful.
//
// The channel can either return one error or be closed.
// Closing the channel signals a clean shutdown.
// Stop() should return immediately and perform the actual shutdown in a
// separate goroutine.
Stop() Result
}
// Func is a function that can be used to provide a clean shutdown.
type Func func() Result
// Group is a collection of Stoppers that can be stopped all at once.
type Group struct {
stoppables []Func
sync.Mutex
}
// NewGroup allocates a new Group.
func NewGroup() *Group {
return &Group{
stoppables: make([]Func, 0),
}
}
// Add appends a Stopper to the Group.
func (cg *Group) Add(toAdd Stopper) {
cg.Lock()
defer cg.Unlock()
cg.stoppables = append(cg.stoppables, toAdd.Stop)
}
// AddFunc appends a Func to the Group.
func (cg *Group) AddFunc(toAddFunc Func) {
cg.Lock()
defer cg.Unlock()
cg.stoppables = append(cg.stoppables, toAddFunc)
}
// Stop stops all members of the Group.
//
// Stopping will be done in a concurrent fashion.
// The slice of errors returned contains all errors returned by stopping the
// members.
func (cg *Group) Stop() Result {
cg.Lock()
defer cg.Unlock()
whenDone := make(Channel)
waitChannels := make([]Result, 0, len(cg.stoppables))
for _, toStop := range cg.stoppables {
waitFor := toStop()
if waitFor == nil {
panic("received a nil chan from Stop")
}
waitChannels = append(waitChannels, waitFor)
}
go func() {
var errors []error
for _, waitForMe := range waitChannels {
childErrors := waitForMe.Wait()
if len(childErrors) > 0 {
errors = append(errors, childErrors...)
}
}
whenDone.Done(errors...)
}()
return whenDone.Result()
}
+24 -54
View File
@@ -12,14 +12,12 @@ package keydb
import (
"context"
"errors"
"time"
"github.com/go-redis/redis/v8"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/storage"
r "github.com/sot-tech/mochi/storage/redis"
)
@@ -79,10 +77,7 @@ func newStore(cfg r.Config) (*store, error) {
var st *store
if err == nil {
st = &store{
Connection: rs,
peerTTL: uint(cfg.PeerLifetime.Seconds()),
}
st = &store{Connection: rs, peerTTL: uint(cfg.PeerLifetime.Seconds())}
}
return st, err
@@ -93,27 +88,23 @@ type store struct {
peerTTL uint
}
func (s *store) setPeerTTL(infoHashKey, peerID string) error {
return s.Process(context.TODO(), redis.NewCmd(context.TODO(), expireMemberCmd, infoHashKey, peerID, s.peerTTL))
}
func (s *store) addPeer(infoHashKey, peerID string) (err error) {
func (s *store) addPeer(ctx context.Context, infoHashKey, peerID string) (err error) {
logger.Trace().
Str("infoHashKey", infoHashKey).
Str("peerID", peerID).
Msg("add peer")
if err = s.SAdd(context.TODO(), infoHashKey, peerID).Err(); err == nil {
err = s.setPeerTTL(infoHashKey, peerID)
if err = s.SAdd(ctx, infoHashKey, peerID).Err(); err == nil {
err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, infoHashKey, peerID, s.peerTTL))
}
return
}
func (s *store) delPeer(infoHashKey, peerID string) error {
func (s *store) delPeer(ctx context.Context, infoHashKey, peerID string) error {
logger.Trace().
Str("infoHashKey", infoHashKey).
Str("peerID", peerID).
Msg("del peer")
deleted, err := s.SRem(context.TODO(), infoHashKey, peerID).Uint64()
deleted, err := s.SRem(ctx, infoHashKey, peerID).Uint64()
err = r.AsNil(err)
if err == nil && deleted == 0 {
err = storage.ErrResourceDoesNotExist
@@ -122,23 +113,23 @@ func (s *store) delPeer(infoHashKey, peerID string) error {
return err
}
func (s *store) PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.addPeer(r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
}
func (s *store) DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString())
}
func (s *store) PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.addPeer(r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
}
func (s *store) DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString())
}
func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) (err error) {
func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) (err error) {
logger.Trace().
Stringer("infoHash", ih).
Object("peer", peer).
@@ -147,21 +138,21 @@ func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) (e
ihSeederKey := r.InfoHashKey(infoHash, true, peer.Addr().Is6())
ihLeecherKey := r.InfoHashKey(infoHash, false, peer.Addr().Is6())
var moved bool
if moved, err = s.SMove(context.TODO(), ihLeecherKey, ihSeederKey, peerID).Result(); err == nil {
if moved, err = s.SMove(ctx, ihLeecherKey, ihSeederKey, peerID).Result(); err == nil {
if moved {
err = s.setPeerTTL(ihSeederKey, peerID)
err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL))
} else {
err = s.addPeer(ihSeederKey, peerID)
err = s.addPeer(ctx, ihSeederKey, peerID)
}
if err == nil {
err = s.HIncrBy(context.TODO(), r.CountDownloadsKey, infoHash, 1).Err()
err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err()
}
}
return err
}
// AnnouncePeers is the same function as redis.AnnouncePeers
func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) {
func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) {
logger.Trace().
Stringer("infoHash", ih).
Bool("forSeeder", forSeeder).
@@ -169,36 +160,15 @@ func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant in
Bool("v6", v6).
Msg("announce peers")
return s.GetPeers(ih, forSeeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd {
return s.SRandMemberN(context.TODO(), infoHashKey, int64(maxCount))
return s.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd {
return s.SRandMemberN(ctx, infoHashKey, int64(maxCount))
})
}
// ScrapeSwarm is the same function as redis.ScrapeSwarm except `SCard` call instead of `HLen`
func (s *store) ScrapeSwarm(ih bittorrent.InfoHash) (uint32, uint32, uint32) {
func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (uint32, uint32, uint32) {
logger.Trace().
Stringer("infoHash", ih).
Msg("scrape swarm")
return s.ScrapeIH(ih, s.SCard)
}
func (*store) GCAware() bool {
return false
}
func (*store) ScheduleGC(_, _ time.Duration) {}
func (*store) StatisticsAware() bool {
return false
}
func (*store) ScheduleStatisticsCollection(_ time.Duration) {}
func (s *store) Stop() stop.Result {
c := make(stop.Channel)
if s.UniversalClient != nil {
c.Done(s.UniversalClient.Close())
s.UniversalClient = nil
}
return c.Result()
return s.ScrapeIH(ctx, ih, s.SCard)
}
+23 -35
View File
@@ -3,6 +3,7 @@
package memory
import (
"context"
"encoding/binary"
"math"
"runtime"
@@ -13,7 +14,6 @@ import (
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/metrics"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/pkg/timecache"
"github.com/sot-tech/mochi/storage"
)
@@ -95,8 +95,9 @@ type peerStore struct {
cfg Config
shards []*peerShard
closed chan struct{}
wg sync.WaitGroup
closed chan struct{}
wg sync.WaitGroup
onceCloser sync.Once
}
var _ storage.PeerStorage = &peerStore{}
@@ -174,7 +175,7 @@ func (ps *peerStore) shardIndex(infoHash bittorrent.InfoHash, v6 bool) uint32 {
return idx
}
func (ps *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
func (ps *peerStore) PutSeeder(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -207,7 +208,7 @@ func (ps *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error
return nil
}
func (ps *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error {
func (ps *peerStore) DeleteSeeder(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -240,7 +241,7 @@ func (ps *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) err
return nil
}
func (ps *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error {
func (ps *peerStore) PutLeecher(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -273,7 +274,7 @@ func (ps *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error
return nil
}
func (ps *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error {
func (ps *peerStore) DeleteLeecher(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -306,7 +307,7 @@ func (ps *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) er
return nil
}
func (ps *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error {
func (ps *peerStore) GraduateLeecher(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -372,7 +373,7 @@ func (ps *peerStore) getPeers(shard *peerShard, ih bittorrent.InfoHash, maxCount
return
}
func (ps *peerStore) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) {
func (ps *peerStore) AnnouncePeers(_ context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -401,7 +402,7 @@ func (ps *peerStore) countPeers(ih bittorrent.InfoHash, v6 bool) (leechers, seed
return
}
func (ps *peerStore) ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) {
func (ps *peerStore) ScrapeSwarm(_ context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) {
select {
case <-ps.closed:
panic("attempted to interact with stopped memory store")
@@ -427,7 +428,7 @@ type dataStore struct {
sync.Map
}
func (ds *dataStore) Put(ctx string, values ...storage.Entry) error {
func (ds *dataStore) Put(_ context.Context, ctx string, values ...storage.Entry) error {
if len(values) > 0 {
c, _ := ds.LoadOrStore(ctx, new(sync.Map))
m := c.(*sync.Map)
@@ -438,7 +439,7 @@ func (ds *dataStore) Put(ctx string, values ...storage.Entry) error {
return nil
}
func (ds *dataStore) Contains(ctx string, key string) (bool, error) {
func (ds *dataStore) Contains(_ context.Context, ctx string, key string) (bool, error) {
var exist bool
if m, found := ds.Map.Load(ctx); found {
_, exist = m.(*sync.Map).Load(key)
@@ -446,7 +447,7 @@ func (ds *dataStore) Contains(ctx string, key string) (bool, error) {
return exist, nil
}
func (ds *dataStore) Load(ctx string, key string) (out []byte, _ error) {
func (ds *dataStore) Load(_ context.Context, ctx string, key string) (out []byte, _ error) {
if m, found := ds.Map.Load(ctx); found {
if v, _ := m.(*sync.Map).Load(key); v != nil {
out = v.([]byte)
@@ -455,7 +456,7 @@ func (ds *dataStore) Load(ctx string, key string) (out []byte, _ error) {
return
}
func (ds *dataStore) Delete(ctx string, keys ...string) error {
func (ds *dataStore) Delete(_ context.Context, ctx string, keys ...string) error {
if len(keys) > 0 {
if m, found := ds.Map.Load(ctx); found {
m := m.(*sync.Map)
@@ -467,17 +468,9 @@ func (ds *dataStore) Delete(ctx string, keys ...string) error {
return nil
}
func (*dataStore) Preservable() bool {
return false
}
func (*dataStore) Preservable() bool { return false }
func (*peerStore) GCAware() bool {
return true
}
func (*peerStore) StatisticsAware() bool {
return true
}
func (ds *dataStore) Close() error { return nil }
// GC deletes all Peers from the PeerStorage which are older than the
// cutoff time.
@@ -537,16 +530,13 @@ func (ps *peerStore) gc(cutoff time.Time) {
}
}
func (*peerStore) Ping() error {
func (*peerStore) Ping(context.Context) error {
return nil
}
func (ps *peerStore) Stop() stop.Result {
c := make(stop.Channel)
go func() {
if ps.closed != nil {
close(ps.closed)
}
func (ps *peerStore) Close() error {
ps.onceCloser.Do(func() {
close(ps.closed)
ps.wg.Wait()
// Explicitly deallocate our storage.
@@ -555,9 +545,7 @@ func (ps *peerStore) Stop() stop.Result {
shards[i] = &peerShard{swarms: make(map[bittorrent.InfoHash]swarm)}
}
ps.shards = shards
})
c.Done()
}()
return c.Result()
return nil
}
+62 -67
View File
@@ -20,7 +20,6 @@ import (
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/metrics"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/pkg/timecache"
"github.com/sot-tech/mochi/storage"
)
@@ -76,7 +75,13 @@ func newStore(cfg Config) (storage.PeerStorage, error) {
return nil, err
}
return &store{Config: cfg, Pool: con, wg: sync.WaitGroup{}, closed: make(chan any)}, nil
return &store{
Config: cfg,
Pool: con,
wg: sync.WaitGroup{},
closed: make(chan any),
onceCloser: sync.Once{},
}, nil
}
type peerQueryConf struct {
@@ -215,14 +220,15 @@ func (cfg Config) Validate() (Config, error) {
type store struct {
Config
*pgxpool.Pool
wg sync.WaitGroup
closed chan any
wg sync.WaitGroup
closed chan any
onceCloser sync.Once
}
func (s *store) txBatch(ctx context.Context, batch *pgx.Batch) (err error) {
var tx pgx.Tx
if tx, err = s.Begin(ctx); err == nil {
if err = tx.SendBatch(context.TODO(), batch).Close(); err == nil {
if err = tx.SendBatch(ctx, batch).Close(); err == nil {
err = tx.Commit(ctx)
} else {
if txErr := tx.Rollback(ctx); txErr != nil {
@@ -233,25 +239,25 @@ func (s *store) txBatch(ctx context.Context, batch *pgx.Batch) (err error) {
return
}
func (s *store) Put(ctx string, values ...storage.Entry) (err error) {
func (s *store) Put(ctx context.Context, storeCtx string, values ...storage.Entry) (err error) {
switch len(values) {
case 0:
// ignore
case 1:
_, err = s.Exec(context.TODO(), s.Data.AddQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(values[0].Key), pValue: values[0].Value})
_, err = s.Exec(ctx, s.Data.AddQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(values[0].Key), pValue: values[0].Value})
default:
var batch pgx.Batch
for _, v := range values {
batch.Queue(s.Data.AddQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(v.Key), pValue: v.Value})
batch.Queue(s.Data.AddQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(v.Key), pValue: v.Value})
}
err = s.txBatch(context.TODO(), &batch)
err = s.txBatch(ctx, &batch)
}
return
}
func (s *store) Contains(ctx string, key string) (contains bool, err error) {
func (s *store) Contains(ctx context.Context, storeCtx string, key string) (contains bool, err error) {
var rows pgx.Rows
if rows, err = s.Query(context.TODO(), s.Data.GetQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(key)}); err == nil {
if rows, err = s.Query(ctx, s.Data.GetQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(key)}); err == nil {
defer rows.Close()
contains = rows.Next()
err = rows.Err()
@@ -259,20 +265,20 @@ func (s *store) Contains(ctx string, key string) (contains bool, err error) {
return
}
func (s *store) Load(ctx string, key string) (out []byte, err error) {
if err = s.QueryRow(context.TODO(), s.Data.GetQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(key)}).Scan(&out); errors.Is(err, pgx.ErrNoRows) {
func (s *store) Load(ctx context.Context, storeCtx string, key string) (out []byte, err error) {
if err = s.QueryRow(ctx, s.Data.GetQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(key)}).Scan(&out); errors.Is(err, pgx.ErrNoRows) {
err = nil
}
return
}
func (s *store) Delete(ctx string, keys ...string) (err error) {
func (s *store) Delete(ctx context.Context, storeCtx string, keys ...string) (err error) {
if len(keys) > 0 {
baKeys := make([][]byte, len(keys))
for i, k := range keys {
baKeys[i] = []byte(k)
}
_, err = s.Exec(context.TODO(), s.Data.DelQuery, pgx.NamedArgs{pCtx: ctx, pKey: baKeys})
_, err = s.Exec(ctx, s.Data.DelQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: baKeys})
}
return
}
@@ -281,11 +287,10 @@ func (s *store) Preservable() bool {
return true
}
func (s *store) GCAware() bool {
return len(s.GCQuery) > 0
}
func (s *store) ScheduleGC(gcInterval, peerLifeTime time.Duration) {
if len(s.GCQuery) == 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
@@ -311,11 +316,10 @@ func (s *store) ScheduleGC(gcInterval, peerLifeTime time.Duration) {
}()
}
func (s *store) StatisticsAware() bool {
return len(s.InfoHashCountQuery) > 0
}
func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) {
if len(s.InfoHashCountQuery) == 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
@@ -328,7 +332,7 @@ func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) {
case <-t.C:
if metrics.Enabled() {
before := time.Now()
sc, lc := s.countPeers(nil)
sc, lc := s.countPeers(context.Background(), nil)
var hc int
if err := s.QueryRow(context.Background(), s.InfoHashCountQuery).Scan(&hc); err != nil && !errors.Is(err, pgx.ErrNoRows) {
logger.Error().Err(err).Msg("error occurred while get info hash count")
@@ -344,7 +348,7 @@ func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) {
}()
}
func (s *store) putPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
logger.Trace().
Stringer("infoHash", ih).
Object("peer", peer).
@@ -357,20 +361,18 @@ func (s *store) putPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder boo
pPort: peer.Port(),
pSeeder: seeder,
pV6: peer.Addr().Is6(),
pCreated: timecache.Now(),
}
if s.GCAware() {
args[pCreated] = timecache.Now()
}
_, err = s.Exec(context.TODO(), s.Peer.AddQuery, args)
_, err = s.Exec(ctx, s.Peer.AddQuery, args)
return
}
func (s *store) delPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) {
logger.Trace().
Stringer("infoHash", ih).
Object("peer", peer).
Msg("del peer")
_, err = s.Exec(context.TODO(), s.Peer.DelQuery, pgx.NamedArgs{
_, err = s.Exec(ctx, s.Peer.DelQuery, pgx.NamedArgs{
pInfoHash: []byte(ih),
pPeerID: peer.ID[:],
pAddress: net.IP(peer.Addr().AsSlice()),
@@ -380,23 +382,23 @@ func (s *store) delPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder boo
return
}
func (s *store) PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.putPeer(ih, peer, true)
func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.putPeer(ctx, ih, peer, true)
}
func (s *store) DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ih, peer, true)
func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, ih, peer, true)
}
func (s *store) PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.putPeer(ih, peer, false)
func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.putPeer(ctx, ih, peer, false)
}
func (s *store) DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ih, peer, false)
func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return s.delPeer(ctx, ih, peer, false)
}
func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
logger.Trace().
Stringer("infoHash", ih).
Object("peer", peer).
@@ -410,12 +412,12 @@ func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) er
pPort: peer.Port(),
})
batch.Queue(s.Downloads.IncrementQuery, pgx.NamedArgs{pInfoHash: ihb})
return s.txBatch(context.TODO(), &batch)
return s.txBatch(ctx, &batch)
}
func (s *store) getPeers(ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) {
func (s *store) getPeers(ctx context.Context, ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) {
var rows pgx.Rows
if rows, err = s.Query(context.TODO(), s.Announce.Query, pgx.NamedArgs{
if rows, err = s.Query(ctx, s.Announce.Query, pgx.NamedArgs{
pInfoHash: []byte(ih),
pSeeder: seeders,
pV6: isV6,
@@ -484,7 +486,7 @@ func (s *store) getPeers(ih bittorrent.InfoHash, seeders bool, maxCount int, isV
return
}
func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) {
func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) {
logger.Trace().
Stringer("infoHash", ih).
Bool("forSeeder", forSeeder).
@@ -492,11 +494,11 @@ func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant in
Bool("v6", v6).
Msg("announce peers")
if forSeeder {
peers, err = s.getPeers(ih, false, numWant, v6)
peers, err = s.getPeers(ctx, ih, false, numWant, v6)
} else {
if peers, err = s.getPeers(ih, true, numWant, v6); err == nil {
if peers, err = s.getPeers(ctx, ih, true, numWant, v6); err == nil {
var addPeers []bittorrent.Peer
addPeers, err = s.getPeers(ih, false, numWant-len(peers), v6)
addPeers, err = s.getPeers(ctx, ih, false, numWant-len(peers), v6)
peers = append(peers, addPeers...)
}
}
@@ -513,13 +515,13 @@ func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant in
return
}
func (s *store) countPeers(ih []byte) (seeders uint32, leechers uint32) {
func (s *store) countPeers(ctx context.Context, ih []byte) (seeders uint32, leechers uint32) {
var rows pgx.Rows
var err error
if len(ih) == 0 {
rows, err = s.Query(context.TODO(), s.Peer.CountQuery)
rows, err = s.Query(ctx, s.Peer.CountQuery)
} else {
rows, err = s.Query(context.TODO(), s.Peer.CountQuery+" "+s.Peer.ByInfoHashClause, pgx.NamedArgs{pInfoHash: ih})
rows, err = s.Query(ctx, s.Peer.CountQuery+" "+s.Peer.ByInfoHashClause, pgx.NamedArgs{pInfoHash: ih})
}
if err == nil {
defer rows.Close()
@@ -556,14 +558,14 @@ func (s *store) countPeers(ih []byte) (seeders uint32, leechers uint32) {
return
}
func (s *store) ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) {
func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) {
logger.Trace().
Stringer("infoHash", ih).
Msg("scrape swarm")
ihb := []byte(ih)
seeders, leechers = s.countPeers(ihb)
seeders, leechers = s.countPeers(ctx, ihb)
if len(s.Downloads.GetQuery) > 0 {
if err := s.QueryRow(context.TODO(), s.Downloads.GetQuery, pgx.NamedArgs{pInfoHash: ihb}).Scan(&snatched); err != nil && !errors.Is(err, pgx.ErrNoRows) {
if err := s.QueryRow(ctx, s.Downloads.GetQuery, pgx.NamedArgs{pInfoHash: ihb}).Scan(&snatched); err != nil && !errors.Is(err, pgx.ErrNoRows) {
logger.Error().Stringer("infoHash", ih).Err(err).Msg("error occurred while get info downloads count")
}
}
@@ -571,24 +573,17 @@ func (s *store) ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders ui
return
}
func (s *store) Ping() error {
_, err := s.Exec(context.TODO(), s.PingQuery)
func (s *store) Ping(ctx context.Context) error {
_, err := s.Exec(ctx, s.PingQuery)
return err
}
func (s *store) Stop() stop.Result {
c := make(stop.Channel)
func (s *store) Close() error {
go func() {
if s.closed != nil {
close(s.closed)
}
close(s.closed)
s.wg.Wait()
if s.Pool != nil {
logger.Info().Msg("pg exiting. mochi does not clear data in database when exiting.")
s.Close()
s.Pool = nil
}
c.Done()
logger.Info().Msg("pg exiting. mochi does not clear data in database when exiting.")
s.Pool.Close()
}()
return c.Result()
return nil
}
+64 -86
View File
@@ -36,7 +36,6 @@ import (
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/metrics"
"github.com/sot-tech/mochi/pkg/stop"
"github.com/sot-tech/mochi/pkg/timecache"
"github.com/sot-tech/mochi/storage"
)
@@ -100,11 +99,7 @@ func newStore(cfg Config) (*store, error) {
return nil, err
}
return &store{
Connection: rs,
closed: make(chan any),
wg: sync.WaitGroup{},
}, nil
return &store{Connection: rs, closed: make(chan any)}, nil
}
// Config holds the configuration of a redis PeerStorage.
@@ -287,8 +282,9 @@ type Connection struct {
type store struct {
Connection
closed chan any
wg sync.WaitGroup
closed chan any
wg sync.WaitGroup
onceCloser sync.Once
}
func (ps *store) count(key string, getLength bool) (n uint64) {
@@ -309,8 +305,8 @@ func (ps *store) getClock() int64 {
return timecache.NowUnixNano()
}
func (ps *store) tx(txf func(tx redis.Pipeliner) error) (err error) {
if pipe, txErr := ps.TxPipelined(context.TODO(), txf); txErr == nil {
func (ps *store) tx(ctx context.Context, txf func(tx redis.Pipeliner) error) (err error) {
if pipe, txErr := ps.TxPipelined(ctx, txf); txErr == nil {
errs := make([]string, 0)
for _, c := range pipe {
if err := c.Err(); err != nil {
@@ -358,58 +354,58 @@ func InfoHashKey(infoHash string, seeder, v6 bool) (infoHashKey string) {
return
}
func (ps *store) putPeer(infoHashKey, peerCountKey, peerID string) error {
func (ps *store) putPeer(ctx context.Context, infoHashKey, peerCountKey, peerID string) error {
logger.Trace().
Str("infoHashKey", infoHashKey).
Str("peerID", peerID).
Msg("put peer")
return ps.tx(func(tx redis.Pipeliner) (err error) {
if err = tx.HSet(context.TODO(), infoHashKey, peerID, ps.getClock()).Err(); err != nil {
return ps.tx(ctx, func(tx redis.Pipeliner) (err error) {
if err = tx.HSet(ctx, infoHashKey, peerID, ps.getClock()).Err(); err != nil {
return
}
if err = tx.Incr(context.TODO(), peerCountKey).Err(); err != nil {
if err = tx.Incr(ctx, peerCountKey).Err(); err != nil {
return
}
err = tx.SAdd(context.TODO(), IHKey, infoHashKey).Err()
err = tx.SAdd(ctx, IHKey, infoHashKey).Err()
return
})
}
func (ps *store) delPeer(infoHashKey, peerCountKey, peerID string) error {
func (ps *store) delPeer(ctx context.Context, infoHashKey, peerCountKey, peerID string) error {
logger.Trace().
Str("infoHashKey", infoHashKey).
Str("peerID", peerID).
Msg("del peer")
deleted, err := ps.HDel(context.TODO(), infoHashKey, peerID).Uint64()
deleted, err := ps.HDel(ctx, infoHashKey, peerID).Uint64()
err = AsNil(err)
if err == nil {
if deleted == 0 {
err = storage.ErrResourceDoesNotExist
} else {
err = ps.Decr(context.TODO(), peerCountKey).Err()
err = ps.Decr(ctx, peerCountKey).Err()
}
}
return err
}
func (ps *store) PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.putPeer(InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
func (ps *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
}
func (ps *store) DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.delPeer(InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
func (ps *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString())
}
func (ps *store) PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.putPeer(InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
func (ps *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
}
func (ps *store) DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.delPeer(InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
func (ps *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString())
}
func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error {
func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error {
logger.Trace().
Stringer("infoHash", ih).
Object("peer", peer).
@@ -418,25 +414,25 @@ func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) e
infoHash, peerID, isV6 := ih.RawString(), peer.RawString(), peer.Addr().Is6()
ihSeederKey, ihLeecherKey := InfoHashKey(infoHash, true, isV6), InfoHashKey(infoHash, false, isV6)
return ps.tx(func(tx redis.Pipeliner) error {
deleted, err := tx.HDel(context.TODO(), ihLeecherKey, peerID).Uint64()
return ps.tx(ctx, func(tx redis.Pipeliner) error {
deleted, err := tx.HDel(ctx, ihLeecherKey, peerID).Uint64()
err = AsNil(err)
if err == nil {
if deleted > 0 {
err = tx.Decr(context.TODO(), CountLeecherKey).Err()
err = tx.Decr(ctx, CountLeecherKey).Err()
}
}
if err == nil {
err = tx.HSet(context.TODO(), ihSeederKey, peerID, ps.getClock()).Err()
err = tx.HSet(ctx, ihSeederKey, peerID, ps.getClock()).Err()
}
if err == nil {
err = tx.Incr(context.TODO(), CountSeederKey).Err()
err = tx.Incr(ctx, CountSeederKey).Err()
}
if err == nil {
err = tx.SAdd(context.TODO(), IHKey, ihSeederKey).Err()
err = tx.SAdd(ctx, IHKey, ihSeederKey).Err()
}
if err == nil {
err = tx.HIncrBy(context.TODO(), CountDownloadsKey, infoHash, 1).Err()
err = tx.HIncrBy(ctx, CountDownloadsKey, infoHash, 1).Err()
}
return err
})
@@ -457,7 +453,7 @@ func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers [
return
}
type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd
type getPeersFn func(string, int) *redis.StringSliceCmd
// GetPeers retrieves peers for provided info hash by calling membersFn and
// converts result to bittorrent.Peer array.
@@ -477,7 +473,7 @@ func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount
for _, infoHashKey := range infoHashKeys {
var peers []bittorrent.Peer
peers, err = ps.parsePeersList(membersFn(context.TODO(), infoHashKey, maxCount))
peers, err = ps.parsePeersList(membersFn(infoHashKey, maxCount))
maxCount -= len(peers)
out = append(out, peers...)
if err != nil || maxCount <= 0 {
@@ -497,7 +493,7 @@ func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount
return
}
func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) {
func (ps *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) {
logger.Trace().
Stringer("infoHash", ih).
Bool("forSeeder", forSeeder).
@@ -505,15 +501,15 @@ func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant i
Bool("v6", v6).
Msg("announce peers")
return ps.GetPeers(ih, forSeeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd {
return ps.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd {
return ps.HRandField(ctx, infoHashKey, maxCount, false)
})
}
type getPeerCountFn func(context.Context, string) *redis.IntCmd
func (ps *Connection) countPeers(infoHashKey string, countFn getPeerCountFn) uint32 {
count, err := countFn(context.TODO(), infoHashKey).Result()
func (ps *Connection) countPeers(ctx context.Context, infoHashKey string, countFn getPeerCountFn) uint32 {
count, err := countFn(ctx, infoHashKey).Result()
err = AsNil(err)
if err != nil {
logger.Error().Err(err).Str("infoHashKey", infoHashKey).Msg("key size calculation failure")
@@ -522,14 +518,14 @@ func (ps *Connection) countPeers(infoHashKey string, countFn getPeerCountFn) uin
}
// ScrapeIH calls provided countFn and returns seeders, leechers and downloads count for specified info hash
func (ps *Connection) ScrapeIH(ih bittorrent.InfoHash, countFn getPeerCountFn) (leechersCount, seedersCount, downloadsCount uint32) {
func (ps *Connection) ScrapeIH(ctx context.Context, ih bittorrent.InfoHash, countFn getPeerCountFn) (leechersCount, seedersCount, downloadsCount uint32) {
infoHash := ih.RawString()
leechersCount = ps.countPeers(InfoHashKey(infoHash, false, false), countFn) +
ps.countPeers(InfoHashKey(infoHash, false, true), countFn)
seedersCount = ps.countPeers(InfoHashKey(infoHash, true, false), countFn) +
ps.countPeers(InfoHashKey(infoHash, true, true), countFn)
d, err := ps.HGet(context.TODO(), CountDownloadsKey, infoHash).Uint64()
leechersCount = ps.countPeers(ctx, InfoHashKey(infoHash, false, false), countFn) +
ps.countPeers(ctx, InfoHashKey(infoHash, false, true), countFn)
seedersCount = ps.countPeers(ctx, InfoHashKey(infoHash, true, false), countFn) +
ps.countPeers(ctx, InfoHashKey(infoHash, true, true), countFn)
d, err := ps.HGet(ctx, CountDownloadsKey, infoHash).Uint64()
if err = AsNil(err); err != nil {
logger.Error().Err(err).Str("infoHash", infoHash).Msg("downloads count calculation failure")
}
@@ -538,31 +534,31 @@ func (ps *Connection) ScrapeIH(ih bittorrent.InfoHash, countFn getPeerCountFn) (
return
}
func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash) (uint32, uint32, uint32) {
func (ps *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (uint32, uint32, uint32) {
logger.Trace().
Stringer("infoHash", ih).
Msg("scrape swarm")
return ps.ScrapeIH(ih, ps.HLen)
return ps.ScrapeIH(ctx, ih, ps.HLen)
}
const argNumErrorMsg = "ERR wrong number of arguments"
// Put - storage.DataStorage implementation
func (ps *Connection) Put(ctx string, values ...storage.Entry) (err error) {
func (ps *Connection) Put(ctx context.Context, storeCtx string, values ...storage.Entry) (err error) {
if l := len(values); l > 0 {
if l == 1 {
err = ps.HSet(context.TODO(), PrefixKey+ctx, values[0].Key, values[0].Value).Err()
err = ps.HSet(ctx, PrefixKey+storeCtx, values[0].Key, values[0].Value).Err()
} else {
args := make([]any, 0, l*2)
for _, p := range values {
args = append(args, p.Key, p.Value)
}
err = ps.HSet(context.TODO(), PrefixKey+ctx, args...).Err()
err = ps.HSet(ctx, PrefixKey+storeCtx, args...).Err()
if err != nil {
if strings.Contains(err.Error(), argNumErrorMsg) {
logger.Warn().Msg("This Redis version/implementation does not support variadic arguments for HSET")
for _, p := range values {
if err = ps.HSet(context.TODO(), PrefixKey+ctx, p.Key, p.Value).Err(); err != nil {
if err = ps.HSet(ctx, PrefixKey+storeCtx, p.Key, p.Value).Err(); err != nil {
break
}
}
@@ -574,14 +570,14 @@ func (ps *Connection) Put(ctx string, values ...storage.Entry) (err error) {
}
// Contains - storage.DataStorage implementation
func (ps *Connection) Contains(ctx string, key string) (bool, error) {
exist, err := ps.HExists(context.TODO(), PrefixKey+ctx, key).Result()
func (ps *Connection) Contains(ctx context.Context, storeCtx string, key string) (bool, error) {
exist, err := ps.HExists(ctx, PrefixKey+storeCtx, key).Result()
return exist, AsNil(err)
}
// Load - storage.DataStorage implementation
func (ps *Connection) Load(ctx string, key string) (v []byte, err error) {
v, err = ps.HGet(context.TODO(), PrefixKey+ctx, key).Bytes()
func (ps *Connection) Load(ctx context.Context, storeCtx string, key string) (v []byte, err error) {
v, err = ps.HGet(ctx, PrefixKey+storeCtx, key).Bytes()
if err != nil && errors.Is(err, redis.Nil) {
v, err = nil, nil
}
@@ -589,14 +585,14 @@ func (ps *Connection) Load(ctx string, key string) (v []byte, err error) {
}
// Delete - storage.DataStorage implementation
func (ps *Connection) Delete(ctx string, keys ...string) (err error) {
func (ps *Connection) Delete(ctx context.Context, storeCtx string, keys ...string) (err error) {
if len(keys) > 0 {
err = AsNil(ps.HDel(context.TODO(), PrefixKey+ctx, keys...).Err())
err = AsNil(ps.HDel(ctx, PrefixKey+storeCtx, keys...).Err())
if err != nil {
if strings.Contains(err.Error(), argNumErrorMsg) {
logger.Warn().Msg("This Redis version/implementation does not support variadic arguments for HDEL")
for _, k := range keys {
if err = AsNil(ps.HDel(context.TODO(), PrefixKey+ctx, k).Err()); err != nil {
if err = AsNil(ps.HDel(ctx, PrefixKey+storeCtx, k).Err()); err != nil {
break
}
}
@@ -611,17 +607,9 @@ func (*Connection) Preservable() bool {
return true
}
func (*store) GCAware() bool {
return true
}
func (*store) StatisticsAware() bool {
return true
}
// Ping sends `PING` request to Redis server
func (ps *Connection) Ping() error {
return ps.UniversalClient.Ping(context.TODO()).Err()
func (ps *Connection) Ping(ctx context.Context) error {
return ps.UniversalClient.Ping(ctx).Err()
}
// GC deletes all Peers from the PeerStorage which are older than the
@@ -747,7 +735,6 @@ func (ps *store) gc(cutoff time.Time) {
if err == nil && infoHashCount == 0 {
// Empty hashes are not shown among existing keys,
// in other words, it's removed automatically after `HDEL` the last field.
// _, err := ps.Del(context.TODO(), infoHashKey)
err = AsNil(ps.SRem(context.Background(), IHKey, infoHashKey).Err())
}
return err
@@ -770,21 +757,12 @@ func (ps *store) gc(cutoff time.Time) {
}
}
func (ps *store) Stop() stop.Result {
c := make(stop.Channel)
go func() {
if ps.closed != nil {
close(ps.closed)
}
func (ps *store) Close() (err error) {
ps.onceCloser.Do(func() {
close(ps.closed)
ps.wg.Wait()
var err error
if ps.UniversalClient != nil {
logger.Info().Msg("redis exiting. mochi does not clear data in redis when exiting. mochi keys have prefix " + PrefixKey)
err = ps.UniversalClient.Close()
ps.UniversalClient = nil
}
c.Done(err)
}()
return c.Result()
logger.Info().Msg("redis exiting. mochi does not clear data in redis when exiting. mochi keys have prefix " + PrefixKey)
err = ps.UniversalClient.Close()
})
return
}
+28 -30
View File
@@ -3,14 +3,15 @@
package storage
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/sot-tech/mochi/bittorrent"
"github.com/sot-tech/mochi/pkg/conf"
"github.com/sot-tech/mochi/pkg/log"
"github.com/sot-tech/mochi/pkg/stop"
)
const (
@@ -92,19 +93,20 @@ var ErrResourceDoesNotExist = bittorrent.ClientError("resource does not exist")
// DataStorage is the interface, used for implementing store for arbitrary data
type DataStorage interface {
io.Closer
// Put used to place arbitrary k-v data with specified context
// into storage. ctx parameter used to group data
// into storage. storeCtx parameter used to group data
// (i.e. data only for specific middleware module: hash key, table name etc...)
Put(ctx string, values ...Entry) error
Put(ctx context.Context, storeCtx string, values ...Entry) error
// Contains checks if any data in specified context exist
Contains(ctx string, key string) (bool, error)
Contains(ctx context.Context, storeCtx string, key string) (bool, error)
// Load used to get arbitrary data in specified context by its key
Load(ctx string, key string) ([]byte, error)
Load(ctx context.Context, storeCtx string, key string) ([]byte, error)
// Delete used to delete arbitrary data in specified context by its keys
Delete(ctx string, keys ...string) error
Delete(ctx context.Context, storeCtx string, keys ...string) error
// Preservable indicates, that this storage can store data permanently,
// in other words, is NOT in-memory storage, which data will be lost after restart
@@ -130,33 +132,33 @@ type PeerStorage interface {
DataStorage
// PutSeeder adds a Seeder to the Swarm identified by the provided
// InfoHash.
PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error
PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error
// DeleteSeeder removes a Seeder from the Swarm identified by the
// provided InfoHash.
//
// If the Swarm or Peer does not exist, this function returns
// ErrResourceDoesNotExist.
DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error
DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error
// PutLeecher adds a Leecher to the Swarm identified by the provided
// InfoHash.
// If the Swarm does not exist already, it is created.
PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error
PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error
// DeleteLeecher removes a Leecher from the Swarm identified by the
// provided InfoHash.
//
// If the Swarm or Peer does not exist, this function returns
// ErrResourceDoesNotExist.
DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error
DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error
// GraduateLeecher promotes a Leecher to a Seeder in the Swarm
// identified by the provided InfoHash.
//
// If the given Peer is not present as a Leecher or the swarm does not exist
// already, the Peer is added as a Seeder and no error is returned.
GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error
GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error
// AnnouncePeers is a best effort attempt to return Peers from the Swarm
// identified by the provided InfoHash.
@@ -173,7 +175,7 @@ type PeerStorage interface {
// leechers
//
// Returns ErrResourceDoesNotExist if the provided InfoHash is not tracked.
AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error)
AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error)
// ScrapeSwarm returns information required to answer a Scrape request
// about a Swarm identified by the given InfoHash.
@@ -183,32 +185,28 @@ type PeerStorage interface {
// filling the Snatches field is optional.
//
// If the Swarm does not exist, an empty Scrape and no error is returned.
ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32)
ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32)
// Ping used for checks if storage is alive
// (connection could be established, enough space etc.)
Ping() error
// GCAware marks that this storage supports periodic
// peers collection
GCAware() bool
Ping(ctx context.Context) error
}
// GarbageCollector marks that this storage supports periodic
// stale peers collection
type GarbageCollector interface {
// ScheduleGC used to delete stale data, such as timed out seeders/leechers.
// Note: implementation must create subroutine by itself
ScheduleGC(gcInterval, peerLifeTime time.Duration)
}
// StatisticsAware marks that this storage supports periodic
// statistics collection
StatisticsAware() bool
// StatisticsCollector marks that this storage supports periodic
// statistics collection
type StatisticsCollector interface {
// ScheduleStatisticsCollection used to receive statistics information about hashes,
// seeders and leechers count.
// Note: implementation must create subroutine by itself
ScheduleStatisticsCollection(reportInterval time.Duration)
// Stopper is an interface that expects a Stop method to stop the PeerStorage.
// For more details see the documentation in the stop package.
stop.Stopper
}
// RegisterDriver makes a Driver available by the provided name.
@@ -255,27 +253,27 @@ func NewStorage(cfg conf.NamedMapConfig) (ps PeerStorage, err error) {
return
}
if gc := ps.GCAware(); gc {
if gc, isOk := ps.(GarbageCollector); isOk {
gcInterval, peerTTL := c.sanitizeGCConfig()
logger.Info().
Str("name", cfg.Name).
Dur("gcInterval", gcInterval).
Dur("peerTTL", peerTTL).
Msg("scheduling GC")
ps.ScheduleGC(gcInterval, peerTTL)
gc.ScheduleGC(gcInterval, peerTTL)
} else {
logger.Debug().
Str("name", cfg.Name).
Msg("storage does not support GC")
}
if st := ps.StatisticsAware(); st {
if st, isOk := ps.(StatisticsCollector); isOk {
if statInterval := c.sanitizeStatisticsConfig(); statInterval > 0 {
logger.Info().
Str("name", cfg.Name).
Dur("statInterval", statInterval).
Msg("scheduling statistics collection")
ps.ScheduleStatisticsCollection(statInterval)
st.ScheduleStatisticsCollection(statInterval)
} else {
logger.Info().Str("name", cfg.Name).Msg("statistics collection disabled because of zero reporting interval")
}
+43 -43
View File
@@ -3,6 +3,7 @@
package test
import (
"context"
"math/rand"
"net"
"net/netip"
@@ -100,9 +101,8 @@ func (bh *benchHolder) runBenchmark(b *testing.B, parallel bool, sf benchSetupFu
}
b.StopTimer()
errChan := ps.Stop()
for err := range errChan {
b.Fatal(err)
if err := ps.Close(); err != nil {
b.Fatal()
}
}
@@ -125,7 +125,7 @@ func (bh *benchHolder) Nop(b *testing.B) {
// Put can run in parallel.
func (bh *benchHolder) Put(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
return ps.PutSeeder(bd.infoHashes[0], bd.peers[0])
return ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0])
})
}
@@ -135,7 +135,7 @@ func (bh *benchHolder) Put(b *testing.B) {
// Put1k can run in parallel.
func (bh *benchHolder) Put1k(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
return ps.PutSeeder(bd.infoHashes[0], bd.peers[i%peersCount])
return ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
})
}
@@ -145,7 +145,7 @@ func (bh *benchHolder) Put1k(b *testing.B) {
// Put1kInfoHash can run in parallel.
func (bh *benchHolder) Put1kInfoHash(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
return ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[0])
return ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
})
}
@@ -155,7 +155,7 @@ func (bh *benchHolder) Put1kInfoHash(b *testing.B) {
// Put1kInfoHash1k can run in parallel.
func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
return err
})
}
@@ -166,11 +166,11 @@ func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) {
// PutDelete can not run in parallel.
func (bh *benchHolder) PutDelete(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutSeeder(bd.infoHashes[0], bd.peers[0])
err := ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0])
if err != nil {
return err
}
return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[0])
return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0])
})
}
@@ -180,11 +180,11 @@ func (bh *benchHolder) PutDelete(b *testing.B) {
// PutDelete1k can not run in parallel.
func (bh *benchHolder) PutDelete1k(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutSeeder(bd.infoHashes[0], bd.peers[i%peersCount])
err := ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
if err != nil {
return err
}
return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[i%peersCount])
return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
})
}
@@ -194,11 +194,11 @@ func (bh *benchHolder) PutDelete1k(b *testing.B) {
// PutDelete1kInfoHash can not run in parallel.
func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[0])
err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
if err != nil {
return err
}
return ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[0])
return ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
})
}
@@ -208,11 +208,11 @@ func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) {
// PutDelete1kInfoHash1k can not run in parallel.
func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
if err != nil {
return err
}
err = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
err = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
return err
})
}
@@ -223,7 +223,7 @@ func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) {
// DeleteNonexist can run in parallel.
func (bh *benchHolder) DeleteNonexist(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.DeleteSeeder(bd.infoHashes[0], bd.peers[0])
_ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0])
return nil
})
}
@@ -234,7 +234,7 @@ func (bh *benchHolder) DeleteNonexist(b *testing.B) {
// DeleteNonexist can run in parallel.
func (bh *benchHolder) DeleteNonexist1k(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.DeleteSeeder(bd.infoHashes[0], bd.peers[i%peersCount])
_ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
return nil
})
}
@@ -245,7 +245,7 @@ func (bh *benchHolder) DeleteNonexist1k(b *testing.B) {
// DeleteNonexist1kInfoHash can run in parallel.
func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[0])
_ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
return nil
})
}
@@ -256,7 +256,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) {
// DeleteNonexist1kInfoHash1k can run in parallel.
func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
_ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
return nil
})
}
@@ -267,7 +267,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) {
// GradNonexist can run in parallel.
func (bh *benchHolder) GradNonexist(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[0])
_ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0])
return nil
})
}
@@ -278,7 +278,7 @@ func (bh *benchHolder) GradNonexist(b *testing.B) {
// GradNonexist1k can run in parallel.
func (bh *benchHolder) GradNonexist1k(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[i%peersCount])
_ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
return nil
})
}
@@ -289,7 +289,7 @@ func (bh *benchHolder) GradNonexist1k(b *testing.B) {
// GradNonexist1kInfoHash can run in parallel.
func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[0])
_ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
return nil
})
}
@@ -301,7 +301,7 @@ func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) {
// GradNonexist1kInfoHash1k can run in parallel.
func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) {
bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
_ = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
_ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
return nil
})
}
@@ -313,15 +313,15 @@ func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) {
// PutGradDelete can not run in parallel.
func (bh *benchHolder) PutGradDelete(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutLeecher(bd.infoHashes[0], bd.peers[0])
err := ps.PutLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0])
if err != nil {
return err
}
err = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[0])
err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0])
if err != nil {
return err
}
return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[0])
return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0])
})
}
@@ -330,15 +330,15 @@ func (bh *benchHolder) PutGradDelete(b *testing.B) {
// PutGradDelete1k can not run in parallel.
func (bh *benchHolder) PutGradDelete1k(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutLeecher(bd.infoHashes[0], bd.peers[i%peersCount])
err := ps.PutLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
if err != nil {
return err
}
err = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[i%peersCount])
err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
if err != nil {
return err
}
return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[i%peersCount])
return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount])
})
}
@@ -348,15 +348,15 @@ func (bh *benchHolder) PutGradDelete1k(b *testing.B) {
// PutGradDelete1kInfoHash can not run in parallel.
func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutLeecher(bd.infoHashes[i%ihCount], bd.peers[0])
err := ps.PutLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
if err != nil {
return err
}
err = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[0])
err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
if err != nil {
return err
}
return ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[0])
return ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0])
})
}
@@ -366,15 +366,15 @@ func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) {
// PutGradDelete1kInfoHash can not run in parallel.
func (bh *benchHolder) PutGradDelete1kInfoHash1k(b *testing.B) {
bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error {
err := ps.PutLeecher(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
err := ps.PutLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
if err != nil {
return err
}
err = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
if err != nil {
return err
}
err = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
err = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount])
return err
})
}
@@ -385,9 +385,9 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error {
for i, peer := range bd.peers {
var err error
if i < l/2 {
err = ps.PutLeecher(ih, peer)
err = ps.PutLeecher(context.TODO(), ih, peer)
} else {
err = ps.PutSeeder(ih, peer)
err = ps.PutSeeder(context.TODO(), ih, peer)
}
if err != nil {
return err
@@ -404,7 +404,7 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error {
// AnnounceLeecher can run in parallel.
func (bh *benchHolder) AnnounceLeecher(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error {
_, err := ps.AnnouncePeers(bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6())
_, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6())
return err
})
}
@@ -415,7 +415,7 @@ func (bh *benchHolder) AnnounceLeecher(b *testing.B) {
// AnnounceLeecher1kInfoHash can run in parallel.
func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error {
_, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6())
_, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6())
return err
})
}
@@ -426,7 +426,7 @@ func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) {
// AnnounceSeeder can run in parallel.
func (bh *benchHolder) AnnounceSeeder(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error {
_, err := ps.AnnouncePeers(bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6())
_, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6())
return err
})
}
@@ -437,7 +437,7 @@ func (bh *benchHolder) AnnounceSeeder(b *testing.B) {
// AnnounceSeeder1kInfoHash can run in parallel.
func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error {
_, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6())
_, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6())
return err
})
}
@@ -448,7 +448,7 @@ func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) {
// ScrapeSwarm can run in parallel.
func (bh *benchHolder) ScrapeSwarm(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error {
ps.ScrapeSwarm(bd.infoHashes[0])
ps.ScrapeSwarm(context.TODO(), bd.infoHashes[0])
return nil
})
}
@@ -458,7 +458,7 @@ func (bh *benchHolder) ScrapeSwarm(b *testing.B) {
// ScrapeSwarm1kInfoHash can run in parallel.
func (bh *benchHolder) ScrapeSwarm1kInfoHash(b *testing.B) {
bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error {
ps.ScrapeSwarm(bd.infoHashes[i%ihCount])
ps.ScrapeSwarm(context.TODO(), bd.infoHashes[i%ihCount])
return nil
})
}
+40 -39
View File
@@ -3,6 +3,7 @@
package test
import (
"context"
"errors"
"testing"
@@ -40,7 +41,7 @@ type hashPeer struct {
func (th *testHolder) DeleteSeeder(t *testing.T) {
for _, c := range testData {
err := th.st.DeleteSeeder(c.ih, c.peer)
err := th.st.DeleteSeeder(context.TODO(), c.ih, c.peer)
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
@@ -54,14 +55,14 @@ func (th *testHolder) PutLeecher(t *testing.T) {
if c.peer.Addr().Is6() {
peer = v6Peer
}
err := th.st.PutLeecher(c.ih, peer)
err := th.st.PutLeecher(context.TODO(), c.ih, peer)
require.Nil(t, err)
}
}
func (th *testHolder) DeleteLeecher(t *testing.T) {
for _, c := range testData {
err := th.st.DeleteLeecher(c.ih, c.peer)
err := th.st.DeleteLeecher(context.TODO(), c.ih, c.peer)
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
@@ -71,7 +72,7 @@ func (th *testHolder) DeleteLeecher(t *testing.T) {
func (th *testHolder) AnnouncePeers(t *testing.T) {
for _, c := range testData {
_, err := th.st.AnnouncePeers(c.ih, false, 50, c.peer.Addr().Is6())
_, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, c.peer.Addr().Is6())
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
@@ -81,7 +82,7 @@ func (th *testHolder) AnnouncePeers(t *testing.T) {
func (th *testHolder) ScrapeSwarm(t *testing.T) {
for _, c := range testData {
l, s, n := th.st.ScrapeSwarm(c.ih)
l, s, n := th.st.ScrapeSwarm(context.TODO(), c.ih)
require.Equal(t, uint32(0), s)
require.Equal(t, uint32(0), l)
require.Equal(t, uint32(0), n)
@@ -91,26 +92,26 @@ func (th *testHolder) ScrapeSwarm(t *testing.T) {
func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) {
for _, c := range testData {
isV6 := c.peer.Addr().Is6()
err := th.st.PutLeecher(c.ih, c.peer)
err := th.st.PutLeecher(context.TODO(), c.ih, c.peer)
require.Nil(t, err)
peers, err := th.st.AnnouncePeers(c.ih, true, 50, isV6)
peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, true, 50, isV6)
require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer))
// non-seeder announce should still return the leecher
peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6)
peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6)
require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer))
l, s, _ := th.st.ScrapeSwarm(c.ih)
l, s, _ := th.st.ScrapeSwarm(context.TODO(), c.ih)
require.Equal(t, uint32(2), l)
require.Equal(t, uint32(0), s)
err = th.st.DeleteLeecher(c.ih, c.peer)
err = th.st.DeleteLeecher(context.TODO(), c.ih, c.peer)
require.Nil(t, err)
peers, err = th.st.AnnouncePeers(c.ih, true, 50, isV6)
peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, true, 50, isV6)
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
@@ -122,22 +123,22 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) {
func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) {
for _, c := range testData {
isV6 := c.peer.Addr().Is6()
err := th.st.PutSeeder(c.ih, c.peer)
err := th.st.PutSeeder(context.TODO(), c.ih, c.peer)
require.Nil(t, err)
// Should be leecher to see the seeder
peers, err := th.st.AnnouncePeers(c.ih, false, 50, isV6)
peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6)
require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer))
l, s, _ := th.st.ScrapeSwarm(c.ih)
l, s, _ := th.st.ScrapeSwarm(context.TODO(), c.ih)
require.Equal(t, uint32(1), l)
require.Equal(t, uint32(1), s)
err = th.st.DeleteSeeder(c.ih, c.peer)
err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer)
require.Nil(t, err)
peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6)
peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6)
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
@@ -153,44 +154,44 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) {
if isV6 {
peer = v6Peer
}
err := th.st.PutLeecher(c.ih, c.peer)
err := th.st.PutLeecher(context.TODO(), c.ih, c.peer)
require.Nil(t, err)
err = th.st.GraduateLeecher(c.ih, c.peer)
err = th.st.GraduateLeecher(context.TODO(), c.ih, c.peer)
require.Nil(t, err)
// Has to be leecher to see the graduated seeder
peers, err := th.st.AnnouncePeers(c.ih, false, 50, isV6)
peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6)
require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer))
// Deleting the Peer as a Leecher should have no effect
err = th.st.DeleteLeecher(c.ih, c.peer)
err = th.st.DeleteLeecher(context.TODO(), c.ih, c.peer)
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
require.Nil(t, err)
// Verify it's still there
peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6)
peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6)
require.Nil(t, err)
require.True(t, containsPeer(peers, c.peer))
// Clean up
err = th.st.DeleteLeecher(c.ih, peer)
err = th.st.DeleteLeecher(context.TODO(), c.ih, peer)
require.Nil(t, err)
// Test ErrDNE for missing leecher
err = th.st.DeleteLeecher(c.ih, peer)
err = th.st.DeleteLeecher(context.TODO(), c.ih, peer)
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
require.Nil(t, err)
err = th.st.DeleteSeeder(c.ih, c.peer)
err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer)
require.Nil(t, err)
err = th.st.DeleteSeeder(c.ih, c.peer)
err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer)
if errors.Is(err, storage.ErrResourceDoesNotExist) {
err = nil
}
@@ -200,35 +201,35 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) {
func (th *testHolder) CustomPutContainsLoadDelete(t *testing.T) {
for _, c := range testData {
err := th.st.Put(kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())})
err := th.st.Put(context.TODO(), kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())})
require.Nil(t, err)
// check if exist in ctx we put
contains, err := th.st.Contains(kvStoreCtx, c.peer.String())
contains, err := th.st.Contains(context.TODO(), kvStoreCtx, c.peer.String())
require.Nil(t, err)
require.True(t, contains)
// check if not exist in another ctx
contains, err = th.st.Contains("", c.peer.String())
contains, err = th.st.Contains(context.TODO(), "", c.peer.String())
require.Nil(t, err)
require.False(t, contains)
// check value and type in ctx we put
out, err := th.st.Load(kvStoreCtx, c.peer.String())
out, err := th.st.Load(context.TODO(), kvStoreCtx, c.peer.String())
require.Nil(t, err)
ih, err := bittorrent.NewInfoHash(out)
require.Nil(t, err)
require.Equal(t, c.ih, ih)
// check value is nil in another ctx
dummy, err := th.st.Load("", c.peer.String())
dummy, err := th.st.Load(context.TODO(), "", c.peer.String())
require.Nil(t, err)
require.Nil(t, dummy)
err = th.st.Delete(kvStoreCtx, c.peer.String())
err = th.st.Delete(context.TODO(), kvStoreCtx, c.peer.String())
require.Nil(t, err)
contains, err = th.st.Contains("", c.peer.String())
contains, err = th.st.Contains(context.TODO(), "", c.peer.String())
require.Nil(t, err)
require.False(t, contains)
}
@@ -245,29 +246,29 @@ func (th *testHolder) CustomBulkPutContainsLoadDelete(t *testing.T) {
Value: []byte(c.ih.RawString()),
})
}
err := th.st.Put(kvStoreCtx, pairs...)
err := th.st.Put(context.TODO(), kvStoreCtx, pairs...)
require.Nil(t, err)
// check if exist in ctx we put
for _, k := range keys {
contains, err := th.st.Contains(kvStoreCtx, k)
contains, err := th.st.Contains(context.TODO(), kvStoreCtx, k)
require.Nil(t, err)
require.True(t, contains)
}
// check value and type in ctx we put
for _, p := range pairs {
out, _ := th.st.Load(kvStoreCtx, p.Key)
out, _ := th.st.Load(context.TODO(), kvStoreCtx, p.Key)
ih, err := bittorrent.NewInfoHash(out)
require.Nil(t, err)
require.Equal(t, p.Value, []byte(ih.RawString()))
}
err = th.st.Delete(kvStoreCtx, keys...)
err = th.st.Delete(context.TODO(), kvStoreCtx, keys...)
require.Nil(t, err)
for _, k := range keys {
contains, err := th.st.Contains(kvStoreCtx, k)
contains, err := th.st.Contains(context.TODO(), kvStoreCtx, k)
require.Nil(t, err)
require.False(t, contains)
}
@@ -304,8 +305,8 @@ func RunTests(t *testing.T, p storage.PeerStorage) {
t.Run("CustomPutContainsLoadDelete", th.CustomPutContainsLoadDelete)
t.Run("CustomBulkPutContainsLoadDelete", th.CustomBulkPutContainsLoadDelete)
e := th.st.Stop()
require.Nil(t, <-e)
e := th.st.Close()
require.Nil(t, e)
}
func containsPeer(peers []bittorrent.Peer, p bittorrent.Peer) bool {