diff --git a/dist/example_config_redis.yaml b/dist/example_config_redis.yaml index 1b5b497..f139231 100644 --- a/dist/example_config_redis.yaml +++ b/dist/example_config_redis.yaml @@ -70,6 +70,13 @@ storage: # only first address used addresses: ["127.0.0.1:6379"] + # Use TLS for connecting to specified addresses. + tls: false + + # Trusted CA certificates to check while TLS connection. + # If not provided, system certificates used. + ca_certs: [] + # Database to be selected after connecting to the server. db: 0 diff --git a/docs/storage/redis.md b/docs/storage/redis.md index 908a097..5d811d3 100644 --- a/docs/storage/redis.md +++ b/docs/storage/redis.md @@ -37,6 +37,16 @@ mochi: # The addresses of redis storage. addresses: ["127.0.0.1:6379"] + + # Use TLS for connecting to specified addresses. + # Note: server names must strictly match with names, + # provided in certificates (CN or SAN field) and + # issuer must be trusted by system or provided in ca_certs. + tls: false + + # Trusted CA certificates to check while TLS connection. + # If not provided, system certificates used. + ca_certs: [] # Database number db: 0 diff --git a/storage/redis/storage.go b/storage/redis/storage.go index 3f6454f..5aba7c9 100644 --- a/storage/redis/storage.go +++ b/storage/redis/storage.go @@ -24,11 +24,14 @@ package redis import ( "context" + "crypto/tls" + "crypto/x509" "encoding/binary" "errors" "fmt" "net" "net/netip" + "os" "strconv" "strings" "sync" @@ -119,6 +122,8 @@ func NewStore(cfg Config) (storage.PeerStorage, error) { type Config struct { PeerLifetime time.Duration `cfg:"peer_lifetime"` Addresses []string + TLS bool + CACerts []string `cfg:"ca_certs"` DB int PoolSize int `cfg:"pool_size"` Login string @@ -187,12 +192,52 @@ func (cfg Config) Validate() (Config, error) { Msg("falling back to default configuration") } + if cfg.TLS { + for _, cert := range cfg.CACerts { + if _, err := os.Stat(cert); err != nil { + return cfg, err + } + } + } + return validCfg, nil } // Connect creates redis client from configuration func (cfg Config) Connect() (con Connection, err error) { var rs redis.UniversalClient + var tlsConf *tls.Config + if cfg.TLS { + tlsConf = &tls.Config{MinVersion: tls.VersionTLS12} + if len(cfg.CACerts) > 0 { + certPool := x509.NewCertPool() + var ok bool + for _, cert := range cfg.CACerts { + var certData []byte + if certData, err = os.ReadFile(cert); err != nil { + logger.Warn(). + Err(err). + Str("name", cert). + Msg("unable to read certificate file") + } else if len(certData) > 0 { + if added := certPool.AppendCertsFromPEM(certData); added { + ok = true + } else { + logger.Warn(). + Str("name", cert). + Msg("unable to append certificate to trusted pool") + } + } else { + logger.Warn(). + Str("name", cert). + Msg("empty certificate file") + } + } + if ok { + tlsConf.RootCAs = certPool + } + } + } switch { case cfg.Cluster: rs = redis.NewClusterClient(&redis.ClusterOptions{ @@ -203,6 +248,7 @@ func (cfg Config) Connect() (con Connection, err error) { ReadTimeout: cfg.ReadTimeout, WriteTimeout: cfg.WriteTimeout, PoolSize: cfg.PoolSize, + TLSConfig: tlsConf, }) case cfg.Sentinel: rs = redis.NewFailoverClient(&redis.FailoverOptions{ @@ -215,6 +261,7 @@ func (cfg Config) Connect() (con Connection, err error) { WriteTimeout: cfg.WriteTimeout, PoolSize: cfg.PoolSize, DB: cfg.DB, + TLSConfig: tlsConf, }) default: rs = redis.NewClient(&redis.Options{ @@ -226,6 +273,7 @@ func (cfg Config) Connect() (con Connection, err error) { WriteTimeout: cfg.WriteTimeout, PoolSize: cfg.PoolSize, DB: cfg.DB, + TLSConfig: tlsConf, }) } if err = rs.Ping(context.Background()).Err(); err == nil && !errors.Is(err, redis.Nil) { @@ -464,7 +512,8 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe // peerMinimumLen is the least allowed length of string serialized Peer const peerMinimumLen = bittorrent.PeerIDLen + 2 + net.IPv4len -var errInvalidPeerDataSize = fmt.Errorf("invalid peer data (must be at least %d bytes (PeerID + Port + IPv4))", peerMinimumLen) +var errInvalidPeerDataSize = fmt.Errorf("invalid peer data (must be at least %d bytes (PeerID + Port + IPv4))", + peerMinimumLen) // UnpackPeer constructs Peer from serialized by Peer.PackPeer data: PeerID[20by]Port[2by]net.IP[4/16by] func UnpackPeer(data string) (peer bittorrent.Peer, err error) { @@ -510,7 +559,9 @@ type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd // converts result to bittorrent.Peer array. // If forSeeder set to true - returns only leechers, if false - // seeders and if maxCount not reached - leechers. -func (ps *Connection) GetPeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) { +func (ps *Connection) GetPeers( + ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn, +) (out []bittorrent.Peer, err error) { infoHash := ih.RawString() infoHashKeys := make([]string, 1, 2) @@ -544,7 +595,9 @@ func (ps *Connection) GetPeers(ctx context.Context, ih bittorrent.InfoHash, forS return } -func (ps *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) { +func (ps *store) AnnouncePeers( + ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool, +) ([]bittorrent.Peer, error) { logger.Trace(). Stringer("infoHash", ih). Bool("forSeeder", forSeeder). @@ -720,7 +773,8 @@ func (ps *store) gc(cutoff time.Time) { for _, infoHashKey := range infoHashKeys { var cntKey string var seeder bool - if seeder = strings.HasPrefix(infoHashKey, IH4SeederKey) || strings.HasPrefix(infoHashKey, IH6SeederKey); seeder { + if seeder = strings.HasPrefix(infoHashKey, IH4SeederKey) || strings.HasPrefix(infoHashKey, + IH6SeederKey); seeder { cntKey = CountSeederKey } else if strings.HasPrefix(infoHashKey, IH4LeecherKey) || strings.HasPrefix(infoHashKey, IH6LeecherKey) { cntKey = CountLeecherKey