mirror of
				https://github.com/yggdrasil-network/yggdrasil-go.git
				synced 2025-11-04 11:15:07 +03:00 
			
		
		
		
	Protect session nonces with mutexes, modify sent/received bytes atomically
This commit is contained in:
		
							parent
							
								
									ade684beff
								
							
						
					
					
						commit
						e3eadba4b7
					
				
					 2 changed files with 25 additions and 10 deletions
				
			
		| 
						 | 
					@ -3,6 +3,7 @@ package yggdrasil
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/hex"
 | 
						"encoding/hex"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/yggdrasil-network/yggdrasil-go/src/crypto"
 | 
						"github.com/yggdrasil-network/yggdrasil-go/src/crypto"
 | 
				
			||||||
| 
						 | 
					@ -84,7 +85,7 @@ func (c *Conn) Read(b []byte) (int, error) {
 | 
				
			||||||
		b = append(b, bs...)
 | 
							b = append(b, bs...)
 | 
				
			||||||
		c.session.updateNonce(&p.Nonce)
 | 
							c.session.updateNonce(&p.Nonce)
 | 
				
			||||||
		c.session.time = time.Now()
 | 
							c.session.time = time.Now()
 | 
				
			||||||
		c.session.bytesRecvd += uint64(len(bs))
 | 
							atomic.AddUint64(&c.session.bytesRecvd, uint64(len(b)))
 | 
				
			||||||
		return len(b), nil
 | 
							return len(b), nil
 | 
				
			||||||
	case <-c.session.closed:
 | 
						case <-c.session.closed:
 | 
				
			||||||
		return len(b), errors.New("session was closed")
 | 
							return len(b), errors.New("session was closed")
 | 
				
			||||||
| 
						 | 
					@ -106,7 +107,9 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
 | 
				
			||||||
	// code isn't multithreaded so appending to this is safe
 | 
						// code isn't multithreaded so appending to this is safe
 | 
				
			||||||
	coords := c.session.coords
 | 
						coords := c.session.coords
 | 
				
			||||||
	// Prepare the payload
 | 
						// Prepare the payload
 | 
				
			||||||
 | 
						c.session.myNonceMutex.Lock()
 | 
				
			||||||
	payload, nonce := crypto.BoxSeal(&c.session.sharedSesKey, b, &c.session.myNonce)
 | 
						payload, nonce := crypto.BoxSeal(&c.session.sharedSesKey, b, &c.session.myNonce)
 | 
				
			||||||
 | 
						c.session.myNonceMutex.Unlock()
 | 
				
			||||||
	defer util.PutBytes(payload)
 | 
						defer util.PutBytes(payload)
 | 
				
			||||||
	p := wire_trafficPacket{
 | 
						p := wire_trafficPacket{
 | 
				
			||||||
		Coords:  coords,
 | 
							Coords:  coords,
 | 
				
			||||||
| 
						 | 
					@ -115,7 +118,7 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) {
 | 
				
			||||||
		Payload: payload,
 | 
							Payload: payload,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	packet := p.encode()
 | 
						packet := p.encode()
 | 
				
			||||||
	c.session.bytesSent += uint64(len(b))
 | 
						atomic.AddUint64(&c.session.bytesSent, uint64(len(b)))
 | 
				
			||||||
	select {
 | 
						select {
 | 
				
			||||||
	case c.session.send <- packet:
 | 
						case c.session.send <- packet:
 | 
				
			||||||
	case <-c.session.closed:
 | 
						case <-c.session.closed:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,9 +29,10 @@ type sessionInfo struct {
 | 
				
			||||||
	theirHandle     crypto.Handle
 | 
						theirHandle     crypto.Handle
 | 
				
			||||||
	myHandle        crypto.Handle
 | 
						myHandle        crypto.Handle
 | 
				
			||||||
	theirNonce      crypto.BoxNonce
 | 
						theirNonce      crypto.BoxNonce
 | 
				
			||||||
	theirNonceMutex sync.RWMutex // protects the above
 | 
						theirNonceMask  uint64
 | 
				
			||||||
 | 
						theirNonceMutex sync.Mutex // protects the above
 | 
				
			||||||
	myNonce         crypto.BoxNonce
 | 
						myNonce         crypto.BoxNonce
 | 
				
			||||||
	myNonceMutex    sync.RWMutex // protects the above
 | 
						myNonceMutex    sync.Mutex // protects the above
 | 
				
			||||||
	theirMTU        uint16
 | 
						theirMTU        uint16
 | 
				
			||||||
	myMTU           uint16
 | 
						myMTU           uint16
 | 
				
			||||||
	wasMTUFixed     bool      // Was the MTU fixed by a receive error?
 | 
						wasMTUFixed     bool      // Was the MTU fixed by a receive error?
 | 
				
			||||||
| 
						 | 
					@ -42,7 +43,6 @@ type sessionInfo struct {
 | 
				
			||||||
	send            chan []byte
 | 
						send            chan []byte
 | 
				
			||||||
	recv            chan *wire_trafficPacket
 | 
						recv            chan *wire_trafficPacket
 | 
				
			||||||
	closed          chan interface{}
 | 
						closed          chan interface{}
 | 
				
			||||||
	nonceMask       uint64
 | 
					 | 
				
			||||||
	tstamp          int64     // tstamp from their last session ping, replay attack mitigation
 | 
						tstamp          int64     // tstamp from their last session ping, replay attack mitigation
 | 
				
			||||||
	tstampMutex     int64     // protects the above
 | 
						tstampMutex     int64     // protects the above
 | 
				
			||||||
	mtuTime         time.Time // time myMTU was last changed
 | 
						mtuTime         time.Time // time myMTU was last changed
 | 
				
			||||||
| 
						 | 
					@ -79,8 +79,10 @@ func (s *sessionInfo) update(p *sessionPing) bool {
 | 
				
			||||||
		s.theirSesPub = p.SendSesPub
 | 
							s.theirSesPub = p.SendSesPub
 | 
				
			||||||
		s.theirHandle = p.Handle
 | 
							s.theirHandle = p.Handle
 | 
				
			||||||
		s.sharedSesKey = *crypto.GetSharedKey(&s.mySesPriv, &s.theirSesPub)
 | 
							s.sharedSesKey = *crypto.GetSharedKey(&s.mySesPriv, &s.theirSesPub)
 | 
				
			||||||
 | 
							s.theirNonceMutex.Lock()
 | 
				
			||||||
		s.theirNonce = crypto.BoxNonce{}
 | 
							s.theirNonce = crypto.BoxNonce{}
 | 
				
			||||||
		s.nonceMask = 0
 | 
							s.theirNonceMask = 0
 | 
				
			||||||
 | 
							s.theirNonceMutex.Unlock()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if p.MTU >= 1280 || p.MTU == 0 {
 | 
						if p.MTU >= 1280 || p.MTU == 0 {
 | 
				
			||||||
		s.theirMTU = p.MTU
 | 
							s.theirMTU = p.MTU
 | 
				
			||||||
| 
						 | 
					@ -270,6 +272,10 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	sinfo := sessionInfo{}
 | 
						sinfo := sessionInfo{}
 | 
				
			||||||
 | 
						sinfo.myNonceMutex.Lock()
 | 
				
			||||||
 | 
						sinfo.theirNonceMutex.Lock()
 | 
				
			||||||
 | 
						defer sinfo.myNonceMutex.Unlock()
 | 
				
			||||||
 | 
						defer sinfo.theirNonceMutex.Unlock()
 | 
				
			||||||
	sinfo.core = ss.core
 | 
						sinfo.core = ss.core
 | 
				
			||||||
	sinfo.reconfigure = make(chan chan error, 1)
 | 
						sinfo.reconfigure = make(chan chan error, 1)
 | 
				
			||||||
	sinfo.theirPermPub = *theirPermKey
 | 
						sinfo.theirPermPub = *theirPermKey
 | 
				
			||||||
| 
						 | 
					@ -389,7 +395,9 @@ func (ss *sessions) getPing(sinfo *sessionInfo) sessionPing {
 | 
				
			||||||
		Coords:      coords,
 | 
							Coords:      coords,
 | 
				
			||||||
		MTU:         sinfo.myMTU,
 | 
							MTU:         sinfo.myMTU,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						sinfo.myNonceMutex.Lock()
 | 
				
			||||||
	sinfo.myNonce.Increment()
 | 
						sinfo.myNonce.Increment()
 | 
				
			||||||
 | 
						sinfo.myNonceMutex.Unlock()
 | 
				
			||||||
	return ref
 | 
						return ref
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -493,26 +501,30 @@ func (sinfo *sessionInfo) getMTU() uint16 {
 | 
				
			||||||
// Checks if a packet's nonce is recent enough to fall within the window of allowed packets, and not already received.
 | 
					// Checks if a packet's nonce is recent enough to fall within the window of allowed packets, and not already received.
 | 
				
			||||||
func (sinfo *sessionInfo) nonceIsOK(theirNonce *crypto.BoxNonce) bool {
 | 
					func (sinfo *sessionInfo) nonceIsOK(theirNonce *crypto.BoxNonce) bool {
 | 
				
			||||||
	// The bitmask is to allow for some non-duplicate out-of-order packets
 | 
						// The bitmask is to allow for some non-duplicate out-of-order packets
 | 
				
			||||||
 | 
						sinfo.theirNonceMutex.Lock()
 | 
				
			||||||
 | 
						defer sinfo.theirNonceMutex.Unlock()
 | 
				
			||||||
	diff := theirNonce.Minus(&sinfo.theirNonce)
 | 
						diff := theirNonce.Minus(&sinfo.theirNonce)
 | 
				
			||||||
	if diff > 0 {
 | 
						if diff > 0 {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return ^sinfo.nonceMask&(0x01<<uint64(-diff)) != 0
 | 
						return ^sinfo.theirNonceMask&(0x01<<uint64(-diff)) != 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Updates the nonce mask by (possibly) shifting the bitmask and setting the bit corresponding to this nonce to 1, and then updating the most recent nonce
 | 
					// Updates the nonce mask by (possibly) shifting the bitmask and setting the bit corresponding to this nonce to 1, and then updating the most recent nonce
 | 
				
			||||||
func (sinfo *sessionInfo) updateNonce(theirNonce *crypto.BoxNonce) {
 | 
					func (sinfo *sessionInfo) updateNonce(theirNonce *crypto.BoxNonce) {
 | 
				
			||||||
 | 
						sinfo.theirNonceMutex.Lock()
 | 
				
			||||||
 | 
						defer sinfo.theirNonceMutex.Unlock()
 | 
				
			||||||
	// Shift nonce mask if needed
 | 
						// Shift nonce mask if needed
 | 
				
			||||||
	// Set bit
 | 
						// Set bit
 | 
				
			||||||
	diff := theirNonce.Minus(&sinfo.theirNonce)
 | 
						diff := theirNonce.Minus(&sinfo.theirNonce)
 | 
				
			||||||
	if diff > 0 {
 | 
						if diff > 0 {
 | 
				
			||||||
		// This nonce is newer, so shift the window before setting the bit, and update theirNonce in the session info.
 | 
							// This nonce is newer, so shift the window before setting the bit, and update theirNonce in the session info.
 | 
				
			||||||
		sinfo.nonceMask <<= uint64(diff)
 | 
							sinfo.theirNonceMask <<= uint64(diff)
 | 
				
			||||||
		sinfo.nonceMask &= 0x01
 | 
							sinfo.theirNonceMask &= 0x01
 | 
				
			||||||
		sinfo.theirNonce = *theirNonce
 | 
							sinfo.theirNonce = *theirNonce
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		// This nonce is older, so set the bit but do not shift the window.
 | 
							// This nonce is older, so set the bit but do not shift the window.
 | 
				
			||||||
		sinfo.nonceMask &= 0x01 << uint64(-diff)
 | 
							sinfo.theirNonceMask &= 0x01 << uint64(-diff)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue