mirror of
https://github.com/sot-tech/mochi.git
synced 2026-04-30 01:19:59 -07:00
udp: Add constructor for ConnectionIDGenerator, return buffer to pool on socket timeout
This commit is contained in:
@@ -18,20 +18,24 @@ type ConnectionIDGenerator struct {
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// Init generates the AES key and sets up the first initialization vector.
|
||||
func (g *ConnectionIDGenerator) Init() error {
|
||||
// NewConnectionIDGenerator creates a ConnectionIDGenerator and generates its
|
||||
// AES key and first initialization vector.
|
||||
func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) {
|
||||
gen = &ConnectionIDGenerator{}
|
||||
key := make([]byte, 16)
|
||||
_, err := rand.Read(key)
|
||||
|
||||
_, err = rand.Read(key)
|
||||
if err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
|
||||
g.block, err = aes.NewCipher(key)
|
||||
gen.block, err = aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
|
||||
return g.NewIV()
|
||||
err = gen.NewIV()
|
||||
return
|
||||
}
|
||||
|
||||
// Generate returns the 64-bit connection ID for an IP
|
||||
|
||||
@@ -11,15 +11,13 @@ import (
|
||||
)
|
||||
|
||||
func TestInitReturnsNoError(t *testing.T) {
|
||||
gen := &ConnectionIDGenerator{}
|
||||
if err := gen.Init(); err != nil {
|
||||
if _, err := NewConnectionIDGenerator(); err != nil {
|
||||
t.Error("Init returned", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testGenerateConnectionID(t *testing.T, ip net.IP) {
|
||||
gen := &ConnectionIDGenerator{}
|
||||
gen.Init()
|
||||
gen, _ := NewConnectionIDGenerator()
|
||||
|
||||
id1 := gen.Generate(ip)
|
||||
id2 := gen.Generate(ip)
|
||||
@@ -46,8 +44,7 @@ func TestGenerateConnectionIDIPv6(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMatchesWorksWithPreviousIV(t *testing.T) {
|
||||
gen := &ConnectionIDGenerator{}
|
||||
gen.Init()
|
||||
gen, _ := NewConnectionIDGenerator()
|
||||
ip := net.ParseIP("192.168.1.123").To4()
|
||||
|
||||
id1 := gen.Generate(ip)
|
||||
|
||||
@@ -62,6 +62,7 @@ func (s *Server) serve(listenAddr string) error {
|
||||
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
|
||||
pool.GiveSlice(buffer)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
@@ -113,8 +114,8 @@ func (s *Server) Stop() {
|
||||
|
||||
// NewServer returns a new UDP server for a given configuration and tracker.
|
||||
func NewServer(cfg *config.Config, tkr *tracker.Tracker) *Server {
|
||||
gen := &ConnectionIDGenerator{}
|
||||
if err := gen.Init(); err != nil {
|
||||
gen, err := NewConnectionIDGenerator()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user