add TLS configuration for redis/keydb connection

This commit is contained in:
Lawrence, Rendall
2025-01-03 22:00:45 +03:00
parent ebe4dc39b3
commit b0bbce093a
3 changed files with 75 additions and 4 deletions

View File

@@ -70,6 +70,13 @@ storage:
# only first address used
addresses: ["127.0.0.1:6379"]
# Use TLS for connecting to specified addresses.
tls: false
# Trusted CA certificates to check while TLS connection.
# If not provided, system certificates used.
ca_certs: []
# Database to be selected after connecting to the server.
db: 0

View File

@@ -37,6 +37,16 @@ mochi:
# The addresses of redis storage.
addresses: ["127.0.0.1:6379"]
# Use TLS for connecting to specified addresses.
# Note: server names must strictly match with names,
# provided in certificates (CN or SAN field) and
# issuer must be trusted by system or provided in ca_certs.
tls: false
# Trusted CA certificates to check while TLS connection.
# If not provided, system certificates used.
ca_certs: []
# Database number
db: 0

View File

@@ -24,11 +24,14 @@ package redis
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
"sync"
@@ -119,6 +122,8 @@ func NewStore(cfg Config) (storage.PeerStorage, error) {
type Config struct {
PeerLifetime time.Duration `cfg:"peer_lifetime"`
Addresses []string
TLS bool
CACerts []string `cfg:"ca_certs"`
DB int
PoolSize int `cfg:"pool_size"`
Login string
@@ -187,12 +192,52 @@ func (cfg Config) Validate() (Config, error) {
Msg("falling back to default configuration")
}
if cfg.TLS {
for _, cert := range cfg.CACerts {
if _, err := os.Stat(cert); err != nil {
return cfg, err
}
}
}
return validCfg, nil
}
// Connect creates redis client from configuration
func (cfg Config) Connect() (con Connection, err error) {
var rs redis.UniversalClient
var tlsConf *tls.Config
if cfg.TLS {
tlsConf = &tls.Config{MinVersion: tls.VersionTLS12}
if len(cfg.CACerts) > 0 {
certPool := x509.NewCertPool()
var ok bool
for _, cert := range cfg.CACerts {
var certData []byte
if certData, err = os.ReadFile(cert); err != nil {
logger.Warn().
Err(err).
Str("name", cert).
Msg("unable to read certificate file")
} else if len(certData) > 0 {
if added := certPool.AppendCertsFromPEM(certData); added {
ok = true
} else {
logger.Warn().
Str("name", cert).
Msg("unable to append certificate to trusted pool")
}
} else {
logger.Warn().
Str("name", cert).
Msg("empty certificate file")
}
}
if ok {
tlsConf.RootCAs = certPool
}
}
}
switch {
case cfg.Cluster:
rs = redis.NewClusterClient(&redis.ClusterOptions{
@@ -203,6 +248,7 @@ func (cfg Config) Connect() (con Connection, err error) {
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
PoolSize: cfg.PoolSize,
TLSConfig: tlsConf,
})
case cfg.Sentinel:
rs = redis.NewFailoverClient(&redis.FailoverOptions{
@@ -215,6 +261,7 @@ func (cfg Config) Connect() (con Connection, err error) {
WriteTimeout: cfg.WriteTimeout,
PoolSize: cfg.PoolSize,
DB: cfg.DB,
TLSConfig: tlsConf,
})
default:
rs = redis.NewClient(&redis.Options{
@@ -226,6 +273,7 @@ func (cfg Config) Connect() (con Connection, err error) {
WriteTimeout: cfg.WriteTimeout,
PoolSize: cfg.PoolSize,
DB: cfg.DB,
TLSConfig: tlsConf,
})
}
if err = rs.Ping(context.Background()).Err(); err == nil && !errors.Is(err, redis.Nil) {
@@ -464,7 +512,8 @@ func (ps *store) GraduateLeecher(ctx context.Context, ih bittorrent.InfoHash, pe
// peerMinimumLen is the least allowed length of string serialized Peer
const peerMinimumLen = bittorrent.PeerIDLen + 2 + net.IPv4len
var errInvalidPeerDataSize = fmt.Errorf("invalid peer data (must be at least %d bytes (PeerID + Port + IPv4))", peerMinimumLen)
var errInvalidPeerDataSize = fmt.Errorf("invalid peer data (must be at least %d bytes (PeerID + Port + IPv4))",
peerMinimumLen)
// UnpackPeer constructs Peer from serialized by Peer.PackPeer data: PeerID[20by]Port[2by]net.IP[4/16by]
func UnpackPeer(data string) (peer bittorrent.Peer, err error) {
@@ -510,7 +559,9 @@ type getPeersFn func(context.Context, string, int) *redis.StringSliceCmd
// converts result to bittorrent.Peer array.
// If forSeeder set to true - returns only leechers, if false -
// seeders and if maxCount not reached - leechers.
func (ps *Connection) GetPeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn) (out []bittorrent.Peer, err error) {
func (ps *Connection) GetPeers(
ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, maxCount int, isV6 bool, membersFn getPeersFn,
) (out []bittorrent.Peer, err error) {
infoHash := ih.RawString()
infoHashKeys := make([]string, 1, 2)
@@ -544,7 +595,9 @@ func (ps *Connection) GetPeers(ctx context.Context, ih bittorrent.InfoHash, forS
return
}
func (ps *store) AnnouncePeers(ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool) ([]bittorrent.Peer, error) {
func (ps *store) AnnouncePeers(
ctx context.Context, ih bittorrent.InfoHash, forSeeder bool, numWant int, v6 bool,
) ([]bittorrent.Peer, error) {
logger.Trace().
Stringer("infoHash", ih).
Bool("forSeeder", forSeeder).
@@ -720,7 +773,8 @@ func (ps *store) gc(cutoff time.Time) {
for _, infoHashKey := range infoHashKeys {
var cntKey string
var seeder bool
if seeder = strings.HasPrefix(infoHashKey, IH4SeederKey) || strings.HasPrefix(infoHashKey, IH6SeederKey); seeder {
if seeder = strings.HasPrefix(infoHashKey, IH4SeederKey) || strings.HasPrefix(infoHashKey,
IH6SeederKey); seeder {
cntKey = CountSeederKey
} else if strings.HasPrefix(infoHashKey, IH4LeecherKey) || strings.HasPrefix(infoHashKey, IH6LeecherKey) {
cntKey = CountLeecherKey