diff --git a/middleware/hooks.go b/middleware/hooks.go index e81dd70..2b1d77c 100644 --- a/middleware/hooks.go +++ b/middleware/hooks.go @@ -180,7 +180,7 @@ func (h *responseHook) appendPeers(req *bittorrent.AnnounceRequest, resp *bittor } l := len(peers) - uniquePeers := make(map[bittorrent.Peer]interface{}, l) + uniquePeers := make(map[bittorrent.Peer]any, l) resp.IPv4Peers = make([]bittorrent.Peer, 0, l/2) resp.IPv6Peers = make([]bittorrent.Peer, 0, l/2) diff --git a/pkg/log/log.go b/pkg/log/log.go index 83e96dd..798798c 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -181,14 +181,14 @@ func (l *Logger) Log() *zerolog.Event { // Print sends a log event using debug level and no extra field. // Arguments are handled in the manner of fmt.Print. -func (l *Logger) Print(v ...interface{}) { +func (l *Logger) Print(v ...any) { l.init() l.Logger.Print(v...) } // Printf sends a log event using debug level and no extra field. // Arguments are handled in the manner of fmt.Printf. -func (l *Logger) Printf(format string, v ...interface{}) { +func (l *Logger) Printf(format string, v ...any) { l.init() l.Logger.Printf(format, v...) } @@ -276,13 +276,13 @@ func Log() *zerolog.Event { // Print sends a log event using debug level and no extra field. // Arguments are handled in the manner of fmt.Print. -func Print(v ...interface{}) { +func Print(v ...any) { root.Print(v...) } // Printf sends a log event using debug level and no extra field. // Arguments are handled in the manner of fmt.Printf. -func Printf(format string, v ...interface{}) { +func Printf(format string, v ...any) { root.Printf(format, v...) } diff --git a/storage/pg/storage.go b/storage/pg/storage.go index 8208dd5..b3583bf 100644 --- a/storage/pg/storage.go +++ b/storage/pg/storage.go @@ -29,14 +29,16 @@ const ( Name = "pg" defaultPingQuery = "SELECT 0" + + errRequiredParameterNotSetMsg = "required parameter not provided: %s" + errRequiredColumnsNotFoundMsg = "one or more required columns not found in result set: %v" + errRollBackMsg = "error occurred while rolling back failed query: %v, failed query error: %v" ) var ( logger = log.NewLogger(Name) errConnectionStringNotProvided = errors.New("database address not provided") - errRequiredParameterNotSetMsg = "required parameter not provided: %s" - errRequiredColumnsNotFoundMsg = "one or more required columns not found in result set: %v" tc = timecache.New() ) @@ -53,6 +55,10 @@ func builder(icfg conf.MapConfig) (storage.PeerStorage, error) { return nil, err } + return newStore(cfg) +} + +func newStore(cfg Config) (storage.PeerStorage, error) { cfg, err := cfg.Validate() if err != nil { return nil, err @@ -72,32 +78,36 @@ func builder(icfg conf.MapConfig) (storage.PeerStorage, error) { return &store{Config: cfg, Pool: con, wg: sync.WaitGroup{}, closed: make(chan any)}, nil } +type peerQueryConf struct { + AddQuery string `cfg:"add_query"` + DelQuery string `cfg:"del_query"` + GraduateQuery string `cfg:"graduate_query"` + CountQuery string `cfg:"count_query"` + CountSeedersColumn string `cfg:"count_seeders_column"` + CountLeechersColumn string `cfg:"count_leechers_column"` + ByInfoHashClause string `cfg:"by_info_hash_clause"` +} + +type announceQueryConf struct { + Query string + PeerIDColumn string `cfg:"peer_id_column"` + AddressColumn string `cfg:"address_column"` + PortColumn string `cfg:"port_column"` +} + +type dataQueryConf struct { + AddQuery string `cfg:"add_query"` + GetQuery string `cfg:"get_query"` + DelQuery string `cfg:"del_query"` +} + // Config holds the configuration of a redis PeerStorage. type Config struct { - ConnectionString string `cfg:"connection_string"` - PingQuery string `cfg:"ping_query"` - Peer struct { - AddQuery string `cfg:"add_query"` - DelQuery string `cfg:"del_query"` - GraduateQuery string `cfg:"graduate_query"` - // SELECT COUNT(1) FILTER (WHERE seeder) AS seeders, COUNT(1) FILTER (WHERE NOT seeder) AS leechers FROM peers - CountQuery string `cfg:"count_query"` - CountSeedersColumn string `cfg:"count_seeders_column"` - CountLeechersColumn string `cfg:"count_leechers_column"` - // WHERE ih = ? - ByInfoHashClause string `cfg:"by_info_hash_clause"` - } - Announce struct { - Query string - PeerIDColumn string `cfg:"peer_id_column"` - AddressColumn string `cfg:"address_column"` - PortColumn string `cfg:"port_column"` - } - Data struct { - AddQuery string `cfg:"add_query"` - GetQuery string `cfg:"get_query"` - DelQuery string `cfg:"del_query"` - } + ConnectionString string `cfg:"connection_string"` + PingQuery string `cfg:"ping_query"` + Peer peerQueryConf + Announce announceQueryConf + Data dataQueryConf GCQuery string `cfg:"gc_query"` InfoHashCountQuery string `cfg:"info_hash_count_query"` } @@ -208,8 +218,8 @@ func (cfg Config) Validate() (Config, error) { } validCfg.Announce.PeerIDColumn = strings.ToUpper(validCfg.Announce.PeerIDColumn) - validCfg.Announce.PeerIDColumn = strings.ToUpper(validCfg.Announce.AddressColumn) - validCfg.Announce.PeerIDColumn = strings.ToUpper(validCfg.Announce.PortColumn) + validCfg.Announce.AddressColumn = strings.ToUpper(validCfg.Announce.AddressColumn) + validCfg.Announce.PortColumn = strings.ToUpper(validCfg.Announce.PortColumn) validCfg.Peer.CountSeedersColumn = strings.ToUpper(validCfg.Peer.CountSeedersColumn) validCfg.Peer.CountLeechersColumn = strings.ToUpper(validCfg.Peer.CountLeechersColumn) @@ -224,24 +234,74 @@ type store struct { closed chan any } -func (s *store) Put(ctx string, values ...storage.Entry) error { - // TODO implement me - panic("implement me") +func (s *store) Put(ctx string, values ...storage.Entry) (err error) { + var tx pgx.Tx + if tx, err = s.Begin(context.TODO()); err == nil { + for _, v := range values { + val := v.Value + switch tOut := val.(type) { + case string: + val = []byte(tOut) + } + if _, err = tx.Exec(context.TODO(), s.Data.AddQuery, ctx, v.Key, val); err != nil { + break + } + } + if err == nil { + err = tx.Commit(context.TODO()) + } else { + if txErr := tx.Rollback(context.TODO()); txErr != nil { + err = fmt.Errorf(errRollBackMsg, txErr, err) + } + } + } + return } -func (s *store) Contains(ctx string, key string) (bool, error) { - // TODO implement me - panic("implement me") +func (s *store) Contains(ctx string, key string) (contains bool, err error) { + var rows pgx.Rows + if rows, err = s.Query(context.TODO(), s.Data.GetQuery, ctx, key); err == nil { + defer rows.Close() + contains = rows.Next() + } + return } -func (s *store) Load(ctx string, key string) (any, error) { - // TODO implement me - panic("implement me") +func (s *store) Load(ctx string, key string) (out any, err error) { + var rows pgx.Rows + if rows, err = s.Query(context.TODO(), s.Data.GetQuery, ctx, key); err == nil { + defer rows.Close() + if rows.Next() { + var values []any + if values, err = rows.Values(); err == nil && len(values) > 0 { + out = values[0] + switch tOut := out.(type) { + case []byte: + out = string(tOut) + } + } + } + } + return } -func (s *store) Delete(ctx string, keys ...string) error { - // TODO implement me - panic("implement me") +func (s *store) Delete(ctx string, keys ...string) (err error) { + var tx pgx.Tx + if tx, err = s.Begin(context.TODO()); err == nil { + for _, k := range keys { + if _, err = tx.Exec(context.TODO(), s.Data.DelQuery, ctx, k); err != nil { + break + } + } + if err == nil { + err = tx.Commit(context.TODO()) + } else { + if txErr := tx.Rollback(context.TODO()); txErr != nil { + err = fmt.Errorf(errRollBackMsg, txErr, err) + } + } + } + return } func (s *store) Preservable() bool { @@ -318,7 +378,7 @@ func (s *store) putPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder boo Object("peer", peer). Bool("seeder", seeder). Msg("put peer") - args := []interface{}{[]byte(ih), peer.ID[:], net.IP(peer.Addr().AsSlice()), peer.Port(), seeder, peer.Addr().Is6()} + args := []any{[]byte(ih), peer.ID[:], net.IP(peer.Addr().AsSlice()), peer.Port(), seeder, peer.Addr().Is6()} if s.GCAware() { args = append(args, tc.Now()) } @@ -330,7 +390,6 @@ func (s *store) delPeer(ih bittorrent.InfoHash, peer bittorrent.Peer, seeder boo logger.Trace(). Stringer("infoHash", ih). Object("peer", peer). - Bool("seeder", seeder). Msg("del peer") _, err := s.Exec(context.TODO(), s.Peer.DelQuery, []byte(ih), peer.ID[:], net.IP(peer.Addr().AsSlice()), peer.Port(), seeder) return err @@ -357,13 +416,13 @@ func (s *store) GraduateLeecher(ih bittorrent.InfoHash, peer bittorrent.Peer) er Stringer("infoHash", ih). Object("peer", peer). Msg("graduate leecher") - _, err := s.Exec(context.TODO(), s.Peer.GraduateQuery, []byte(ih), peer.ID[:], peer.Addr(), peer.Port()) + _, err := s.Exec(context.TODO(), s.Peer.GraduateQuery, []byte(ih), peer.ID[:], net.IP(peer.Addr().AsSlice()), peer.Port()) return err } func (s *store) getPeers(ih bittorrent.InfoHash, seeders bool, maxCount int, isV6 bool) (peers []bittorrent.Peer, err error) { var rows pgx.Rows - if rows, err = s.Query(context.TODO(), s.Announce.Query, []byte(ih), isV6, seeders, maxCount); err == nil { + if rows, err = s.Query(context.TODO(), s.Announce.Query, []byte(ih), seeders, isV6, maxCount); err == nil { defer rows.Close() idIndex, ipIndex, portIndex := -1, -1, -1 for i, field := range rows.FieldDescriptions() { @@ -399,7 +458,7 @@ func (s *store) getPeers(ih bittorrent.InfoHash, seeders bool, maxCount int, isV var id []byte var ip net.IP var port int - into := make([]interface{}, maxIndex+1) + into := make([]any, maxIndex+1) into[idIndex], into[ipIndex], into[portIndex] = &id, &ip, &port if err = rows.Scan(into...); err == nil { @@ -455,7 +514,7 @@ func (s *store) AnnouncePeers(ih bittorrent.InfoHash, forSeeder bool, numWant in return } -func (s *store) countPeers(ih bittorrent.InfoHash) (leechers int, seeders int) { +func (s *store) countPeers(ih bittorrent.InfoHash) (seeders int, leechers int) { var rows pgx.Rows var err error if ih == bittorrent.NoneInfoHash { @@ -485,7 +544,7 @@ func (s *store) countPeers(ih bittorrent.InfoHash) (leechers int, seeders int) { } else { mi = li } - into := make([]interface{}, mi+1) + into := make([]any, mi+1) into[si], into[li] = &seeders, &leechers err = rows.Scan(into...) @@ -514,7 +573,18 @@ func (s *store) Ping() error { func (s *store) Stop() stop.Result { c := make(stop.Channel) - s.Close() + go func() { + if s.closed != nil { + close(s.closed) + } + s.wg.Wait() + if s.Pool != nil { + logger.Info().Msg("pg exiting. mochi does not clear data in database when exiting.") + s.Close() + s.Pool = nil + } + c.Done() + }() return c.Result() } diff --git a/storage/pg/storage_test.go b/storage/pg/storage_test.go new file mode 100644 index 0000000..b298b8f --- /dev/null +++ b/storage/pg/storage_test.go @@ -0,0 +1,83 @@ +package pg + +import ( + "context" + "fmt" + "testing" + + s "github.com/sot-tech/mochi/storage" + "github.com/sot-tech/mochi/storage/test" +) + +const ( + createTablesQuery = ` +DROP TABLE IF EXISTS mo_peers; +CREATE TABLE mo_peers ( + info_hash bytea NOT NULL, + peer_id bytea NOT NULL, + address inet NOT NULL, + port int2 NOT NULL, + is_seeder bool NOT NULL, + is_v6 bool NOT NULL, + created timestamp NOT NULL DEFAULT current_timestamp, + UNIQUE(info_hash, peer_id, address, port) +); + +CREATE INDEX peers_ih_idx ON mo_peers(info_hash); +CREATE INDEX peers_created_idx ON mo_peers(created); +CREATE INDEX peers_announce_idx ON mo_peers(info_hash, is_seeder, is_v6); + +DROP TABLE IF EXISTS mo_kv; +CREATE TABLE mo_kv ( + context varchar NOT NULL, + name varchar NOT NULL, + value bytea, + UNIQUE (context, name) +); +` +) + +var cfg = Config{ + ConnectionString: "host=127.0.0.1 database=test user=postgres", + PingQuery: "SELECT 1", + Peer: peerQueryConf{ + AddQuery: "INSERT INTO mo_peers VALUES($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (info_hash, peer_id, address, port) DO UPDATE SET created = EXCLUDED.created, is_seeder = EXCLUDED.is_seeder", + DelQuery: "DELETE FROM mo_peers WHERE info_hash=$1 AND peer_id=$2 AND address=$3 AND port=$4 AND is_seeder=$5", + GraduateQuery: "UPDATE mo_peers SET is_seeder=TRUE WHERE info_hash=$1 AND peer_id=$2 AND address=$3 AND port=$4 AND NOT is_seeder", + CountQuery: "SELECT COUNT(1) FILTER (WHERE is_seeder) AS seeders, COUNT(1) FILTER (WHERE NOT is_seeder) AS leechers FROM mo_peers", + CountSeedersColumn: "seeders", + CountLeechersColumn: "leechers", + ByInfoHashClause: "WHERE info_hash = $1", + }, + Announce: announceQueryConf{ + Query: "SELECT peer_id, address, port FROM mo_peers WHERE info_hash=$1 AND is_seeder=$2 AND is_v6=$3 LIMIT $4", + PeerIDColumn: "peer_id", + AddressColumn: "address", + PortColumn: "port", + }, + Data: dataQueryConf{ + AddQuery: "INSERT INTO mo_kv VALUES($1, $2, ($3)::bytea) ON CONFLICT (context, name) DO NOTHING", + GetQuery: "SELECT value FROM mo_kv WHERE context=$1 AND name=$2", + DelQuery: "DELETE FROM mo_kv WHERE context=$1 AND name=$2", + }, + GCQuery: "DELETE FROM mo_peers WHERE created > $1", + InfoHashCountQuery: "SELECT COUNT(DISTINCT info_hash) as info_hashes FROM mo_peers", +} + +func createNew() s.PeerStorage { + var ps s.PeerStorage + var err error + ps, err = newStore(cfg) + if err != nil { + panic(fmt.Sprint("Unable to create PostgreSQL connection: ", err, "\nThis driver needs real PostgreSQL instance")) + } + pss := ps.(*store) + if _, err = pss.Exec(context.Background(), createTablesQuery); err != nil { + panic(fmt.Sprint("Unable to create test PostgreSQL tables: ", err)) + } + return ps +} + +func TestStorage(t *testing.T) { test.RunTests(t, createNew()) } + +func BenchmarkStorage(b *testing.B) { test.RunBenchmarks(b, createNew) } diff --git a/storage/redis/storage_test.go b/storage/redis/storage_test.go index 3e305d4..a8ea916 100644 --- a/storage/redis/storage_test.go +++ b/storage/redis/storage_test.go @@ -11,7 +11,6 @@ import ( var cfg = Config{ Addresses: []string{"localhost:6379"}, - PeerLifetime: 30 * time.Minute, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, ConnectTimeout: 10 * time.Second, @@ -22,7 +21,7 @@ func createNew() s.PeerStorage { var err error ps, err = newStore(cfg) if err != nil { - panic(fmt.Sprint("Unable to create KeyDB connection: ", err, "\nThis driver needs real Redis instance")) + panic(fmt.Sprint("Unable to create Redis connection: ", err, "\nThis driver needs real Redis instance")) } return ps } diff --git a/storage/test/storage_test_base.go b/storage/test/storage_test_base.go index ae1495e..832b21a 100644 --- a/storage/test/storage_test_base.go +++ b/storage/test/storage_test_base.go @@ -13,6 +13,8 @@ import ( "github.com/sot-tech/mochi/storage" ) +const kv_store_ctx = "test" + func init() { _ = log.ConfigureLogger("", "warn", false, false) } @@ -39,7 +41,10 @@ type hashPeer struct { func (th *testHolder) DeleteSeeder(t *testing.T) { for _, c := range testData { err := th.st.DeleteSeeder(c.ih, c.peer) - require.Equal(t, storage.ErrResourceDoesNotExist, err) + if errors.Is(err, storage.ErrResourceDoesNotExist) { + err = nil + } + require.Nil(t, err) } } @@ -57,7 +62,10 @@ func (th *testHolder) PutLeecher(t *testing.T) { func (th *testHolder) DeleteLeecher(t *testing.T) { for _, c := range testData { err := th.st.DeleteLeecher(c.ih, c.peer) - require.Equal(t, storage.ErrResourceDoesNotExist, err) + if errors.Is(err, storage.ErrResourceDoesNotExist) { + err = nil + } + require.Nil(t, err) } } @@ -158,7 +166,10 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { // Deleting the Peer as a Leecher should have no effect err = th.st.DeleteLeecher(c.ih, c.peer) - require.Equal(t, storage.ErrResourceDoesNotExist, err) + if errors.Is(err, storage.ErrResourceDoesNotExist) { + err = nil + } + require.Nil(t, err) // Verify it's still there peers, err = th.st.AnnouncePeers(c.ih, false, 50, isV6) @@ -171,23 +182,29 @@ func (th *testHolder) LeecherPutGraduateAnnounceDeleteAnnounce(t *testing.T) { // Test ErrDNE for missing leecher err = th.st.DeleteLeecher(c.ih, peer) - require.Equal(t, storage.ErrResourceDoesNotExist, err) + if errors.Is(err, storage.ErrResourceDoesNotExist) { + err = nil + } + require.Nil(t, err) err = th.st.DeleteSeeder(c.ih, c.peer) require.Nil(t, err) err = th.st.DeleteSeeder(c.ih, c.peer) - require.Equal(t, storage.ErrResourceDoesNotExist, err) + if errors.Is(err, storage.ErrResourceDoesNotExist) { + err = nil + } + require.Nil(t, err) } } func (th *testHolder) CustomPutContainsLoadDelete(t *testing.T) { for _, c := range testData { - err := th.st.Put("test", storage.Entry{Key: c.peer.String(), Value: c.ih.RawString()}) + err := th.st.Put(kv_store_ctx, storage.Entry{Key: c.peer.String(), Value: c.ih.RawString()}) require.Nil(t, err) // check if exist in ctx we put - contains, err := th.st.Contains("test", c.peer.String()) + contains, err := th.st.Contains(kv_store_ctx, c.peer.String()) require.Nil(t, err) require.True(t, contains) @@ -197,7 +214,7 @@ func (th *testHolder) CustomPutContainsLoadDelete(t *testing.T) { require.False(t, contains) // check value and type in ctx we put - out, err := th.st.Load("test", c.peer.String()) + out, err := th.st.Load(kv_store_ctx, c.peer.String()) require.Nil(t, err) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) @@ -208,10 +225,10 @@ func (th *testHolder) CustomPutContainsLoadDelete(t *testing.T) { require.Nil(t, err) require.Nil(t, dummy) - err = th.st.Delete("test", c.peer.String()) + err = th.st.Delete(kv_store_ctx, c.peer.String()) require.Nil(t, err) - contains, err = th.st.Contains("peers", c.peer.String()) + contains, err = th.st.Contains("", c.peer.String()) require.Nil(t, err) require.False(t, contains) } @@ -228,29 +245,29 @@ func (th *testHolder) CustomBulkPutContainsLoadDelete(t *testing.T) { Value: c.ih.RawString(), }) } - err := th.st.Put("test", pairs...) + err := th.st.Put(kv_store_ctx, pairs...) require.Nil(t, err) // check if exist in ctx we put for _, k := range keys { - contains, err := th.st.Contains("test", k) + contains, err := th.st.Contains(kv_store_ctx, k) require.Nil(t, err) require.True(t, contains) } // check value and type in ctx we put for _, p := range pairs { - out, _ := th.st.Load("test", p.Key) + out, _ := th.st.Load(kv_store_ctx, p.Key) ih, err := bittorrent.NewInfoHash(out) require.Nil(t, err) require.Equal(t, p.Value, ih.RawString()) } - err = th.st.Delete("test", keys...) + err = th.st.Delete(kv_store_ctx, keys...) require.Nil(t, err) for _, k := range keys { - contains, err := th.st.Contains("test", k) + contains, err := th.st.Contains(kv_store_ctx, k) require.Nil(t, err) require.False(t, contains) }