From d6de38bdbd4447085ae8806bd6e5405531812090 Mon Sep 17 00:00:00 2001 From: "Lawrence, Rendall" Date: Mon, 31 Oct 2022 18:57:57 +0300 Subject: [PATCH 1/3] (wip) pass context from frontend request down to store --- frontend/http/frontend.go | 37 ++++-- middleware/hooks.go | 32 +++--- .../torrentapproval/container/container.go | 3 +- .../container/directory/directory.go | 28 ++--- .../torrentapproval/container/list/list.go | 9 +- middleware/torrentapproval/torrentapproval.go | 2 +- storage/keydb/storage.go | 50 ++++---- storage/memory/storage.go | 25 ++-- storage/pg/storage.go | 78 ++++++------- storage/redis/storage.go | 107 +++++++++--------- storage/storage.go | 27 ++--- storage/test/storage_bench.go | 80 ++++++------- storage/test/storage_test_base.go | 74 ++++++------ 13 files changed, 285 insertions(+), 267 deletions(-) diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index be35a39..fe93992 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -169,6 +169,17 @@ func injectRouteParamsToContext(ctx context.Context, ps httprouter.Params) conte return context.WithValue(ctx, bittorrent.RouteParamsKey, rp) } +func remapRouteParamsToBgContext(inCtx context.Context) context.Context { + rp, isOk := inCtx.Value(bittorrent.RouteParamsKey).(bittorrent.RouteParams) + if !isOk { + rp = bittorrent.RouteParams{} + } else { + logger.Warn().Msg("unable to fetch route parameters, probably jammed context") + } + // FIXME: cancelable context + return context.WithValue(context.TODO(), bittorrent.RouteParamsKey, rp) +} + // announceRoute parses and responds to an Announce. func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { var err error @@ -189,7 +200,7 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro } addr = req.GetFirst() - ctx := injectRouteParamsToContext(context.Background(), ps) + ctx := injectRouteParamsToContext(r.Context(), ps) ctx, resp, err := f.logic.HandleAnnounce(ctx, req) if err != nil { WriteError(w, err) @@ -203,6 +214,9 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro return } + // next actions are background and should not be canceled after http writer closed + ctx = remapRouteParamsToBgContext(ctx) + go f.logic.AfterAnnounce(ctx, req, resp) } @@ -225,7 +239,7 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout } addr = req.GetFirst() - ctx := injectRouteParamsToContext(context.Background(), ps) + ctx := injectRouteParamsToContext(r.Context(), ps) ctx, resp, err := f.logic.HandleScrape(ctx, req) if err != nil { WriteError(w, err) @@ -239,18 +253,27 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout return } + // next actions are background and should not be canceled after http writer closed + ctx = remapRouteParamsToBgContext(ctx) + go f.logic.AfterScrape(ctx, req, resp) } func (f *httpFE) ping(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { var err error + status := http.StatusOK + ctx := r.Context() if r.Method == http.MethodGet { - err = f.logic.Ping(context.TODO()) + err = f.logic.Ping(ctx) } - if err == nil { - w.WriteHeader(http.StatusOK) - } else { + + if err != nil { logger.Error().Err(err).Msg("ping completed with error") - w.WriteHeader(http.StatusServiceUnavailable) + status = http.StatusServiceUnavailable + } + if ctxErr := ctx.Err(); ctxErr == nil { + w.WriteHeader(status) + } else { + logger.Info().Err(ctxErr).Str("ip", r.RemoteAddr).Msg("ping request cancelled") } } diff --git a/middleware/hooks.go b/middleware/hooks.go index 444969b..71dd9b5 100644 --- a/middleware/hooks.go +++ b/middleware/hooks.go @@ -46,17 +46,17 @@ func (h *swarmInteractionHook) HandleAnnounce(ctx context.Context, req *bittorre return } - var storeFn func(bittorrent.InfoHash, bittorrent.Peer) error + var storeFn func(context.Context, bittorrent.InfoHash, bittorrent.Peer) error switch { case req.Event == bittorrent.Stopped: - storeFn = func(hash bittorrent.InfoHash, peer bittorrent.Peer) error { - err = h.store.DeleteSeeder(hash, peer) + storeFn = func(ctx context.Context, hash bittorrent.InfoHash, peer bittorrent.Peer) error { + err = h.store.DeleteSeeder(ctx, hash, peer) if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) { return err } - err = h.store.DeleteLeecher(hash, peer) + err = h.store.DeleteLeecher(ctx, hash, peer) if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) { return err } @@ -74,8 +74,8 @@ func (h *swarmInteractionHook) HandleAnnounce(ctx context.Context, req *bittorre storeFn = h.store.PutLeecher } for _, p := range req.Peers() { - if err = storeFn(req.InfoHash, p); err == nil && len(req.InfoHash) == bittorrent.InfoHashV2Len { - err = storeFn(req.InfoHash.TruncateV1(), p) + if err = storeFn(ctx, req.InfoHash, p); err == nil && len(req.InfoHash) == bittorrent.InfoHashV2Len { + err = storeFn(ctx, req.InfoHash.TruncateV1(), p) } if err != nil { break @@ -102,10 +102,10 @@ type responseHook struct { store storage.PeerStorage } -func (h *responseHook) scrape(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) { - leechers, seeders, snatched = h.store.ScrapeSwarm(ih) +func (h *responseHook) scrape(ctx context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) { + leechers, seeders, snatched = h.store.ScrapeSwarm(ctx, ih) if len(ih) == bittorrent.InfoHashV2Len { - l, s, n := h.store.ScrapeSwarm(ih.TruncateV1()) + l, s, n := h.store.ScrapeSwarm(ctx, ih.TruncateV1()) leechers, seeders, snatched = leechers+l, seeders+s, snatched+n } return @@ -117,9 +117,9 @@ func (h *responseHook) HandleAnnounce(ctx context.Context, req *bittorrent.Annou } // Add the Scrape data to the response. - resp.Incomplete, resp.Complete, _ = h.scrape(req.InfoHash) + resp.Incomplete, resp.Complete, _ = h.scrape(ctx, req.InfoHash) - err = h.appendPeers(req, resp) + err = h.appendPeers(ctx, req, resp) return ctx, err } @@ -128,7 +128,7 @@ type fetchArgs struct { v6 bool } -func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (err error) { +func (h *responseHook) appendPeers(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (err error) { seeding := req.Left == 0 max := int(req.NumWant) peers := make([]bittorrent.Peer, 0, len(resp.IPv4Peers)+len(resp.IPv6Peers)) @@ -159,7 +159,7 @@ func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittor break } var storePeers []bittorrent.Peer - storePeers, err = h.store.AnnouncePeers(a.ih, seeding, max, a.v6) + storePeers, err = h.store.AnnouncePeers(ctx, a.ih, seeding, max, a.v6) if err != nil && !errors.Is(err, storage.ErrResourceDoesNotExist) { return err } @@ -209,13 +209,13 @@ func (h *responseHook) HandleScrape(ctx context.Context, req *bittorrent.ScrapeR for _, infoHash := range req.InfoHashes { scr := bittorrent.Scrape{InfoHash: infoHash} - scr.Incomplete, scr.Complete, scr.Snatches = h.scrape(infoHash) + scr.Incomplete, scr.Complete, scr.Snatches = h.scrape(ctx, infoHash) resp.Files = append(resp.Files, scr) } return ctx, nil } -func (h *responseHook) Ping(_ context.Context) error { - return h.store.Ping() +func (h *responseHook) Ping(ctx context.Context) error { + return h.store.Ping(ctx) } diff --git a/middleware/torrentapproval/container/container.go b/middleware/torrentapproval/container/container.go index 155d71c..ed58db5 100644 --- a/middleware/torrentapproval/container/container.go +++ b/middleware/torrentapproval/container/container.go @@ -3,6 +3,7 @@ package container import ( + "context" "errors" "sync" @@ -41,7 +42,7 @@ func Register(n string, c Builder) { // Container holds InfoHash and checks if value approved or not type Container interface { - Approved(bittorrent.InfoHash) bool + Approved(context.Context, bittorrent.InfoHash) bool } // GetContainer creates Container by its name and provided confBytes diff --git a/middleware/torrentapproval/container/directory/directory.go b/middleware/torrentapproval/container/directory/directory.go index d002fa4..9358cbc 100644 --- a/middleware/torrentapproval/container/directory/directory.go +++ b/middleware/torrentapproval/container/directory/directory.go @@ -5,6 +5,7 @@ package directory import ( + "context" "fmt" "github.com/anacrolix/torrent/metainfo" @@ -85,28 +86,23 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er name = list.DUMMY } bName := []byte(name) - logger.Err(d.Storage.Put(d.StorageCtx, - storage.Entry{ - Key: event.InfoHash.AsString(), - Value: bName, - }, storage.Entry{ - Key: v2hash.RawString(), - Value: bName, - }, storage.Entry{ - Key: v2hash.TruncateV1().RawString(), - Value: bName, - })). + logger.Err(d.Storage.Put(context.Background(), d.StorageCtx, storage.Entry{ + Key: event.InfoHash.AsString(), + Value: bName, + }, storage.Entry{ + Key: v2hash.RawString(), + Value: bName, + }, storage.Entry{ + Key: v2hash.TruncateV1().RawString(), + Value: bName, + })). Str("action", "add"). Str("file", event.TorrentFilePath). Stringer("infoHash", event.InfoHash). Stringer("infoHashV2", v2hash). Msg("approval torrent watcher event") case dirwatch.Removed: - logger.Err(d.Storage.Delete(c.StorageCtx, - event.InfoHash.AsString(), - v2hash.RawString(), - v2hash.TruncateV1().RawString(), - )). + logger.Err(d.Storage.Delete(context.Background(), c.StorageCtx, event.InfoHash.AsString(), v2hash.RawString(), v2hash.TruncateV1().RawString())). Str("action", "delete"). Str("file", event.TorrentFilePath). Stringer("infoHash", event.InfoHash). diff --git a/middleware/torrentapproval/container/list/list.go b/middleware/torrentapproval/container/list/list.go index 0440411..cd61d25 100644 --- a/middleware/torrentapproval/container/list/list.go +++ b/middleware/torrentapproval/container/list/list.go @@ -3,6 +3,7 @@ package list import ( + "context" "fmt" "github.com/sot-tech/mochi/bittorrent" @@ -64,7 +65,7 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er init = append(init, storage.Entry{Key: ih.TruncateV1().RawString(), Value: []byte(DUMMY)}) } } - if err := l.Storage.Put(l.StorageCtx, init...); err != nil { + if err := l.Storage.Put(context.Background(), l.StorageCtx, init...); err != nil { return nil, fmt.Errorf("unable to put initial data: %w", err) } } @@ -84,11 +85,11 @@ type List struct { // Approved checks if specified hash is approved or not. // If List.Invert set to true and hash found in storage, function will return false, // that means that hash is blacklisted. -func (l *List) Approved(hash bittorrent.InfoHash) (contains bool) { +func (l *List) Approved(ctx context.Context, hash bittorrent.InfoHash) (contains bool) { var err error - if contains, err = l.Storage.Contains(l.StorageCtx, hash.RawString()); err == nil { + if contains, err = l.Storage.Contains(ctx, l.StorageCtx, hash.RawString()); err == nil { if len(hash) == bittorrent.InfoHashV2Len { - if containsV2, errV2 := l.Storage.Contains(l.StorageCtx, hash.TruncateV1().RawString()); err == nil { + if containsV2, errV2 := l.Storage.Contains(ctx, l.StorageCtx, hash.TruncateV1().RawString()); err == nil { contains = contains || containsV2 } else { err = errV2 diff --git a/middleware/torrentapproval/torrentapproval.go b/middleware/torrentapproval/torrentapproval.go index fab9f8a..93ca348 100644 --- a/middleware/torrentapproval/torrentapproval.go +++ b/middleware/torrentapproval/torrentapproval.go @@ -74,7 +74,7 @@ type hook struct { func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, _ *bittorrent.AnnounceResponse) (context.Context, error) { var err error - if !h.hashContainer.Approved(req.InfoHash) { + if !h.hashContainer.Approved(ctx, req.InfoHash) { err = ErrTorrentUnapproved } diff --git a/storage/keydb/storage.go b/storage/keydb/storage.go index dbb9acb..7bdca86 100644 --- a/storage/keydb/storage.go +++ b/storage/keydb/storage.go @@ -93,27 +93,23 @@ type store struct { peerTTL uint } -func (s *store) setPeerTTL(infoHashKey, peerID string) error { - return s.Process(context.TODO(), redis.NewCmd(context.TODO(), expireMemberCmd, infoHashKey, peerID, s.peerTTL)) -} - -func (s *store) addPeer(infoHashKey, peerID string) (err error) { +func (s *store) addPeer(ctx context.Context, infoHashKey, peerID string) (err error) { logger.Trace(). Str("infoHashKey", infoHashKey). Str("peerID", peerID). Msg("add peer") - if err = s.SAdd(context.TODO(), infoHashKey, peerID).Err(); err == nil { - err = s.setPeerTTL(infoHashKey, peerID) + if err = s.SAdd(ctx, infoHashKey, peerID).Err(); err == nil { + err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, infoHashKey, peerID, s.peerTTL)) } return } -func (s *store) delPeer(infoHashKey, peerID string) error { +func (s *store) delPeer(ctx context.Context, infoHashKey, peerID string) error { logger.Trace(). Str("infoHashKey", infoHashKey). Str("peerID", peerID). Msg("del peer") - deleted, err := s.SRem(context.TODO(), infoHashKey, peerID).Uint64() + deleted, err := s.SRem(ctx, infoHashKey, peerID).Uint64() err = r.AsNil(err) if err == nil && deleted == 0 { err = storage.ErrResourceDoesNotExist @@ -122,23 +118,23 @@ func (s *store) delPeer(infoHashKey, peerID string) error { return err } -func (s *store) PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.addPeer(r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString()) +func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString()) } -func (s *store) DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString()) +func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), peer.RawString()) } -func (s *store) PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.addPeer(r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString()) +func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.addPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString()) } -func (s *store) DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString()) +func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.delPeer(ctx, r.InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), peer.RawString()) } -func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) (err error) { +func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) (err error) { logger.Trace(). Stringer("infoHash", ih). Object("peer", peer). @@ -147,21 +143,21 @@ func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) (e ihSeederKey := r.InfoHashKey(infoHash, true, peer.Addr().Is6()) ihLeecherKey := r.InfoHashKey(infoHash, false, peer.Addr().Is6()) var moved bool - if moved, err = s.SMove(context.TODO(), ihLeecherKey, ihSeederKey, peerID).Result(); err == nil { + if moved, err = s.SMove(ctx, ihLeecherKey, ihSeederKey, peerID).Result(); err == nil { if moved { - err = s.setPeerTTL(ihSeederKey, peerID) + err = s.Process(ctx, redis.NewCmd(ctx, expireMemberCmd, ihSeederKey, peerID, s.peerTTL)) } else { - err = s.addPeer(ihSeederKey, peerID) + err = s.addPeer(ctx, ihSeederKey, peerID) } if err == nil { - err = s.HIncrBy(context.TODO(), r.CountDownloadsKey, infoHash, 1).Err() + err = s.HIncrBy(ctx, r.CountDownloadsKey, infoHash, 1).Err() } } return err } // AnnouncePeers is the same function as redis.AnnouncePeers -func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) { +func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) { logger.Trace(). Stringer("infoHash", ih). Bool("forSeeder", forSeeder). @@ -169,17 +165,17 @@ func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant in Bool("v6", v6). Msg("announce peers") - return s.GetPeers(ih, forSeeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd { - return s.SRandMemberN(context.TODO(), infoHashKey, int64(maxCount)) + return s.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd { + return s.SRandMemberN(ctx, infoHashKey, int64(maxCount)) }) } // ScrapeSwarm is the same function as redis.ScrapeSwarm except `SCard` call instead of `HLen` -func (s *store) ScrapeSwarm(ih bittorrent.InfoHash) (uint32, uint32, uint32) { +func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (uint32, uint32, uint32) { logger.Trace(). Stringer("infoHash", ih). Msg("scrape swarm") - return s.ScrapeIH(ih, s.SCard) + return s.ScrapeIH(ctx, ih, s.SCard) } func (*store) GCAware() bool { diff --git a/storage/memory/storage.go b/storage/memory/storage.go index b586e5f..a88c6ad 100644 --- a/storage/memory/storage.go +++ b/storage/memory/storage.go @@ -3,6 +3,7 @@ package memory import ( + "context" "encoding/binary" "math" "runtime" @@ -174,7 +175,7 @@ func (ps *peerStore) shardIndex(infoHash bittorrent.InfoHash, v6 bool) uint32 { return idx } -func (ps *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) PutSeeder(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error { select { case <-ps.closed: panic("attempted to interact with stopped memory store") @@ -207,7 +208,7 @@ func (ps *peerStore) PutSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error return nil } -func (ps *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) DeleteSeeder(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error { select { case <-ps.closed: panic("attempted to interact with stopped memory store") @@ -240,7 +241,7 @@ func (ps *peerStore) DeleteSeeder(ih bittorrent.InfoHash, p bittorrent.Peer) err return nil } -func (ps *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) PutLeecher(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error { select { case <-ps.closed: panic("attempted to interact with stopped memory store") @@ -273,7 +274,7 @@ func (ps *peerStore) PutLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error return nil } -func (ps *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) DeleteLeecher(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error { select { case <-ps.closed: panic("attempted to interact with stopped memory store") @@ -306,7 +307,7 @@ func (ps *peerStore) DeleteLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) er return nil } -func (ps *peerStore) GraduateLeecher(ih bittorrent.InfoHash, p bittorrent.Peer) error { +func (ps *peerStore) GraduateLeecher(_ context.Context, ih bittorrent.InfoHash, p bittorrent.Peer) error { select { case <-ps.closed: panic("attempted to interact with stopped memory store") @@ -372,7 +373,7 @@ func (ps *peerStore) getPeers(shard *peerShard, ih bittorrent.InfoHash, maxCount return } -func (ps *peerStore) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) { +func (ps *peerStore) AnnouncePeers(_ context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) { select { case <-ps.closed: panic("attempted to interact with stopped memory store") @@ -401,7 +402,7 @@ func (ps *peerStore) countPeers(ih bittorrent.InfoHash, v6 bool) (leechers, seed return } -func (ps *peerStore) ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) { +func (ps *peerStore) ScrapeSwarm(_ context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) { select { case <-ps.closed: panic("attempted to interact with stopped memory store") @@ -427,7 +428,7 @@ type dataStore struct { sync.Map } -func (ds *dataStore) Put(ctx string, values ...storage.Entry) error { +func (ds *dataStore) Put(_ context.Context, ctx string, values ...storage.Entry) error { if len(values) > 0 { c, _ := ds.LoadOrStore(ctx, new(sync.Map)) m := c.(*sync.Map) @@ -438,7 +439,7 @@ func (ds *dataStore) Put(ctx string, values ...storage.Entry) error { return nil } -func (ds *dataStore) Contains(ctx string, key string) (bool, error) { +func (ds *dataStore) Contains(_ context.Context, ctx string, key string) (bool, error) { var exist bool if m, found := ds.Map.Load(ctx); found { _, exist = m.(*sync.Map).Load(key) @@ -446,7 +447,7 @@ func (ds *dataStore) Contains(ctx string, key string) (bool, error) { return exist, nil } -func (ds *dataStore) Load(ctx string, key string) (out []byte, _ error) { +func (ds *dataStore) Load(_ context.Context, ctx string, key string) (out []byte, _ error) { if m, found := ds.Map.Load(ctx); found { if v, _ := m.(*sync.Map).Load(key); v != nil { out = v.([]byte) @@ -455,7 +456,7 @@ func (ds *dataStore) Load(ctx string, key string) (out []byte, _ error) { return } -func (ds *dataStore) Delete(ctx string, keys ...string) error { +func (ds *dataStore) Delete(_ context.Context, ctx string, keys ...string) error { if len(keys) > 0 { if m, found := ds.Map.Load(ctx); found { m := m.(*sync.Map) @@ -537,7 +538,7 @@ func (ps *peerStore) gc(cutoff time.Time) { } } -func (*peerStore) Ping() error { +func (*peerStore) Ping(context.Context) error { return nil } diff --git a/storage/pg/storage.go b/storage/pg/storage.go index 7e8dd1b..0f238d0 100644 --- a/storage/pg/storage.go +++ b/storage/pg/storage.go @@ -222,7 +222,7 @@ type store struct { func (s *store) txBatch(ctx context.Context, batch *pgx.Batch) (err error) { var tx pgx.Tx if tx, err = s.Begin(ctx); err == nil { - if err = tx.SendBatch(context.TODO(), batch).Close(); err == nil { + if err = tx.SendBatch(ctx, batch).Close(); err == nil { err = tx.Commit(ctx) } else { if txErr := tx.Rollback(ctx); txErr != nil { @@ -233,25 +233,25 @@ func (s *store) txBatch(ctx context.Context, batch *pgx.Batch) (err error) { return } -func (s *store) Put(ctx string, values ...storage.Entry) (err error) { +func (s *store) Put(ctx context.Context, storeCtx string, values ...storage.Entry) (err error) { switch len(values) { case 0: // ignore case 1: - _, err = s.Exec(context.TODO(), s.Data.AddQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(values[0].Key), pValue: values[0].Value}) + _, err = s.Exec(ctx, s.Data.AddQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(values[0].Key), pValue: values[0].Value}) default: var batch pgx.Batch for _, v := range values { batch.Queue(s.Data.AddQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(v.Key), pValue: v.Value}) } - err = s.txBatch(context.TODO(), &batch) + err = s.txBatch(ctx, &batch) } return } -func (s *store) Contains(ctx string, key string) (contains bool, err error) { +func (s *store) Contains(ctx context.Context, storeCtx string, key string) (contains bool, err error) { var rows pgx.Rows - if rows, err = s.Query(context.TODO(), s.Data.GetQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(key)}); err == nil { + if rows, err = s.Query(ctx, s.Data.GetQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(key)}); err == nil { defer rows.Close() contains = rows.Next() err = rows.Err() @@ -259,20 +259,20 @@ func (s *store) Contains(ctx string, key string) (contains bool, err error) { return } -func (s *store) Load(ctx string, key string) (out []byte, err error) { - if err = s.QueryRow(context.TODO(), s.Data.GetQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(key)}).Scan(&out); errors.Is(err, pgx.ErrNoRows) { +func (s *store) Load(ctx context.Context, storeCtx string, key string) (out []byte, err error) { + if err = s.QueryRow(ctx, s.Data.GetQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(key)}).Scan(&out); errors.Is(err, pgx.ErrNoRows) { err = nil } return } -func (s *store) Delete(ctx string, keys ...string) (err error) { +func (s *store) Delete(ctx context.Context, storeCtx string, keys ...string) (err error) { if len(keys) > 0 { baKeys := make([][]byte, len(keys)) for i, k := range keys { baKeys[i] = []byte(k) } - _, err = s.Exec(context.TODO(), s.Data.DelQuery, pgx.NamedArgs{pCtx: ctx, pKey: baKeys}) + _, err = s.Exec(ctx, s.Data.DelQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: baKeys}) } return } @@ -328,7 +328,7 @@ func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) { case <-t.C: if metrics.Enabled() { before := time.Now() - sc, lc := s.countPeers(nil) + sc, lc := s.countPeers(context.Background(), nil) var hc int if err := s.QueryRow(context.Background(), s.InfoHashCountQuery).Scan(&hc); err != nil && !errors.Is(err, pgx.ErrNoRows) { logger.Error().Err(err).Msg("error occurred while get info hash count") @@ -344,7 +344,7 @@ func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) { }() } -func (s *store) putPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) { +func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) { logger.Trace(). Stringer("infoHash", ih). Object("peer", peer). @@ -361,16 +361,16 @@ func (s *store) putPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder boo if s.GCAware() { args[pCreated] = timecache.Now() } - _, err = s.Exec(context.TODO(), s.Peer.AddQuery, args) + _, err = s.Exec(ctx, s.Peer.AddQuery, args) return } -func (s *store) delPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) { +func (s *store) delPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer, seeder bool) (err error) { logger.Trace(). Stringer("infoHash", ih). Object("peer", peer). Msg("del peer") - _, err = s.Exec(context.TODO(), s.Peer.DelQuery, pgx.NamedArgs{ + _, err = s.Exec(ctx, s.Peer.DelQuery, pgx.NamedArgs{ pInfoHash: []byte(ih), pPeerID: peer.ID[:], pAddress: net.IP(peer.Addr().AsSlice()), @@ -380,23 +380,23 @@ func (s *store) delPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder boo return } -func (s *store) PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.putPeer(ih, peer, true) +func (s *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.putPeer(ctx, ih, peer, true) } -func (s *store) DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(ih, peer, true) +func (s *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.delPeer(ctx, ih, peer, true) } -func (s *store) PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.putPeer(ih, peer, false) +func (s *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.putPeer(ctx, ih, peer, false) } -func (s *store) DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return s.delPeer(ih, peer, false) +func (s *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return s.delPeer(ctx, ih, peer, false) } -func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { +func (s *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { logger.Trace(). Stringer("infoHash", ih). Object("peer", peer). @@ -410,12 +410,12 @@ func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) er pPort: peer.Port(), }) batch.Queue(s.Downloads.IncrementQuery, pgx.NamedArgs{pInfoHash: ihb}) - return s.txBatch(context.TODO(), &batch) + return s.txBatch(ctx, &batch) } -func (s *store) getPeers(ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) { +func (s *store) getPeers(ctx context.Context, ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) { var rows pgx.Rows - if rows, err = s.Query(context.TODO(), s.Announce.Query, pgx.NamedArgs{ + if rows, err = s.Query(ctx, s.Announce.Query, pgx.NamedArgs{ pInfoHash: []byte(ih), pSeeder: seeders, pV6: isV6, @@ -484,7 +484,7 @@ func (s *store) getPeers(ih bittorrent.InfoHash, seeders bool, maxCount int, isV return } -func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) { +func (s *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) { logger.Trace(). Stringer("infoHash", ih). Bool("forSeeder", forSeeder). @@ -492,11 +492,11 @@ func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant in Bool("v6", v6). Msg("announce peers") if forSeeder { - peers, err = s.getPeers(ih, false, numWant, v6) + peers, err = s.getPeers(ctx, ih, false, numWant, v6) } else { - if peers, err = s.getPeers(ih, true, numWant, v6); err == nil { + if peers, err = s.getPeers(ctx, ih, true, numWant, v6); err == nil { var addPeers []bittorrent.Peer - addPeers, err = s.getPeers(ih, false, numWant-len(peers), v6) + addPeers, err = s.getPeers(ctx, ih, false, numWant-len(peers), v6) peers = append(peers, addPeers...) } } @@ -513,13 +513,13 @@ func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant in return } -func (s *store) countPeers(ih []byte) (seeders uint32, leechers uint32) { +func (s *store) countPeers(ctx context.Context, ih []byte) (seeders uint32, leechers uint32) { var rows pgx.Rows var err error if len(ih) == 0 { - rows, err = s.Query(context.TODO(), s.Peer.CountQuery) + rows, err = s.Query(ctx, s.Peer.CountQuery) } else { - rows, err = s.Query(context.TODO(), s.Peer.CountQuery+" "+s.Peer.ByInfoHashClause, pgx.NamedArgs{pInfoHash: ih}) + rows, err = s.Query(ctx, s.Peer.CountQuery+" "+s.Peer.ByInfoHashClause, pgx.NamedArgs{pInfoHash: ih}) } if err == nil { defer rows.Close() @@ -556,14 +556,14 @@ func (s *store) countPeers(ih []byte) (seeders uint32, leechers uint32) { return } -func (s *store) ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) { +func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) { logger.Trace(). Stringer("infoHash", ih). Msg("scrape swarm") ihb := []byte(ih) - seeders, leechers = s.countPeers(ihb) + seeders, leechers = s.countPeers(ctx, ihb) if len(s.Downloads.GetQuery) > 0 { - if err := s.QueryRow(context.TODO(), s.Downloads.GetQuery, pgx.NamedArgs{pInfoHash: ihb}).Scan(&snatched); err != nil && !errors.Is(err, pgx.ErrNoRows) { + if err := s.QueryRow(ctx, s.Downloads.GetQuery, pgx.NamedArgs{pInfoHash: ihb}).Scan(&snatched); err != nil && !errors.Is(err, pgx.ErrNoRows) { logger.Error().Stringer("infoHash", ih).Err(err).Msg("error occurred while get info downloads count") } } @@ -571,8 +571,8 @@ func (s *store) ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders ui return } -func (s *store) Ping() error { - _, err := s.Exec(context.TODO(), s.PingQuery) +func (s *store) Ping(ctx context.Context) error { + _, err := s.Exec(ctx, s.PingQuery) return err } diff --git a/storage/redis/storage.go b/storage/redis/storage.go index f7c26bf..d50b4b9 100644 --- a/storage/redis/storage.go +++ b/storage/redis/storage.go @@ -309,8 +309,8 @@ func (ps *store) getClock() int64 { return timecache.NowUnixNano() } -func (ps *store) tx(txf func(tx redis.Pipeliner) error) (err error) { - if pipe, txErr := ps.TxPipelined(context.TODO(), txf); txErr == nil { +func (ps *store) tx(ctx context.Context, txf func(tx redis.Pipeliner) error) (err error) { + if pipe, txErr := ps.TxPipelined(ctx, txf); txErr == nil { errs := make([]string, 0) for _, c := range pipe { if err := c.Err(); err != nil { @@ -358,58 +358,58 @@ func InfoHashKey(infoHash string, seeder, v6 bool) (infoHashKey string) { return } -func (ps *store) putPeer(infoHashKey, peerCountKey, peerID string) error { +func (ps *store) putPeer(ctx context.Context, infoHashKey, peerCountKey, peerID string) error { logger.Trace(). Str("infoHashKey", infoHashKey). Str("peerID", peerID). Msg("put peer") - return ps.tx(func(tx redis.Pipeliner) (err error) { - if err = tx.HSet(context.TODO(), infoHashKey, peerID, ps.getClock()).Err(); err != nil { + return ps.tx(ctx, func(tx redis.Pipeliner) (err error) { + if err = tx.HSet(ctx, infoHashKey, peerID, ps.getClock()).Err(); err != nil { return } - if err = tx.Incr(context.TODO(), peerCountKey).Err(); err != nil { + if err = tx.Incr(ctx, peerCountKey).Err(); err != nil { return } - err = tx.SAdd(context.TODO(), IHKey, infoHashKey).Err() + err = tx.SAdd(ctx, IHKey, infoHashKey).Err() return }) } -func (ps *store) delPeer(infoHashKey, peerCountKey, peerID string) error { +func (ps *store) delPeer(ctx context.Context, infoHashKey, peerCountKey, peerID string) error { logger.Trace(). Str("infoHashKey", infoHashKey). Str("peerID", peerID). Msg("del peer") - deleted, err := ps.HDel(context.TODO(), infoHashKey, peerID).Uint64() + deleted, err := ps.HDel(ctx, infoHashKey, peerID).Uint64() err = AsNil(err) if err == nil { if deleted == 0 { err = storage.ErrResourceDoesNotExist } else { - err = ps.Decr(context.TODO(), peerCountKey).Err() + err = ps.Decr(ctx, peerCountKey).Err() } } return err } -func (ps *store) PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.putPeer(InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString()) +func (ps *store) PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return ps.putPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString()) } -func (ps *store) DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.delPeer(InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString()) +func (ps *store) DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return ps.delPeer(ctx, InfoHashKey(ih.RawString(), true, peer.Addr().Is6()), CountSeederKey, peer.RawString()) } -func (ps *store) PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.putPeer(InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString()) +func (ps *store) PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return ps.putPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString()) } -func (ps *store) DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { - return ps.delPeer(InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString()) +func (ps *store) DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { + return ps.delPeer(ctx, InfoHashKey(ih.RawString(), false, peer.Addr().Is6()), CountLeecherKey, peer.RawString()) } -func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error { +func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error { logger.Trace(). Stringer("infoHash", ih). Object("peer", peer). @@ -418,25 +418,25 @@ func (ps *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) e infoHash, peerID, isV6 := ih.RawString(), peer.RawString(), peer.Addr().Is6() ihSeederKey, ihLeecherKey := InfoHashKey(infoHash, true, isV6), InfoHashKey(infoHash, false, isV6) - return ps.tx(func(tx redis.Pipeliner) error { - deleted, err := tx.HDel(context.TODO(), ihLeecherKey, peerID).Uint64() + return ps.tx(ctx, func(tx redis.Pipeliner) error { + deleted, err := tx.HDel(ctx, ihLeecherKey, peerID).Uint64() err = AsNil(err) if err == nil { if deleted > 0 { - err = tx.Decr(context.TODO(), CountLeecherKey).Err() + err = tx.Decr(ctx, CountLeecherKey).Err() } } if err == nil { - err = tx.HSet(context.TODO(), ihSeederKey, peerID, ps.getClock()).Err() + err = tx.HSet(ctx, ihSeederKey, peerID, ps.getClock()).Err() } if err == nil { - err = tx.Incr(context.TODO(), CountSeederKey).Err() + err = tx.Incr(ctx, CountSeederKey).Err() } if err == nil { - err = tx.SAdd(context.TODO(), IHKey, ihSeederKey).Err() + err = tx.SAdd(ctx, IHKey, ihSeederKey).Err() } if err == nil { - err = tx.HIncrBy(context.TODO(), CountDownloadsKey, infoHash, 1).Err() + err = tx.HIncrBy(ctx, CountDownloadsKey, infoHash, 1).Err() } return err }) @@ -457,7 +457,7 @@ func (ps *Connection) parsePeersList(peersResult *redis.StringSliceCmd) (peers [ return } -type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd +type getPeersFn func(string, int) *redis.StringSliceCmd // GetPeers retrieves peers for provided info hash by calling membersFn and // converts result to bittorrent.Peer array. @@ -477,7 +477,7 @@ func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount for _, infoHashKey := range infoHashKeys { var peers []bittorrent.Peer - peers, err = ps.parsePeersList(membersFn(context.TODO(), infoHashKey, maxCount)) + peers, err = ps.parsePeersList(membersFn(infoHashKey, maxCount)) maxCount -= len(peers) out = append(out, peers...) if err != nil || maxCount <= 0 { @@ -497,7 +497,7 @@ func (ps *Connection) GetPeers(ih bittorrent.InfoHash, forSeeder bool, maxCount return } -func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) { +func (ps *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) { logger.Trace(). Stringer("infoHash", ih). Bool("forSeeder", forSeeder). @@ -505,15 +505,15 @@ func (ps *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant i Bool("v6", v6). Msg("announce peers") - return ps.GetPeers(ih, forSeeder, numWant, v6, func(ctx context.Context, infoHashKey string, maxCount int) *redis.StringSliceCmd { + return ps.GetPeers(ih, forSeeder, numWant, v6, func(infoHashKey string, maxCount int) *redis.StringSliceCmd { return ps.HRandField(ctx, infoHashKey, maxCount, false) }) } type getPeerCountFn func(context.Context, string) *redis.IntCmd -func (ps *Connection) countPeers(infoHashKey string, countFn getPeerCountFn) uint32 { - count, err := countFn(context.TODO(), infoHashKey).Result() +func (ps *Connection) countPeers(ctx context.Context, infoHashKey string, countFn getPeerCountFn) uint32 { + count, err := countFn(ctx, infoHashKey).Result() err = AsNil(err) if err != nil { logger.Error().Err(err).Str("infoHashKey", infoHashKey).Msg("key size calculation failure") @@ -522,14 +522,14 @@ func (ps *Connection) countPeers(infoHashKey string, countFn getPeerCountFn) uin } // ScrapeIH calls provided countFn and returns seeders, leechers and downloads count for specified info hash -func (ps *Connection) ScrapeIH(ih bittorrent.InfoHash, countFn getPeerCountFn) (leechersCount, seedersCount, downloadsCount uint32) { +func (ps *Connection) ScrapeIH(ctx context.Context, ih bittorrent.InfoHash, countFn getPeerCountFn) (leechersCount, seedersCount, downloadsCount uint32) { infoHash := ih.RawString() - leechersCount = ps.countPeers(InfoHashKey(infoHash, false, false), countFn) + - ps.countPeers(InfoHashKey(infoHash, false, true), countFn) - seedersCount = ps.countPeers(InfoHashKey(infoHash, true, false), countFn) + - ps.countPeers(InfoHashKey(infoHash, true, true), countFn) - d, err := ps.HGet(context.TODO(), CountDownloadsKey, infoHash).Uint64() + leechersCount = ps.countPeers(ctx, InfoHashKey(infoHash, false, false), countFn) + + ps.countPeers(ctx, InfoHashKey(infoHash, false, true), countFn) + seedersCount = ps.countPeers(ctx, InfoHashKey(infoHash, true, false), countFn) + + ps.countPeers(ctx, InfoHashKey(infoHash, true, true), countFn) + d, err := ps.HGet(ctx, CountDownloadsKey, infoHash).Uint64() if err = AsNil(err); err != nil { logger.Error().Err(err).Str("infoHash", infoHash).Msg("downloads count calculation failure") } @@ -538,31 +538,31 @@ func (ps *Connection) ScrapeIH(ih bittorrent.InfoHash, countFn getPeerCountFn) ( return } -func (ps *store) ScrapeSwarm(ih bittorrent.InfoHash) (uint32, uint32, uint32) { +func (ps *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (uint32, uint32, uint32) { logger.Trace(). Stringer("infoHash", ih). Msg("scrape swarm") - return ps.ScrapeIH(ih, ps.HLen) + return ps.ScrapeIH(ctx, ih, ps.HLen) } const argNumErrorMsg = "ERR wrong number of arguments" // Put - storage.DataStorage implementation -func (ps *Connection) Put(ctx string, values ...storage.Entry) (err error) { +func (ps *Connection) Put(ctx context.Context, storeCtx string, values ...storage.Entry) (err error) { if l := len(values); l > 0 { if l == 1 { - err = ps.HSet(context.TODO(), PrefixKey+ctx, values[0].Key, values[0].Value).Err() + err = ps.HSet(ctx, PrefixKey+storeCtx, values[0].Key, values[0].Value).Err() } else { args := make([]any, 0, l*2) for _, p := range values { args = append(args, p.Key, p.Value) } - err = ps.HSet(context.TODO(), PrefixKey+ctx, args...).Err() + err = ps.HSet(ctx, PrefixKey+storeCtx, args...).Err() if err != nil { if strings.Contains(err.Error(), argNumErrorMsg) { logger.Warn().Msg("This Redis version/implementation does not support variadic arguments for HSET") for _, p := range values { - if err = ps.HSet(context.TODO(), PrefixKey+ctx, p.Key, p.Value).Err(); err != nil { + if err = ps.HSet(ctx, PrefixKey+storeCtx, p.Key, p.Value).Err(); err != nil { break } } @@ -574,14 +574,14 @@ func (ps *Connection) Put(ctx string, values ...storage.Entry) (err error) { } // Contains - storage.DataStorage implementation -func (ps *Connection) Contains(ctx string, key string) (bool, error) { - exist, err := ps.HExists(context.TODO(), PrefixKey+ctx, key).Result() +func (ps *Connection) Contains(ctx context.Context, storeCtx string, key string) (bool, error) { + exist, err := ps.HExists(ctx, PrefixKey+storeCtx, key).Result() return exist, AsNil(err) } // Load - storage.DataStorage implementation -func (ps *Connection) Load(ctx string, key string) (v []byte, err error) { - v, err = ps.HGet(context.TODO(), PrefixKey+ctx, key).Bytes() +func (ps *Connection) Load(ctx context.Context, storeCtx string, key string) (v []byte, err error) { + v, err = ps.HGet(ctx, PrefixKey+storeCtx, key).Bytes() if err != nil && errors.Is(err, redis.Nil) { v, err = nil, nil } @@ -589,14 +589,14 @@ func (ps *Connection) Load(ctx string, key string) (v []byte, err error) { } // Delete - storage.DataStorage implementation -func (ps *Connection) Delete(ctx string, keys ...string) (err error) { +func (ps *Connection) Delete(ctx context.Context, storeCtx string, keys ...string) (err error) { if len(keys) > 0 { - err = AsNil(ps.HDel(context.TODO(), PrefixKey+ctx, keys...).Err()) + err = AsNil(ps.HDel(ctx, PrefixKey+storeCtx, keys...).Err()) if err != nil { if strings.Contains(err.Error(), argNumErrorMsg) { logger.Warn().Msg("This Redis version/implementation does not support variadic arguments for HDEL") for _, k := range keys { - if err = AsNil(ps.HDel(context.TODO(), PrefixKey+ctx, k).Err()); err != nil { + if err = AsNil(ps.HDel(ctx, PrefixKey+storeCtx, k).Err()); err != nil { break } } @@ -620,8 +620,8 @@ func (*store) StatisticsAware() bool { } // Ping sends `PING` request to Redis server -func (ps *Connection) Ping() error { - return ps.UniversalClient.Ping(context.TODO()).Err() +func (ps *Connection) Ping(ctx context.Context) error { + return ps.UniversalClient.Ping(ctx).Err() } // GC deletes all Peers from the PeerStorage which are older than the @@ -747,7 +747,6 @@ func (ps *store) gc(cutoff time.Time) { if err == nil && infoHashCount == 0 { // Empty hashes are not shown among existing keys, // in other words, it's removed automatically after `HDEL` the last field. - // _, err := ps.Del(context.TODO(), infoHashKey) err = AsNil(ps.SRem(context.Background(), IHKey, infoHashKey).Err()) } return err diff --git a/storage/storage.go b/storage/storage.go index 01628cd..bd61386 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -3,6 +3,7 @@ package storage import ( + "context" "fmt" "sync" "time" @@ -93,18 +94,18 @@ var ErrResourceDoesNotExist = bittorrent.ClientError("resource does not exist") // DataStorage is the interface, used for implementing store for arbitrary data type DataStorage interface { // Put used to place arbitrary k-v data with specified context - // into storage. ctx parameter used to group data + // into storage. storeCtx parameter used to group data // (i.e. data only for specific middleware module: hash key, table name etc...) - Put(ctx string, values ...Entry) error + Put(ctx context.Context, storeCtx string, values ...Entry) error // Contains checks if any data in specified context exist - Contains(ctx string, key string) (bool, error) + Contains(ctx context.Context, storeCtx string, key string) (bool, error) // Load used to get arbitrary data in specified context by its key - Load(ctx string, key string) ([]byte, error) + Load(ctx context.Context, storeCtx string, key string) ([]byte, error) // Delete used to delete arbitrary data in specified context by its keys - Delete(ctx string, keys ...string) error + Delete(ctx context.Context, storeCtx string, keys ...string) error // Preservable indicates, that this storage can store data permanently, // in other words, is NOT in-memory storage, which data will be lost after restart @@ -130,33 +131,33 @@ type PeerStorage interface { DataStorage // PutSeeder adds a Seeder to the Swarm identified by the provided // InfoHash. - PutSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error + PutSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error // DeleteSeeder removes a Seeder from the Swarm identified by the // provided InfoHash. // // If the Swarm or Peer does not exist, this function returns // ErrResourceDoesNotExist. - DeleteSeeder(ih bittorrent.InfoHash, peer bittorrent.Peer) error + DeleteSeeder(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error // PutLeecher adds a Leecher to the Swarm identified by the provided // InfoHash. // If the Swarm does not exist already, it is created. - PutLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error + PutLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error // DeleteLeecher removes a Leecher from the Swarm identified by the // provided InfoHash. // // If the Swarm or Peer does not exist, this function returns // ErrResourceDoesNotExist. - DeleteLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error + DeleteLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error // GraduateLeecher promotes a Leecher to a Seeder in the Swarm // identified by the provided InfoHash. // // If the given Peer is not present as a Leecher or the swarm does not exist // already, the Peer is added as a Seeder and no error is returned. - GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) error + GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, peer bittorrent.Peer) error // AnnouncePeers is a best effort attempt to return Peers from the Swarm // identified by the provided InfoHash. @@ -173,7 +174,7 @@ type PeerStorage interface { // leechers // // Returns ErrResourceDoesNotExist if the provided InfoHash is not tracked. - AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) + AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) (peers []bittorrent.Peer, err error) // ScrapeSwarm returns information required to answer a Scrape request // about a Swarm identified by the given InfoHash. @@ -183,11 +184,11 @@ type PeerStorage interface { // filling the Snatches field is optional. // // If the Swarm does not exist, an empty Scrape and no error is returned. - ScrapeSwarm(ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) + ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (leechers uint32, seeders uint32, snatched uint32) // Ping used for checks if storage is alive // (connection could be established, enough space etc.) - Ping() error + Ping(ctx context.Context) error // GCAware marks that this storage supports periodic // peers collection diff --git a/storage/test/storage_bench.go b/storage/test/storage_bench.go index d822cc7..a8023d0 100644 --- a/storage/test/storage_bench.go +++ b/storage/test/storage_bench.go @@ -125,7 +125,7 @@ func (bh *benchHolder) Nop(b *testing.B) { // Put can run in parallel. func (bh *benchHolder) Put(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(bd.infoHashes[0], bd.peers[0]) + return ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[0]) }) } @@ -135,7 +135,7 @@ func (bh *benchHolder) Put(b *testing.B) { // Put1k can run in parallel. func (bh *benchHolder) Put1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -145,7 +145,7 @@ func (bh *benchHolder) Put1k(b *testing.B) { // Put1kInfoHash can run in parallel. func (bh *benchHolder) Put1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -155,7 +155,7 @@ func (bh *benchHolder) Put1kInfoHash(b *testing.B) { // Put1kInfoHash1k can run in parallel. func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -166,11 +166,11 @@ func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) { // PutDelete can not run in parallel. func (bh *benchHolder) PutDelete(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(bd.infoHashes[0], bd.peers[0]) + err := ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[0]) + return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) }) } @@ -180,11 +180,11 @@ func (bh *benchHolder) PutDelete(b *testing.B) { // PutDelete1k can not run in parallel. func (bh *benchHolder) PutDelete1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(bd.infoHashes[0], bd.peers[i%peersCount]) + err := ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -194,11 +194,11 @@ func (bh *benchHolder) PutDelete1k(b *testing.B) { // PutDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[0]) + err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -208,11 +208,11 @@ func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) { // PutDelete1kInfoHash1k can not run in parallel. func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -223,7 +223,7 @@ func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) { // DeleteNonexist can run in parallel. func (bh *benchHolder) DeleteNonexist(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(bd.infoHashes[0], bd.peers[0]) + _ = ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) return nil }) } @@ -234,7 +234,7 @@ func (bh *benchHolder) DeleteNonexist(b *testing.B) { // DeleteNonexist can run in parallel. func (bh *benchHolder) DeleteNonexist1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(bd.infoHashes[0], bd.peers[i%peersCount]) + _ = ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) return nil }) } @@ -245,7 +245,7 @@ func (bh *benchHolder) DeleteNonexist1k(b *testing.B) { // DeleteNonexist1kInfoHash can run in parallel. func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[0]) + _ = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) return nil }) } @@ -256,7 +256,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) { // DeleteNonexist1kInfoHash1k can run in parallel. func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + _ = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return nil }) } @@ -267,7 +267,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) { // GradNonexist can run in parallel. func (bh *benchHolder) GradNonexist(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[0]) + _ = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[0]) return nil }) } @@ -278,7 +278,7 @@ func (bh *benchHolder) GradNonexist(b *testing.B) { // GradNonexist1k can run in parallel. func (bh *benchHolder) GradNonexist1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[i%peersCount]) + _ = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) return nil }) } @@ -289,7 +289,7 @@ func (bh *benchHolder) GradNonexist1k(b *testing.B) { // GradNonexist1kInfoHash can run in parallel. func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[0]) + _ = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) return nil }) } @@ -301,7 +301,7 @@ func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) { // GradNonexist1kInfoHash1k can run in parallel. func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + _ = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return nil }) } @@ -313,15 +313,15 @@ func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) { // PutGradDelete can not run in parallel. func (bh *benchHolder) PutGradDelete(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(bd.infoHashes[0], bd.peers[0]) + err := ps.PutLeecher(nil, bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - err = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[0]) + err = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[0]) + return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) }) } @@ -330,15 +330,15 @@ func (bh *benchHolder) PutGradDelete(b *testing.B) { // PutGradDelete1k can not run in parallel. func (bh *benchHolder) PutGradDelete1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(bd.infoHashes[0], bd.peers[i%peersCount]) + err := ps.PutLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - err = ps.GraduateLeecher(bd.infoHashes[0], bd.peers[i%peersCount]) + err = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - return ps.DeleteSeeder(bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -348,15 +348,15 @@ func (bh *benchHolder) PutGradDelete1k(b *testing.B) { // PutGradDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(bd.infoHashes[i%ihCount], bd.peers[0]) + err := ps.PutLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - err = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[0]) + err = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -366,15 +366,15 @@ func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) { // PutGradDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutGradDelete1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.GraduateLeecher(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.DeleteSeeder(bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -385,9 +385,9 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error { for i, peer := range bd.peers { var err error if i < l/2 { - err = ps.PutLeecher(ih, peer) + err = ps.PutLeecher(nil, ih, peer) } else { - err = ps.PutSeeder(ih, peer) + err = ps.PutSeeder(nil, ih, peer) } if err != nil { return err @@ -404,7 +404,7 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error { // AnnounceLeecher can run in parallel. func (bh *benchHolder) AnnounceLeecher(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(nil, bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -415,7 +415,7 @@ func (bh *benchHolder) AnnounceLeecher(b *testing.B) { // AnnounceLeecher1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(nil, bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -426,7 +426,7 @@ func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { // AnnounceSeeder can run in parallel. func (bh *benchHolder) AnnounceSeeder(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(nil, bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -437,7 +437,7 @@ func (bh *benchHolder) AnnounceSeeder(b *testing.B) { // AnnounceSeeder1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(nil, bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -448,7 +448,7 @@ func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) { // ScrapeSwarm can run in parallel. func (bh *benchHolder) ScrapeSwarm(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - ps.ScrapeSwarm(bd.infoHashes[0]) + ps.ScrapeSwarm(nil, bd.infoHashes[0]) return nil }) } @@ -458,7 +458,7 @@ func (bh *benchHolder) ScrapeSwarm(b *testing.B) { // ScrapeSwarm1kInfoHash can run in parallel. func (bh *benchHolder) ScrapeSwarm1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - ps.ScrapeSwarm(bd.infoHashes[i%ihCount]) + ps.ScrapeSwarm(nil, bd.infoHashes[i%ihCount]) return nil }) } diff --git a/storage/test/storage_test_base.go b/storage/test/storage_test_base.go index 3e746c0..a14b3f4 100644 --- a/storage/test/storage_test_base.go +++ b/storage/test/storage_test_base.go @@ -40,7 +40,7 @@ type hashPeer struct { func (th *testHolder) DeleteSeeder(t *testing.T) { for _, c := range testData { - err := th.st.DeleteSeeder(c.ih, c.peer) + err := th.st.DeleteSeeder(nil, c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -54,14 +54,14 @@ func (th *testHolder) PutLeecher(t *testing.T) { if c.peer.Addr().Is6() { peer = v6Peer } - err := th.st.PutLeecher(c.ih, peer) + err := th.st.PutLeecher(nil, c.ih, peer) require.Nil(t, err) } } func (th *testHolder) DeleteLeecher(t *testing.T) { for _, c := range testData { - err := th.st.DeleteLeecher(c.ih, c.peer) + err := th.st.DeleteLeecher(nil, c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -71,7 +71,7 @@ func (th *testHolder) DeleteLeecher(t *testing.T) { func (th *testHolder) AnnouncePeers(t *testing.T) { for _, c := range testData { - _, err := th.st.AnnouncePeers(c.ih, false, 50, c.peer.Addr().Is6()) + _, err := th.st.AnnouncePeers(nil, c.ih, false, 50, c.peer.Addr().Is6()) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -81,7 +81,7 @@ func (th *testHolder) AnnouncePeers(t *testing.T) { func (th *testHolder) ScrapeSwarm(t *testing.T) { for _, c := range testData { - l, s, n := th.st.ScrapeSwarm(c.ih) + l, s, n := th.st.ScrapeSwarm(nil, c.ih) require.Equal(t, uint32(0), s) require.Equal(t, uint32(0), l) require.Equal(t, uint32(0), n) @@ -91,26 +91,26 @@ func (th *testHolder) ScrapeSwarm(t *testing.T) { func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { isV6 := c.peer.Addr().Is6() - err := th.st.PutLeecher(c.ih, c.peer) + err := th.st.PutLeecher(nil, c.ih, c.peer) require.Nil(t, err) - peers, err := th.st.AnnouncePeers(c.ih, true, 50, isV6) + peers, err := th.st.AnnouncePeers(nil, c.ih, true, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // non-seeder announce should still return the leecher - peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) - l, s, _ := th.st.ScrapeSwarm(c.ih) + l, s, _ := th.st.ScrapeSwarm(nil, c.ih) require.Equal(t, uint32(2), l) require.Equal(t, uint32(0), s) - err = th.st.DeleteLeecher(c.ih, c.peer) + err = th.st.DeleteLeecher(nil, c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(c.ih, true, 50, isV6) + peers, err = th.st.AnnouncePeers(nil, c.ih, true, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -122,22 +122,22 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { isV6 := c.peer.Addr().Is6() - err := th.st.PutSeeder(c.ih, c.peer) + err := th.st.PutSeeder(nil, c.ih, c.peer) require.Nil(t, err) // Should be leecher to see the seeder - peers, err := th.st.AnnouncePeers(c.ih, false, 50, isV6) + peers, err := th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) - l, s, _ := th.st.ScrapeSwarm(c.ih) + l, s, _ := th.st.ScrapeSwarm(nil, c.ih) require.Equal(t, uint32(1), l) require.Equal(t, uint32(1), s) - err = th.st.DeleteSeeder(c.ih, c.peer) + err = th.st.DeleteSeeder(nil, c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -153,44 +153,44 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { if isV6 { peer = v6Peer } - err := th.st.PutLeecher(c.ih, c.peer) + err := th.st.PutLeecher(nil, c.ih, c.peer) require.Nil(t, err) - err = th.st.GraduateLeecher(c.ih, c.peer) + err = th.st.GraduateLeecher(nil, c.ih, c.peer) require.Nil(t, err) // Has to be leecher to see the graduated seeder - peers, err := th.st.AnnouncePeers(c.ih, false, 50, isV6) + peers, err := th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // Deleting the Peer as a Leecher should have no effect - err = th.st.DeleteLeecher(c.ih, c.peer) + err = th.st.DeleteLeecher(nil, c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } require.Nil(t, err) // Verify it's still there - peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // Clean up - err = th.st.DeleteLeecher(c.ih, peer) + err = th.st.DeleteLeecher(nil, c.ih, peer) require.Nil(t, err) // Test ErrDNE for missing leecher - err = th.st.DeleteLeecher(c.ih, peer) + err = th.st.DeleteLeecher(nil, c.ih, peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } require.Nil(t, err) - err = th.st.DeleteSeeder(c.ih, c.peer) + err = th.st.DeleteSeeder(nil, c.ih, c.peer) require.Nil(t, err) - err = th.st.DeleteSeeder(c.ih, c.peer) + err = th.st.DeleteSeeder(nil, c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -200,35 +200,35 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) CustomPutContainsLoadDelete(t *testing.T) { for _, c := range testData { - err := th.st.Put(kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())}) + err := th.st.Put(nil, kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())}) require.Nil(t, err) // check if exist in ctx we put - contains, err := th.st.Contains(kvStoreCtx, c.peer.String()) + contains, err := th.st.Contains(nil, kvStoreCtx, c.peer.String()) require.Nil(t, err) require.True(t, contains) // check if not exist in another ctx - contains, err = th.st.Contains("", c.peer.String()) + contains, err = th.st.Contains(nil, "", c.peer.String()) require.Nil(t, err) require.False(t, contains) // check value and type in ctx we put - out, err := th.st.Load(kvStoreCtx, c.peer.String()) + out, err := th.st.Load(nil, kvStoreCtx, c.peer.String()) require.Nil(t, err) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) require.Equal(t, c.ih, ih) // check value is nil in another ctx - dummy, err := th.st.Load("", c.peer.String()) + dummy, err := th.st.Load(nil, "", c.peer.String()) require.Nil(t, err) require.Nil(t, dummy) - err = th.st.Delete(kvStoreCtx, c.peer.String()) + err = th.st.Delete(nil, kvStoreCtx, c.peer.String()) require.Nil(t, err) - contains, err = th.st.Contains("", c.peer.String()) + contains, err = th.st.Contains(nil, "", c.peer.String()) require.Nil(t, err) require.False(t, contains) } @@ -245,29 +245,29 @@ func (th *testHolder) CustomBulkPutContainsLoadDelete(t *testing.T) { Value: []byte(c.ih.RawString()), }) } - err := th.st.Put(kvStoreCtx, pairs...) + err := th.st.Put(nil, kvStoreCtx, pairs...) require.Nil(t, err) // check if exist in ctx we put for _, k := range keys { - contains, err := th.st.Contains(kvStoreCtx, k) + contains, err := th.st.Contains(nil, kvStoreCtx, k) require.Nil(t, err) require.True(t, contains) } // check value and type in ctx we put for _, p := range pairs { - out, _ := th.st.Load(kvStoreCtx, p.Key) + out, _ := th.st.Load(nil, kvStoreCtx, p.Key) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) require.Equal(t, p.Value, []byte(ih.RawString())) } - err = th.st.Delete(kvStoreCtx, keys...) + err = th.st.Delete(nil, kvStoreCtx, keys...) require.Nil(t, err) for _, k := range keys { - contains, err := th.st.Contains(kvStoreCtx, k) + contains, err := th.st.Contains(nil, kvStoreCtx, k) require.Nil(t, err) require.False(t, contains) } From d8d6f2cf4ba769efe47ba7d2112afd4f1fc66feb Mon Sep 17 00:00:00 2001 From: "Lawrence, Rendall" Date: Tue, 1 Nov 2022 17:02:24 +0300 Subject: [PATCH 2/3] refactor code for context fallthrough --- bittorrent/params.go | 22 +++ cmd/mochi/main.go | 2 +- cmd/mochi/server.go | 85 +++++++++--- dist/example_config.yaml | 12 +- frontend/frontend.go | 4 +- frontend/http/frontend.go | 56 +++----- frontend/udp/frontend.go | 47 +++---- frontend/udp/frontend_test.go | 5 +- frontend/udp/writer.go | 20 ++- middleware/jwt/jwt.go | 8 +- middleware/logic.go | 23 ---- .../container/directory/directory.go | 23 ++-- middleware/torrentapproval/torrentapproval.go | 10 +- pkg/metrics/server.go | 12 +- pkg/stop/stop.go | 125 ------------------ storage/keydb/storage.go | 28 +--- storage/memory/storage.go | 33 ++--- storage/pg/storage.go | 51 ++++--- storage/redis/storage.go | 43 ++---- storage/storage.go | 31 ++--- storage/test/storage_bench.go | 86 ++++++------ storage/test/storage_test_base.go | 79 +++++------ 22 files changed, 312 insertions(+), 493 deletions(-) delete mode 100644 pkg/stop/stop.go diff --git a/bittorrent/params.go b/bittorrent/params.go index 3a08b64..819bfe5 100644 --- a/bittorrent/params.go +++ b/bittorrent/params.go @@ -1,6 +1,7 @@ package bittorrent import ( + "context" "errors" "net/url" "strconv" @@ -84,6 +85,27 @@ func (rp RouteParams) ByName(name string) string { return "" } +// InjectRouteParamsToContext returns new context with specified RouteParams placed in +// RouteParamsKey key +func InjectRouteParamsToContext(ctx context.Context, rp RouteParams) context.Context { + if rp == nil { + rp = RouteParams{} + } + return context.WithValue(ctx, RouteParamsKey, rp) +} + +// RemapRouteParamsToBgContext returns new context with context.Background parent +// and copied RouteParams from inCtx +func RemapRouteParamsToBgContext(inCtx context.Context) context.Context { + rp, isOk := inCtx.Value(RouteParamsKey).(RouteParams) + if !isOk { + rp = RouteParams{} + } else { + logger.Warn().Msg("unable to fetch route parameters, probably jammed context") + } + return context.WithValue(context.Background(), RouteParamsKey, rp) +} + // ParseURLData parses a request URL or UDP URLData as defined in BEP41. // It expects a concatenated string of the request's path and query parts as // defined in RFC 3986. As both the udp: and http: scheme used by BitTorrent diff --git a/cmd/mochi/main.go b/cmd/mochi/main.go index 6b26686..f8564e6 100644 --- a/cmd/mochi/main.go +++ b/cmd/mochi/main.go @@ -37,7 +37,7 @@ func main() { if err := s.Run(*configPath); err != nil { log.Fatal("unable to start server: ", err) } - defer s.Dispose() + defer s.Shutdown() ch := make(chan os.Signal, 2) signal.Notify(ch, os.Interrupt, syscall.SIGTERM) <-ch diff --git a/cmd/mochi/server.go b/cmd/mochi/server.go index d3dc243..94a618c 100644 --- a/cmd/mochi/server.go +++ b/cmd/mochi/server.go @@ -3,20 +3,23 @@ package main import ( "errors" "fmt" + "io" + "sync" + + "github.com/rs/zerolog" "github.com/sot-tech/mochi/frontend" "github.com/sot-tech/mochi/middleware" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) // Server represents the state of a running instance. type Server struct { - storage storage.PeerStorage - logic *middleware.Logic - sg *stop.Group + frontends []io.Closer + hooks []io.Closer + storage storage.PeerStorage } // Run begins an instance of Conf. @@ -28,11 +31,9 @@ func (r *Server) Run(configFilePath string) error { return fmt.Errorf("failed to read config: %w", err) } - r.sg = stop.NewGroup() - if len(cfg.MetricsAddr) > 0 { log.Info().Str("address", cfg.MetricsAddr).Msg("starting metrics server") - r.sg.Add(metrics.NewServer(cfg.MetricsAddr)) + r.frontends = append(r.frontends, metrics.NewServer(cfg.MetricsAddr)) } else { log.Info().Msg("metrics disabled because of empty address") } @@ -46,17 +47,30 @@ func (r *Server) Run(configFilePath string) error { if err != nil { return fmt.Errorf("failed to configure pre-hooks: %w", err) } + + for _, h := range preHooks { + if c, isOk := h.(io.Closer); isOk { + r.hooks = append(r.hooks, c) + } + } + postHooks, err := middleware.NewHooks(cfg.PostHooks, r.storage) if err != nil { return fmt.Errorf("failed to configure post-hooks: %w", err) } + for _, h := range postHooks { + if c, isOk := h.(io.Closer); isOk { + r.hooks = append(r.hooks, c) + } + } + if len(cfg.Frontends) > 0 { var fs []frontend.Frontend - r.logic = middleware.NewLogic(cfg.AnnounceInterval, cfg.MinAnnounceInterval, r.storage, preHooks, postHooks) - if fs, err = frontend.NewFrontends(cfg.Frontends, r.logic); err == nil { + logic := middleware.NewLogic(cfg.AnnounceInterval, cfg.MinAnnounceInterval, r.storage, preHooks, postHooks) + if fs, err = frontend.NewFrontends(cfg.Frontends, logic); err == nil { for _, f := range fs { - r.sg.Add(f) + r.frontends = append(r.frontends, f) } } else { err = fmt.Errorf("failed to configure frontends: %w", err) @@ -68,21 +82,50 @@ func (r *Server) Run(configFilePath string) error { return err } -// Dispose shuts down an instance of Server. -func (r *Server) Dispose() { +// Shutdown shuts down an instance of Server. +func (r *Server) Shutdown() { log.Debug().Msg("stopping frontends and metrics server") - if errs := r.sg.Stop().Wait(); len(errs) > 0 { - log.Error().Errs("errors", errs).Msg("error occurred while shutting down frontends") - } + closeGroup(r.frontends).Msg("frontends stopped") - log.Debug().Msg("stopping logic") - if errs := r.logic.Stop().Wait(); len(errs) > 0 { - log.Error().Errs("errors", errs).Msg("error occurred while shutting down middlewares") - } + log.Debug().Msg("stopping middleware") + closeGroup(r.hooks).Msg("hooks stopped") log.Debug().Msg("stopping peer store") - if errs := r.storage.Stop().Wait(); len(errs) != 0 { - log.Error().Errs("errors", errs).Msg("error occurred while shutting down peer store") + var err error + if r.storage != nil { + err = r.storage.Close() + } else { + err = errors.New("peer store not configured") } + log.Err(err).Msg("peer store stopped") log.Close() } + +func closeGroup(cls []io.Closer) (e *zerolog.Event) { + l := len(cls) + errs := make([]error, l) + wg := sync.WaitGroup{} + wg.Add(l) + for i, cl := range cls { + go func(i int, cl io.Closer) { + defer wg.Done() + if e := cl.Close(); e != nil { + errs[i] = e + } + }(i, cl) + } + wg.Wait() + nnErrs := make([]error, 0, l) + for _, e := range errs { + if e != nil { + nnErrs = append(nnErrs, e) + } + } + var evt *zerolog.Event + if len(nnErrs) > 0 { + evt = log.Error().Errs("errors", nnErrs) + } else { + evt = log.Info() + } + return evt +} diff --git a/dist/example_config.yaml b/dist/example_config.yaml index 3713af9..3e92b1c 100644 --- a/dist/example_config.yaml +++ b/dist/example_config.yaml @@ -14,13 +14,15 @@ min_announce_interval: 15m # /debug/pprof/{cmdline,profile,symbol,trace} serves profiles in the pprof format metrics_addr: "0.0.0.0:6880" +# This block defines named configurations of network listeners (frontends). +# At least one listener should be provided. frontends: # This block defines configuration for the tracker's HTTP interface. # If you do not wish to run this, delete this section. - name: http config: # The network interface that will bind to an HTTP server for serving - # BitTorrent traffic. Remove this to disable the non-TLS listener. + # BitTorrent traffic. addr: "0.0.0.0:6969" # Mark this frontend as HTTPS server for serving @@ -32,6 +34,8 @@ frontends: tls_key_path: "" # Enable SO_REUSEPORT to allow starting multiple mochi instances with the same HTTP(S) port. + # You can also use this parameter to define two or mote listeners for the same address and port, + # and (possibly) increase throughput. reuse_port: true # The timeout durations for HTTP requests. @@ -97,8 +101,8 @@ frontends: # The maximum number of infohashes that can be scraped in one request. max_scrape_infohashes: 50 - # This block defines configuration for the tracker's UDP interface. - # If you do not wish to run this, delete this section. + # This block defines configuration for the tracker's UDP interface. + # If you do not wish to run this, delete this section. - name: udp config: # The network interface that will bind to a UDP server for serving @@ -106,6 +110,8 @@ frontends: addr: "0.0.0.0:6969" # Enable SO_REUSEPORT to allow starting multiple mochi instances with the same UDP port. + # You can also use this parameter to define two or mote listeners for the same address and port, + # and (a little) increase throughput (faster queue processing because of multiple 'workers'). reuse_port: true # The leeway for a timestamp on a connection ID. diff --git a/frontend/frontend.go b/frontend/frontend.go index e871d96..b00c1e1 100644 --- a/frontend/frontend.go +++ b/frontend/frontend.go @@ -4,12 +4,12 @@ package frontend import ( "fmt" + "io" "sync" "github.com/sot-tech/mochi/middleware" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" ) var ( @@ -46,7 +46,7 @@ func RegisterBuilder(name string, b Builder) { // Frontend dummy interface for bittorrent frontends type Frontend interface { - stop.Stopper + io.Closer } // NewFrontends is a utility function for initializing Frontend-s in bulk. diff --git a/frontend/http/frontend.go b/frontend/http/frontend.go index fe93992..4e39661 100644 --- a/frontend/http/frontend.go +++ b/frontend/http/frontend.go @@ -18,7 +18,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" ) var ( @@ -81,17 +80,17 @@ type httpFE struct { } // NewFrontend builds and starts http bittorrent frontend from provided configuration -func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend, err error) { +func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend, error) { var cfg Config + var err error if err = c.Unmarshal(&cfg); err != nil { - return + return nil, err } if cfg, err = cfg.Validate(); err != nil { - return + return nil, err } if len(cfg.AnnounceRoutes) < 1 || len(cfg.ScrapeRoutes) < 1 { - err = errRoutesNotProvided - return + return nil, errRoutesNotProvided } f := &httpFE{ @@ -110,7 +109,7 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend if cfg.UseTLS { var cert tls.Certificate if cert, err = tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath); err != nil { - return + return nil, err } f.srv.TLSConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -150,34 +149,17 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (_ frontend.Frontend return f, nil } -// Stop provides a thread-safe way to shut down a currently running Frontend. -func (f *httpFE) Stop() stop.Result { - c := make(stop.Channel) - if f.srv != nil { - go func() { - c.Done(f.srv.Shutdown(context.Background())) - }() - } - return c.Result() +// Close provides a thread-safe way to shut down a currently running Frontend. +func (f *httpFE) Close() error { + return f.srv.Shutdown(context.Background()) } -func injectRouteParamsToContext(ctx context.Context, ps httprouter.Params) context.Context { - rp := bittorrent.RouteParams{} - for _, p := range ps { - rp = append(rp, bittorrent.RouteParam{Key: p.Key, Value: p.Value}) +func httpParamsToRouteParams(in httprouter.Params) (out bittorrent.RouteParams) { + out = make([]bittorrent.RouteParam, 0, len(in)) + for _, p := range in { + out = append(out, bittorrent.RouteParam{Key: p.Key, Value: p.Value}) } - return context.WithValue(ctx, bittorrent.RouteParamsKey, rp) -} - -func remapRouteParamsToBgContext(inCtx context.Context) context.Context { - rp, isOk := inCtx.Value(bittorrent.RouteParamsKey).(bittorrent.RouteParams) - if !isOk { - rp = bittorrent.RouteParams{} - } else { - logger.Warn().Msg("unable to fetch route parameters, probably jammed context") - } - // FIXME: cancelable context - return context.WithValue(context.TODO(), bittorrent.RouteParamsKey, rp) + return } // announceRoute parses and responds to an Announce. @@ -200,7 +182,7 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro } addr = req.GetFirst() - ctx := injectRouteParamsToContext(r.Context(), ps) + ctx := bittorrent.InjectRouteParamsToContext(r.Context(), httpParamsToRouteParams(ps)) ctx, resp, err := f.logic.HandleAnnounce(ctx, req) if err != nil { WriteError(w, err) @@ -215,8 +197,7 @@ func (f *httpFE) announceRoute(w http.ResponseWriter, r *http.Request, ps httpro } // next actions are background and should not be canceled after http writer closed - ctx = remapRouteParamsToBgContext(ctx) - + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go f.logic.AfterAnnounce(ctx, req, resp) } @@ -239,7 +220,7 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout } addr = req.GetFirst() - ctx := injectRouteParamsToContext(r.Context(), ps) + ctx := bittorrent.InjectRouteParamsToContext(r.Context(), httpParamsToRouteParams(ps)) ctx, resp, err := f.logic.HandleScrape(ctx, req) if err != nil { WriteError(w, err) @@ -254,8 +235,7 @@ func (f *httpFE) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httprout } // next actions are background and should not be canceled after http writer closed - ctx = remapRouteParamsToBgContext(ctx) - + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go f.logic.AfterScrape(ctx, req, resp) } diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index 6dadd61..97c960e 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -20,7 +20,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" ) @@ -81,6 +80,8 @@ type udpFE struct { logic *middleware.Logic maxClockSkew time.Duration collectTimings bool + ctxCancel context.CancelFunc + onceCloser sync.Once frontend.ParseOptions } @@ -109,43 +110,37 @@ func NewFrontend(c conf.MapConfig, logic *middleware.Logic) (frontend.Frontend, } if f.socket, err = cfg.ListenUDP(); err == nil { + var ctx context.Context + ctx, f.ctxCancel = context.WithCancel(context.Background()) f.wg.Add(1) - go func() { - if err := f.serve(); err != nil { + go func(ctx context.Context) { + if err := f.serve(ctx); err != nil { logger.Fatal().Err(err).Msg("server failed") } - }() + }(ctx) } return f, err } -// Stop provides a thread-safe way to shut down a currently running Frontend. -func (t *udpFE) Stop() stop.Result { - select { - case <-t.closing: - return stop.AlreadyStopped - default: - } - - c := make(stop.Channel) - go func() { +// Close provides a thread-safe way to shut down a currently running Frontend. +func (t *udpFE) Close() (err error) { + t.onceCloser.Do(func() { close(t.closing) - var err error if t.socket != nil { + t.ctxCancel() _ = t.socket.SetReadDeadline(time.Now()) t.wg.Wait() err = t.socket.Close() } - c.Done(err) - }() + }) - return c.Result() + return } // serve blocks while listening and serving UDP BitTorrent requests // until Stop() is called or an error is returned. -func (t *udpFE) serve() error { +func (t *udpFE) serve(ctx context.Context) error { pool := bytepool.NewBytePool(2048) defer t.wg.Done() @@ -188,7 +183,7 @@ func (t *udpFE) serve() error { if t.collectTimings && metrics.Enabled() { start = time.Now() } - action, err := t.handleRequest( + action, err := t.handleRequest(ctx, Request{(*buffer)[:n], addr}, ResponseWriter{t.socket, addrPort}, ) @@ -218,7 +213,7 @@ func (w ResponseWriter) Write(b []byte) (int, error) { } // handleRequest parses and responds to a UDP Request. -func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, err error) { +func (t *udpFE) handleRequest(ctx context.Context, r Request, w ResponseWriter) (actionName string, err error) { if len(r.Packet) < 16 { // Malformed, no client packets are less than 16 bytes. // We explicitly return nothing in case this is a DoS attempt. @@ -265,9 +260,9 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e return } - var ctx context.Context var resp *bittorrent.AnnounceResponse - ctx, resp, err = t.logic.HandleAnnounce(context.Background(), req) + ctx := bittorrent.InjectRouteParamsToContext(ctx, bittorrent.RouteParams{}) + ctx, resp, err = t.logic.HandleAnnounce(ctx, req) if err != nil { WriteError(w, txID, err) return @@ -275,6 +270,7 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e WriteAnnounce(w, txID, resp, actionID == announceV6ActionID, r.IP.Is6()) + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go t.logic.AfterAnnounce(ctx, req, resp) case scrapeActionID: @@ -287,9 +283,9 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e return } - var ctx context.Context var resp *bittorrent.ScrapeResponse - ctx, resp, err = t.logic.HandleScrape(context.Background(), req) + ctx := bittorrent.InjectRouteParamsToContext(ctx, bittorrent.RouteParams{}) + ctx, resp, err = t.logic.HandleScrape(ctx, req) if err != nil { WriteError(w, txID, err) return @@ -297,6 +293,7 @@ func (t *udpFE) handleRequest(r Request, w ResponseWriter) (actionName string, e WriteScrape(w, txID, resp) + ctx = bittorrent.RemapRouteParamsToBgContext(ctx) go t.logic.AfterScrape(ctx, req, resp) default: diff --git a/frontend/udp/frontend_test.go b/frontend/udp/frontend_test.go index 8ae8b8c..0313e9b 100644 --- a/frontend/udp/frontend_test.go +++ b/frontend/udp/frontend_test.go @@ -29,8 +29,7 @@ func TestStartStopRaceIssue437(t *testing.T) { if err != nil { t.Fatal(err) } - errC := fe.Stop() - if errs := <-errC; len(errs) != 0 { - t.Fatal(errs) + if err = fe.Close(); err != nil { + t.Fatal(err) } } diff --git a/frontend/udp/writer.go b/frontend/udp/writer.go index c6d8e34..5ed0efe 100644 --- a/frontend/udp/writer.go +++ b/frontend/udp/writer.go @@ -3,7 +3,6 @@ package udp import ( "encoding/binary" "errors" - "fmt" "io" "time" @@ -12,17 +11,15 @@ import ( // WriteError writes the failure reason as a null-terminated string. func WriteError(w io.Writer, txID []byte, err error) { - // If the client wasn't at fault, acknowledge it. - var clientErr bittorrent.ClientError - if !errors.Is(err, &clientErr) { - err = fmt.Errorf("internal error occurred: %w", err) - } - buf := reqRespBufferPool.Get() defer reqRespBufferPool.Put(buf) writeHeader(buf, txID, errorActionID) - _, _ = buf.WriteString(err.Error()) - _, _ = buf.WriteRune('\000') + // If the client wasn't at fault, acknowledge it. + if !errors.As(err, new(bittorrent.ClientError)) { + buf.WriteString("internal error occurred: ") + } + buf.WriteString(err.Error()) + buf.WriteByte('\000') _, _ = w.Write(buf.Bytes()) } @@ -50,7 +47,7 @@ func WriteAnnounce(w io.Writer, txID []byte, resp *bittorrent.AnnounceResponse, } for _, peer := range peers { - _, _ = buf.Write(peer.Addr().AsSlice()) + buf.Write(peer.Addr().AsSlice()) _ = binary.Write(buf, binary.BigEndian, peer.Port()) } @@ -79,8 +76,7 @@ func WriteConnectionID(w io.Writer, txID, connID []byte) { defer reqRespBufferPool.Put(buf) writeHeader(buf, txID, connectActionID) - _, _ = buf.Write(connID) - + buf.Write(connID) _, _ = w.Write(buf.Bytes()) } diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 34556ae..65c6eaf 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -19,7 +19,6 @@ import ( "github.com/sot-tech/mochi/middleware" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -110,13 +109,12 @@ func build(config conf.MapConfig, _ storage.PeerStorage) (h middleware.Hook, err return } -func (h *hook) Stop() stop.Result { +func (h *hook) Close() error { logger.Debug().Msg("attempting to shutdown JWT middleware") - c := make(stop.Channel) if h.jwks != nil { - go h.jwks.EndBackground() + h.jwks.EndBackground() } - return c.Result() + return nil } type verifiableClaims interface { diff --git a/middleware/logic.go b/middleware/logic.go index 458c156..26495e1 100644 --- a/middleware/logic.go +++ b/middleware/logic.go @@ -6,7 +6,6 @@ import ( "time" "github.com/sot-tech/mochi/bittorrent" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -119,25 +118,3 @@ func (l *Logic) Ping(ctx context.Context) (err error) { } return } - -// Stop stops the Logic. -// -// This stops any hooks that implement stop.Stopper. -func (l *Logic) Stop() stop.Result { - stopGroup := stop.NewGroup() - for _, hook := range l.preHooks { - stoppable, ok := hook.(stop.Stopper) - if ok { - stopGroup.Add(stoppable) - } - } - - for _, hook := range l.postHooks { - stoppable, ok := hook.(stop.Stopper) - if ok { - stopGroup.Add(stoppable) - } - } - - return stopGroup.Stop() -} diff --git a/middleware/torrentapproval/container/directory/directory.go b/middleware/torrentapproval/container/directory/directory.go index 9358cbc..16714c8 100644 --- a/middleware/torrentapproval/container/directory/directory.go +++ b/middleware/torrentapproval/container/directory/directory.go @@ -17,7 +17,6 @@ import ( "github.com/sot-tech/mochi/middleware/torrentapproval/container/list" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -49,11 +48,6 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er }, watcher: nil, } - var w *dirwatch.Instance - if w, err = dirwatch.New(c.Path); err != nil { - return nil, fmt.Errorf("unable to initialize directory watch: %w", err) - } - d.watcher = w if len(d.StorageCtx) == 0 { logger.Warn(). Str("name", "StorageCtx"). @@ -62,6 +56,11 @@ func build(conf conf.MapConfig, st storage.DataStorage) (container.Container, er Msg("falling back to default configuration") d.StorageCtx = container.DefaultStorageCtxName } + var w *dirwatch.Instance + if w, err = dirwatch.New(c.Path); err != nil { + return nil, fmt.Errorf("unable to initialize directory watch: %w", err) + } + d.watcher = w go func() { for event := range d.watcher.Events { var mi *metainfo.MetaInfo @@ -124,10 +123,10 @@ type directory struct { watcher *dirwatch.Instance } -// Stop closes watching of torrent directory -func (d *directory) Stop() stop.Result { - st := make(stop.Channel) - d.watcher.Close() - st.Done() - return st.Result() +// Close closes watching of torrent directory +func (d *directory) Close() error { + if d.watcher != nil { + d.watcher.Close() + } + return nil } diff --git a/middleware/torrentapproval/torrentapproval.go b/middleware/torrentapproval/torrentapproval.go index 93ca348..977f8b1 100644 --- a/middleware/torrentapproval/torrentapproval.go +++ b/middleware/torrentapproval/torrentapproval.go @@ -5,6 +5,7 @@ package torrentapproval import ( "context" "fmt" + "io" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/middleware" @@ -17,7 +18,6 @@ import ( // import static list to enable appropriate support _ "github.com/sot-tech/mochi/middleware/torrentapproval/container/list" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" ) @@ -86,9 +86,9 @@ func (h *hook) HandleScrape(ctx context.Context, _ *bittorrent.ScrapeRequest, _ return ctx, nil } -func (h *hook) Stop() stop.Result { - if st, isOk := h.hashContainer.(stop.Stopper); isOk { - return st.Stop() +func (h *hook) Close() (err error) { + if cl, isOk := h.hashContainer.(io.Closer); isOk { + err = cl.Close() } - return stop.AlreadyStopped + return err } diff --git a/pkg/metrics/server.go b/pkg/metrics/server.go index 03a40bc..7801532 100644 --- a/pkg/metrics/server.go +++ b/pkg/metrics/server.go @@ -14,7 +14,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" ) const ( @@ -50,14 +49,9 @@ func AddressFamily(ip netip.Addr) string { } } -// Stop shuts down the server. -func (s *Server) Stop() stop.Result { - c := make(stop.Channel) - go func() { - c.Done(s.srv.Shutdown(context.Background())) - }() - - return c.Result() +// Close shuts down the server. +func (s *Server) Close() error { + return s.srv.Shutdown(context.Background()) } // NewServer creates a new instance of a Prometheus server that asynchronously diff --git a/pkg/stop/stop.go b/pkg/stop/stop.go deleted file mode 100644 index d3b29dd..0000000 --- a/pkg/stop/stop.go +++ /dev/null @@ -1,125 +0,0 @@ -// Package stop implements a pattern for shutting down a group of processes. -package stop - -import ( - "sync" -) - -// Channel is used to return zero or more errors asynchronously. Call Done() -// once to pass errors to the Channel. -type Channel chan []error - -// Result is a receive-only version of Channel. Call Wait() once to receive any -// returned errors. -type Result <-chan []error - -// Done adds zero or more errors to the Channel and closes it, indicating the -// caller has finished stopping. It should be called exactly once. -func (ch Channel) Done(errs ...error) { - if len(errs) > 0 && errs[0] != nil { - ch <- errs - } - close(ch) -} - -// Result converts a Channel to a Result. -func (ch Channel) Result() <-chan []error { - return ch -} - -// Wait blocks until Done() is called on the underlying Channel and returns any -// errors. It should be called exactly once. -func (r Result) Wait() []error { - return <-r -} - -// AlreadyStopped is a closed error channel to be used by Funcs when -// an element was already stopped. -var AlreadyStopped Result - -// AlreadyStoppedFunc is a Func that returns AlreadyStopped. -var AlreadyStoppedFunc = func() Result { return AlreadyStopped } - -func init() { - closeMe := make(Channel) - close(closeMe) - AlreadyStopped = closeMe.Result() -} - -// Stopper is an interface that allows a clean shutdown. -type Stopper interface { - // Stop returns a channel that indicates whether the stop was - // successful. - // - // The channel can either return one error or be closed. - // Closing the channel signals a clean shutdown. - // Stop() should return immediately and perform the actual shutdown in a - // separate goroutine. - Stop() Result -} - -// Func is a function that can be used to provide a clean shutdown. -type Func func() Result - -// Group is a collection of Stoppers that can be stopped all at once. -type Group struct { - stoppables []Func - sync.Mutex -} - -// NewGroup allocates a new Group. -func NewGroup() *Group { - return &Group{ - stoppables: make([]Func, 0), - } -} - -// Add appends a Stopper to the Group. -func (cg *Group) Add(toAdd Stopper) { - cg.Lock() - defer cg.Unlock() - - cg.stoppables = append(cg.stoppables, toAdd.Stop) -} - -// AddFunc appends a Func to the Group. -func (cg *Group) AddFunc(toAddFunc Func) { - cg.Lock() - defer cg.Unlock() - - cg.stoppables = append(cg.stoppables, toAddFunc) -} - -// Stop stops all members of the Group. -// -// Stopping will be done in a concurrent fashion. -// The slice of errors returned contains all errors returned by stopping the -// members. -func (cg *Group) Stop() Result { - cg.Lock() - defer cg.Unlock() - - whenDone := make(Channel) - - waitChannels := make([]Result, 0, len(cg.stoppables)) - for _, toStop := range cg.stoppables { - waitFor := toStop() - if waitFor == nil { - panic("received a nil chan from Stop") - } - waitChannels = append(waitChannels, waitFor) - } - - go func() { - var errors []error - for _, waitForMe := range waitChannels { - childErrors := waitForMe.Wait() - if len(childErrors) > 0 { - errors = append(errors, childErrors...) - } - } - whenDone.Done(errors...) - }() - - return whenDone.Result() -} diff --git a/storage/keydb/storage.go b/storage/keydb/storage.go index 7bdca86..59d65e5 100644 --- a/storage/keydb/storage.go +++ b/storage/keydb/storage.go @@ -12,14 +12,12 @@ package keydb import ( "context" "errors" - "time" "github.com/go-redis/redis/v8" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/storage" r "github.com/sot-tech/mochi/storage/redis" ) @@ -79,10 +77,7 @@ func newStore(cfg r.Config) (*store, error) { var st *store if err == nil { - st = &store{ - Connection: rs, - peerTTL: uint(cfg.PeerLifetime.Seconds()), - } + st = &store{Connection: rs, peerTTL: uint(cfg.PeerLifetime.Seconds())} } return st, err @@ -177,24 +172,3 @@ func (s *store) ScrapeSwarm(ctx context.Context, ih bittorrent.InfoHash) (uint32 Msg("scrape swarm") return s.ScrapeIH(ctx, ih, s.SCard) } - -func (*store) GCAware() bool { - return false -} - -func (*store) ScheduleGC(_, _ time.Duration) {} - -func (*store) StatisticsAware() bool { - return false -} - -func (*store) ScheduleStatisticsCollection(_ time.Duration) {} - -func (s *store) Stop() stop.Result { - c := make(stop.Channel) - if s.UniversalClient != nil { - c.Done(s.UniversalClient.Close()) - s.UniversalClient = nil - } - return c.Result() -} diff --git a/storage/memory/storage.go b/storage/memory/storage.go index a88c6ad..d62b37c 100644 --- a/storage/memory/storage.go +++ b/storage/memory/storage.go @@ -14,7 +14,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" "github.com/sot-tech/mochi/storage" ) @@ -96,8 +95,9 @@ type peerStore struct { cfg Config shards []*peerShard - closed chan struct{} - wg sync.WaitGroup + closed chan struct{} + wg sync.WaitGroup + onceCloser sync.Once } var _ storage.PeerStorage = &peerStore{} @@ -468,17 +468,9 @@ func (ds *dataStore) Delete(_ context.Context, ctx string, keys ...string) error return nil } -func (*dataStore) Preservable() bool { - return false -} +func (*dataStore) Preservable() bool { return false } -func (*peerStore) GCAware() bool { - return true -} - -func (*peerStore) StatisticsAware() bool { - return true -} +func (ds *dataStore) Close() error { return nil } // GC deletes all Peers from the PeerStorage which are older than the // cutoff time. @@ -542,12 +534,9 @@ func (*peerStore) Ping(context.Context) error { return nil } -func (ps *peerStore) Stop() stop.Result { - c := make(stop.Channel) - go func() { - if ps.closed != nil { - close(ps.closed) - } +func (ps *peerStore) Close() error { + ps.onceCloser.Do(func() { + close(ps.closed) ps.wg.Wait() // Explicitly deallocate our storage. @@ -556,9 +545,7 @@ func (ps *peerStore) Stop() stop.Result { shards[i] = &peerShard{swarms: make(map[bittorrent.InfoHash]swarm)} } ps.shards = shards + }) - c.Done() - }() - - return c.Result() + return nil } diff --git a/storage/pg/storage.go b/storage/pg/storage.go index 0f238d0..32b69c6 100644 --- a/storage/pg/storage.go +++ b/storage/pg/storage.go @@ -20,7 +20,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" "github.com/sot-tech/mochi/storage" ) @@ -76,7 +75,13 @@ func newStore(cfg Config) (storage.PeerStorage, error) { return nil, err } - return &store{Config: cfg, Pool: con, wg: sync.WaitGroup{}, closed: make(chan any)}, nil + return &store{ + Config: cfg, + Pool: con, + wg: sync.WaitGroup{}, + closed: make(chan any), + onceCloser: sync.Once{}, + }, nil } type peerQueryConf struct { @@ -215,8 +220,9 @@ func (cfg Config) Validate() (Config, error) { type store struct { Config *pgxpool.Pool - wg sync.WaitGroup - closed chan any + wg sync.WaitGroup + closed chan any + onceCloser sync.Once } func (s *store) txBatch(ctx context.Context, batch *pgx.Batch) (err error) { @@ -242,7 +248,7 @@ func (s *store) Put(ctx context.Context, storeCtx string, values ...storage.Entr default: var batch pgx.Batch for _, v := range values { - batch.Queue(s.Data.AddQuery, pgx.NamedArgs{pCtx: ctx, pKey: []byte(v.Key), pValue: v.Value}) + batch.Queue(s.Data.AddQuery, pgx.NamedArgs{pCtx: storeCtx, pKey: []byte(v.Key), pValue: v.Value}) } err = s.txBatch(ctx, &batch) } @@ -281,11 +287,10 @@ func (s *store) Preservable() bool { return true } -func (s *store) GCAware() bool { - return len(s.GCQuery) > 0 -} - func (s *store) ScheduleGC(gcInterval, peerLifeTime time.Duration) { + if len(s.GCQuery) == 0 { + return + } s.wg.Add(1) go func() { defer s.wg.Done() @@ -311,11 +316,10 @@ func (s *store) ScheduleGC(gcInterval, peerLifeTime time.Duration) { }() } -func (s *store) StatisticsAware() bool { - return len(s.InfoHashCountQuery) > 0 -} - func (s *store) ScheduleStatisticsCollection(reportInterval time.Duration) { + if len(s.InfoHashCountQuery) == 0 { + return + } s.wg.Add(1) go func() { defer s.wg.Done() @@ -357,9 +361,7 @@ func (s *store) putPeer(ctx context.Context, ih bittorrent.InfoHash, peer bittor pPort: peer.Port(), pSeeder: seeder, pV6: peer.Addr().Is6(), - } - if s.GCAware() { - args[pCreated] = timecache.Now() + pCreated: timecache.Now(), } _, err = s.Exec(ctx, s.Peer.AddQuery, args) return @@ -576,19 +578,12 @@ func (s *store) Ping(ctx context.Context) error { return err } -func (s *store) Stop() stop.Result { - c := make(stop.Channel) +func (s *store) Close() error { go func() { - if s.closed != nil { - close(s.closed) - } + close(s.closed) s.wg.Wait() - if s.Pool != nil { - logger.Info().Msg("pg exiting. mochi does not clear data in database when exiting.") - s.Close() - s.Pool = nil - } - c.Done() + logger.Info().Msg("pg exiting. mochi does not clear data in database when exiting.") + s.Pool.Close() }() - return c.Result() + return nil } diff --git a/storage/redis/storage.go b/storage/redis/storage.go index d50b4b9..daf19bc 100644 --- a/storage/redis/storage.go +++ b/storage/redis/storage.go @@ -36,7 +36,6 @@ import ( "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" "github.com/sot-tech/mochi/pkg/metrics" - "github.com/sot-tech/mochi/pkg/stop" "github.com/sot-tech/mochi/pkg/timecache" "github.com/sot-tech/mochi/storage" ) @@ -100,11 +99,7 @@ func newStore(cfg Config) (*store, error) { return nil, err } - return &store{ - Connection: rs, - closed: make(chan any), - wg: sync.WaitGroup{}, - }, nil + return &store{Connection: rs, closed: make(chan any)}, nil } // Config holds the configuration of a redis PeerStorage. @@ -287,8 +282,9 @@ type Connection struct { type store struct { Connection - closed chan any - wg sync.WaitGroup + closed chan any + wg sync.WaitGroup + onceCloser sync.Once } func (ps *store) count(key string, getLength bool) (n uint64) { @@ -611,14 +607,6 @@ func (*Connection) Preservable() bool { return true } -func (*store) GCAware() bool { - return true -} - -func (*store) StatisticsAware() bool { - return true -} - // Ping sends `PING` request to Redis server func (ps *Connection) Ping(ctx context.Context) error { return ps.UniversalClient.Ping(ctx).Err() @@ -769,21 +757,12 @@ func (ps *store) gc(cutoff time.Time) { } } -func (ps *store) Stop() stop.Result { - c := make(stop.Channel) - go func() { - if ps.closed != nil { - close(ps.closed) - } +func (ps *store) Close() (err error) { + ps.onceCloser.Do(func() { + close(ps.closed) ps.wg.Wait() - var err error - if ps.UniversalClient != nil { - logger.Info().Msg("redis exiting. mochi does not clear data in redis when exiting. mochi keys have prefix " + PrefixKey) - err = ps.UniversalClient.Close() - ps.UniversalClient = nil - } - c.Done(err) - }() - - return c.Result() + logger.Info().Msg("redis exiting. mochi does not clear data in redis when exiting. mochi keys have prefix " + PrefixKey) + err = ps.UniversalClient.Close() + }) + return } diff --git a/storage/storage.go b/storage/storage.go index bd61386..5b3aa73 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,13 +5,13 @@ package storage import ( "context" "fmt" + "io" "sync" "time" "github.com/sot-tech/mochi/bittorrent" "github.com/sot-tech/mochi/pkg/conf" "github.com/sot-tech/mochi/pkg/log" - "github.com/sot-tech/mochi/pkg/stop" ) const ( @@ -93,6 +93,7 @@ var ErrResourceDoesNotExist = bittorrent.ClientError("resource does not exist") // DataStorage is the interface, used for implementing store for arbitrary data type DataStorage interface { + io.Closer // Put used to place arbitrary k-v data with specified context // into storage. storeCtx parameter used to group data // (i.e. data only for specific middleware module: hash key, table name etc...) @@ -189,27 +190,23 @@ type PeerStorage interface { // Ping used for checks if storage is alive // (connection could be established, enough space etc.) Ping(ctx context.Context) error +} - // GCAware marks that this storage supports periodic - // peers collection - GCAware() bool - +// GarbageCollector marks that this storage supports periodic +// stale peers collection +type GarbageCollector interface { // ScheduleGC used to delete stale data, such as timed out seeders/leechers. // Note: implementation must create subroutine by itself ScheduleGC(gcInterval, peerLifeTime time.Duration) +} - // StatisticsAware marks that this storage supports periodic - // statistics collection - StatisticsAware() bool - +// StatisticsCollector marks that this storage supports periodic +// statistics collection +type StatisticsCollector interface { // ScheduleStatisticsCollection used to receive statistics information about hashes, // seeders and leechers count. // Note: implementation must create subroutine by itself ScheduleStatisticsCollection(reportInterval time.Duration) - - // Stopper is an interface that expects a Stop method to stop the PeerStorage. - // For more details see the documentation in the stop package. - stop.Stopper } // RegisterDriver makes a Driver available by the provided name. @@ -256,27 +253,27 @@ func NewStorage(cfg conf.NamedMapConfig) (ps PeerStorage, err error) { return } - if gc := ps.GCAware(); gc { + if gc, isOk := ps.(GarbageCollector); isOk { gcInterval, peerTTL := c.sanitizeGCConfig() logger.Info(). Str("name", cfg.Name). Dur("gcInterval", gcInterval). Dur("peerTTL", peerTTL). Msg("scheduling GC") - ps.ScheduleGC(gcInterval, peerTTL) + gc.ScheduleGC(gcInterval, peerTTL) } else { logger.Debug(). Str("name", cfg.Name). Msg("storage does not support GC") } - if st := ps.StatisticsAware(); st { + if st, isOk := ps.(StatisticsCollector); isOk { if statInterval := c.sanitizeStatisticsConfig(); statInterval > 0 { logger.Info(). Str("name", cfg.Name). Dur("statInterval", statInterval). Msg("scheduling statistics collection") - ps.ScheduleStatisticsCollection(statInterval) + st.ScheduleStatisticsCollection(statInterval) } else { logger.Info().Str("name", cfg.Name).Msg("statistics collection disabled because of zero reporting interval") } diff --git a/storage/test/storage_bench.go b/storage/test/storage_bench.go index a8023d0..fa2765a 100644 --- a/storage/test/storage_bench.go +++ b/storage/test/storage_bench.go @@ -3,6 +3,7 @@ package test import ( + "context" "math/rand" "net" "net/netip" @@ -100,9 +101,8 @@ func (bh *benchHolder) runBenchmark(b *testing.B, parallel bool, sf benchSetupFu } b.StopTimer() - errChan := ps.Stop() - for err := range errChan { - b.Fatal(err) + if err := ps.Close(); err != nil { + b.Fatal() } } @@ -125,7 +125,7 @@ func (bh *benchHolder) Nop(b *testing.B) { // Put can run in parallel. func (bh *benchHolder) Put(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[0]) + return ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) }) } @@ -135,7 +135,7 @@ func (bh *benchHolder) Put(b *testing.B) { // Put1k can run in parallel. func (bh *benchHolder) Put1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -145,7 +145,7 @@ func (bh *benchHolder) Put1k(b *testing.B) { // Put1kInfoHash can run in parallel. func (bh *benchHolder) Put1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - return ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -155,7 +155,7 @@ func (bh *benchHolder) Put1kInfoHash(b *testing.B) { // Put1kInfoHash1k can run in parallel. func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -166,11 +166,11 @@ func (bh *benchHolder) Put1kInfoHash1k(b *testing.B) { // PutDelete can not run in parallel. func (bh *benchHolder) PutDelete(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[0]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) }) } @@ -180,11 +180,11 @@ func (bh *benchHolder) PutDelete(b *testing.B) { // PutDelete1k can not run in parallel. func (bh *benchHolder) PutDelete1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -194,11 +194,11 @@ func (bh *benchHolder) PutDelete1k(b *testing.B) { // PutDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -208,11 +208,11 @@ func (bh *benchHolder) PutDelete1kInfoHash(b *testing.B) { // PutDelete1kInfoHash1k can not run in parallel. func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -223,7 +223,7 @@ func (bh *benchHolder) PutDelete1kInfoHash1k(b *testing.B) { // DeleteNonexist can run in parallel. func (bh *benchHolder) DeleteNonexist(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) return nil }) } @@ -234,7 +234,7 @@ func (bh *benchHolder) DeleteNonexist(b *testing.B) { // DeleteNonexist can run in parallel. func (bh *benchHolder) DeleteNonexist1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) return nil }) } @@ -245,7 +245,7 @@ func (bh *benchHolder) DeleteNonexist1k(b *testing.B) { // DeleteNonexist1kInfoHash can run in parallel. func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) return nil }) } @@ -256,7 +256,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash(b *testing.B) { // DeleteNonexist1kInfoHash1k can run in parallel. func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + _ = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return nil }) } @@ -267,7 +267,7 @@ func (bh *benchHolder) DeleteNonexist1kInfoHash1k(b *testing.B) { // GradNonexist can run in parallel. func (bh *benchHolder) GradNonexist(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[0]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0]) return nil }) } @@ -278,7 +278,7 @@ func (bh *benchHolder) GradNonexist(b *testing.B) { // GradNonexist1k can run in parallel. func (bh *benchHolder) GradNonexist1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) return nil }) } @@ -289,7 +289,7 @@ func (bh *benchHolder) GradNonexist1k(b *testing.B) { // GradNonexist1kInfoHash can run in parallel. func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) return nil }) } @@ -301,7 +301,7 @@ func (bh *benchHolder) GradNonexist1kInfoHash(b *testing.B) { // GradNonexist1kInfoHash1k can run in parallel. func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, true, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - _ = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + _ = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return nil }) } @@ -313,15 +313,15 @@ func (bh *benchHolder) GradNonexist1kInfoHash1k(b *testing.B) { // PutGradDelete can not run in parallel. func (bh *benchHolder) PutGradDelete(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[0], bd.peers[0]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[0]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[0]) }) } @@ -330,15 +330,15 @@ func (bh *benchHolder) PutGradDelete(b *testing.B) { // PutGradDelete1k can not run in parallel. func (bh *benchHolder) PutGradDelete1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[0], bd.peers[i%peersCount]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[0], bd.peers[i%peersCount]) }) } @@ -348,15 +348,15 @@ func (bh *benchHolder) PutGradDelete1k(b *testing.B) { // PutGradDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) if err != nil { return err } - return ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[0]) + return ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[0]) }) } @@ -366,15 +366,15 @@ func (bh *benchHolder) PutGradDelete1kInfoHash(b *testing.B) { // PutGradDelete1kInfoHash can not run in parallel. func (bh *benchHolder) PutGradDelete1kInfoHash1k(b *testing.B) { bh.runBenchmark(b, false, nil, func(i int, ps storage.PeerStorage, bd *benchData) error { - err := ps.PutLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err := ps.PutLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.GraduateLeecher(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.GraduateLeecher(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) if err != nil { return err } - err = ps.DeleteSeeder(nil, bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) + err = ps.DeleteSeeder(context.TODO(), bd.infoHashes[i%ihCount], bd.peers[(i*3)%peersCount]) return err }) } @@ -385,9 +385,9 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error { for i, peer := range bd.peers { var err error if i < l/2 { - err = ps.PutLeecher(nil, ih, peer) + err = ps.PutLeecher(context.TODO(), ih, peer) } else { - err = ps.PutSeeder(nil, ih, peer) + err = ps.PutSeeder(context.TODO(), ih, peer) } if err != nil { return err @@ -404,7 +404,7 @@ func putPeers(ps storage.PeerStorage, bd *benchData) error { // AnnounceLeecher can run in parallel. func (bh *benchHolder) AnnounceLeecher(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[0], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -415,7 +415,7 @@ func (bh *benchHolder) AnnounceLeecher(b *testing.B) { // AnnounceLeecher1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[i%ihCount], false, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -426,7 +426,7 @@ func (bh *benchHolder) AnnounceLeecher1kInfoHash(b *testing.B) { // AnnounceSeeder can run in parallel. func (bh *benchHolder) AnnounceSeeder(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[0], true, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -437,7 +437,7 @@ func (bh *benchHolder) AnnounceSeeder(b *testing.B) { // AnnounceSeeder1kInfoHash can run in parallel. func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - _, err := ps.AnnouncePeers(nil, bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6()) + _, err := ps.AnnouncePeers(context.TODO(), bd.infoHashes[i%ihCount], true, 50, bd.peers[0].Addr().Is6()) return err }) } @@ -448,7 +448,7 @@ func (bh *benchHolder) AnnounceSeeder1kInfoHash(b *testing.B) { // ScrapeSwarm can run in parallel. func (bh *benchHolder) ScrapeSwarm(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - ps.ScrapeSwarm(nil, bd.infoHashes[0]) + ps.ScrapeSwarm(context.TODO(), bd.infoHashes[0]) return nil }) } @@ -458,7 +458,7 @@ func (bh *benchHolder) ScrapeSwarm(b *testing.B) { // ScrapeSwarm1kInfoHash can run in parallel. func (bh *benchHolder) ScrapeSwarm1kInfoHash(b *testing.B) { bh.runBenchmark(b, true, putPeers, func(i int, ps storage.PeerStorage, bd *benchData) error { - ps.ScrapeSwarm(nil, bd.infoHashes[i%ihCount]) + ps.ScrapeSwarm(context.TODO(), bd.infoHashes[i%ihCount]) return nil }) } diff --git a/storage/test/storage_test_base.go b/storage/test/storage_test_base.go index a14b3f4..a360965 100644 --- a/storage/test/storage_test_base.go +++ b/storage/test/storage_test_base.go @@ -3,6 +3,7 @@ package test import ( + "context" "errors" "testing" @@ -40,7 +41,7 @@ type hashPeer struct { func (th *testHolder) DeleteSeeder(t *testing.T) { for _, c := range testData { - err := th.st.DeleteSeeder(nil, c.ih, c.peer) + err := th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -54,14 +55,14 @@ func (th *testHolder) PutLeecher(t *testing.T) { if c.peer.Addr().Is6() { peer = v6Peer } - err := th.st.PutLeecher(nil, c.ih, peer) + err := th.st.PutLeecher(context.TODO(), c.ih, peer) require.Nil(t, err) } } func (th *testHolder) DeleteLeecher(t *testing.T) { for _, c := range testData { - err := th.st.DeleteLeecher(nil, c.ih, c.peer) + err := th.st.DeleteLeecher(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -71,7 +72,7 @@ func (th *testHolder) DeleteLeecher(t *testing.T) { func (th *testHolder) AnnouncePeers(t *testing.T) { for _, c := range testData { - _, err := th.st.AnnouncePeers(nil, c.ih, false, 50, c.peer.Addr().Is6()) + _, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, c.peer.Addr().Is6()) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -81,7 +82,7 @@ func (th *testHolder) AnnouncePeers(t *testing.T) { func (th *testHolder) ScrapeSwarm(t *testing.T) { for _, c := range testData { - l, s, n := th.st.ScrapeSwarm(nil, c.ih) + l, s, n := th.st.ScrapeSwarm(context.TODO(), c.ih) require.Equal(t, uint32(0), s) require.Equal(t, uint32(0), l) require.Equal(t, uint32(0), n) @@ -91,26 +92,26 @@ func (th *testHolder) ScrapeSwarm(t *testing.T) { func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { isV6 := c.peer.Addr().Is6() - err := th.st.PutLeecher(nil, c.ih, c.peer) + err := th.st.PutLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) - peers, err := th.st.AnnouncePeers(nil, c.ih, true, 50, isV6) + peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, true, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // non-seeder announce should still return the leecher - peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) - l, s, _ := th.st.ScrapeSwarm(nil, c.ih) + l, s, _ := th.st.ScrapeSwarm(context.TODO(), c.ih) require.Equal(t, uint32(2), l) require.Equal(t, uint32(0), s) - err = th.st.DeleteLeecher(nil, c.ih, c.peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(nil, c.ih, true, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, true, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -122,22 +123,22 @@ func (th *testHolder) LeecherPutAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) SeederPutAnnounceDeleteAnnounce(t *testing.T) { for _, c := range testData { isV6 := c.peer.Addr().Is6() - err := th.st.PutSeeder(nil, c.ih, c.peer) + err := th.st.PutSeeder(context.TODO(), c.ih, c.peer) require.Nil(t, err) // Should be leecher to see the seeder - peers, err := th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) - l, s, _ := th.st.ScrapeSwarm(nil, c.ih) + l, s, _ := th.st.ScrapeSwarm(context.TODO(), c.ih) require.Equal(t, uint32(1), l) require.Equal(t, uint32(1), s) - err = th.st.DeleteSeeder(nil, c.ih, c.peer) + err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) require.Nil(t, err) - peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -153,44 +154,44 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { if isV6 { peer = v6Peer } - err := th.st.PutLeecher(nil, c.ih, c.peer) + err := th.st.PutLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) - err = th.st.GraduateLeecher(nil, c.ih, c.peer) + err = th.st.GraduateLeecher(context.TODO(), c.ih, c.peer) require.Nil(t, err) // Has to be leecher to see the graduated seeder - peers, err := th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err := th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // Deleting the Peer as a Leecher should have no effect - err = th.st.DeleteLeecher(nil, c.ih, c.peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } require.Nil(t, err) // Verify it's still there - peers, err = th.st.AnnouncePeers(nil, c.ih, false, 50, isV6) + peers, err = th.st.AnnouncePeers(context.TODO(), c.ih, false, 50, isV6) require.Nil(t, err) require.True(t, containsPeer(peers, c.peer)) // Clean up - err = th.st.DeleteLeecher(nil, c.ih, peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, peer) require.Nil(t, err) // Test ErrDNE for missing leecher - err = th.st.DeleteLeecher(nil, c.ih, peer) + err = th.st.DeleteLeecher(context.TODO(), c.ih, peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } require.Nil(t, err) - err = th.st.DeleteSeeder(nil, c.ih, c.peer) + err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) require.Nil(t, err) - err = th.st.DeleteSeeder(nil, c.ih, c.peer) + err = th.st.DeleteSeeder(context.TODO(), c.ih, c.peer) if errors.Is(err, storage.ErrResourceDoesNotExist) { err = nil } @@ -200,35 +201,35 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { func (th *testHolder) CustomPutContainsLoadDelete(t *testing.T) { for _, c := range testData { - err := th.st.Put(nil, kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())}) + err := th.st.Put(context.TODO(), kvStoreCtx, storage.Entry{Key: c.peer.String(), Value: []byte(c.ih.RawString())}) require.Nil(t, err) // check if exist in ctx we put - contains, err := th.st.Contains(nil, kvStoreCtx, c.peer.String()) + contains, err := th.st.Contains(context.TODO(), kvStoreCtx, c.peer.String()) require.Nil(t, err) require.True(t, contains) // check if not exist in another ctx - contains, err = th.st.Contains(nil, "", c.peer.String()) + contains, err = th.st.Contains(context.TODO(), "", c.peer.String()) require.Nil(t, err) require.False(t, contains) // check value and type in ctx we put - out, err := th.st.Load(nil, kvStoreCtx, c.peer.String()) + out, err := th.st.Load(context.TODO(), kvStoreCtx, c.peer.String()) require.Nil(t, err) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) require.Equal(t, c.ih, ih) // check value is nil in another ctx - dummy, err := th.st.Load(nil, "", c.peer.String()) + dummy, err := th.st.Load(context.TODO(), "", c.peer.String()) require.Nil(t, err) require.Nil(t, dummy) - err = th.st.Delete(nil, kvStoreCtx, c.peer.String()) + err = th.st.Delete(context.TODO(), kvStoreCtx, c.peer.String()) require.Nil(t, err) - contains, err = th.st.Contains(nil, "", c.peer.String()) + contains, err = th.st.Contains(context.TODO(), "", c.peer.String()) require.Nil(t, err) require.False(t, contains) } @@ -245,29 +246,29 @@ func (th *testHolder) CustomBulkPutContainsLoadDelete(t *testing.T) { Value: []byte(c.ih.RawString()), }) } - err := th.st.Put(nil, kvStoreCtx, pairs...) + err := th.st.Put(context.TODO(), kvStoreCtx, pairs...) require.Nil(t, err) // check if exist in ctx we put for _, k := range keys { - contains, err := th.st.Contains(nil, kvStoreCtx, k) + contains, err := th.st.Contains(context.TODO(), kvStoreCtx, k) require.Nil(t, err) require.True(t, contains) } // check value and type in ctx we put for _, p := range pairs { - out, _ := th.st.Load(nil, kvStoreCtx, p.Key) + out, _ := th.st.Load(context.TODO(), kvStoreCtx, p.Key) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) require.Equal(t, p.Value, []byte(ih.RawString())) } - err = th.st.Delete(nil, kvStoreCtx, keys...) + err = th.st.Delete(context.TODO(), kvStoreCtx, keys...) require.Nil(t, err) for _, k := range keys { - contains, err := th.st.Contains(nil, kvStoreCtx, k) + contains, err := th.st.Contains(context.TODO(), kvStoreCtx, k) require.Nil(t, err) require.False(t, contains) } @@ -304,8 +305,8 @@ func RunTests(t *testing.T, p storage.PeerStorage) { t.Run("CustomPutContainsLoadDelete", th.CustomPutContainsLoadDelete) t.Run("CustomBulkPutContainsLoadDelete", th.CustomBulkPutContainsLoadDelete) - e := th.st.Stop() - require.Nil(t, <-e) + e := th.st.Close() + require.Nil(t, e) } func containsPeer(peers []bittorrent.Peer, p bittorrent.Peer) bool { From be40392f93c0cb1f48ba8e223c1dc5be925c71a4 Mon Sep 17 00:00:00 2001 From: "Lawrence, Rendall" Date: Tue, 1 Nov 2022 18:26:25 +0300 Subject: [PATCH 3/3] fix invalid branch in RemapRouteParamsToBgContext --- bittorrent/params.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bittorrent/params.go b/bittorrent/params.go index 819bfe5..7b2b272 100644 --- a/bittorrent/params.go +++ b/bittorrent/params.go @@ -99,9 +99,8 @@ func InjectRouteParamsToContext(ctx context.Context, rp RouteParams) context.Con func RemapRouteParamsToBgContext(inCtx context.Context) context.Context { rp, isOk := inCtx.Value(RouteParamsKey).(RouteParams) if !isOk { - rp = RouteParams{} - } else { logger.Warn().Msg("unable to fetch route parameters, probably jammed context") + rp = RouteParams{} } return context.WithValue(context.Background(), RouteParamsKey, rp) }