mirror of
				https://github.com/yggdrasil-network/yggdrasil-go.git
				synced 2025-11-04 03:05:07 +03:00 
			
		
		
		
	Refactor multicast so that it creates a new TCP listener for each interface with LL addresses (so that it will not break if Listen is not set with a wildcard address)
This commit is contained in:
		
							parent
							
								
									2419b61b2c
								
							
						
					
					
						commit
						de2aff2758
					
				
					 3 changed files with 68 additions and 50 deletions
				
			
		| 
						 | 
				
			
			@ -119,7 +119,8 @@ func (l *link) listen(uri string) error {
 | 
			
		|||
	}
 | 
			
		||||
	switch u.Scheme {
 | 
			
		||||
	case "tcp":
 | 
			
		||||
		return l.tcp.listen(u.Host)
 | 
			
		||||
		_, err := l.tcp.listen(u.Host)
 | 
			
		||||
		return err
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("unknown listen scheme: " + u.Scheme)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,6 @@ import (
 | 
			
		|||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/ipv6"
 | 
			
		||||
| 
						 | 
				
			
			@ -16,19 +15,16 @@ type multicast struct {
 | 
			
		|||
	reconfigure chan chan error
 | 
			
		||||
	sock        *ipv6.PacketConn
 | 
			
		||||
	groupAddr   string
 | 
			
		||||
	myAddr      *net.TCPAddr
 | 
			
		||||
	myAddrMutex sync.RWMutex
 | 
			
		||||
	listeners   map[string]*tcpListener
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *multicast) init(core *Core) {
 | 
			
		||||
	m.core = core
 | 
			
		||||
	m.reconfigure = make(chan chan error, 1)
 | 
			
		||||
	m.listeners = make(map[string]*tcpListener)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			e := <-m.reconfigure
 | 
			
		||||
			m.myAddrMutex.Lock()
 | 
			
		||||
			m.myAddr = m.core.link.tcp.getAddr()
 | 
			
		||||
			m.myAddrMutex.Unlock()
 | 
			
		||||
			e <- nil
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
| 
						 | 
				
			
			@ -94,10 +90,12 @@ func (m *multicast) interfaces() []net.Interface {
 | 
			
		|||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		for _, expr := range exprs {
 | 
			
		||||
			// Compile each regular expression
 | 
			
		||||
			e, err := regexp.Compile(expr)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				panic(err)
 | 
			
		||||
			}
 | 
			
		||||
			// Does the interface match the regular expression? Store it if so
 | 
			
		||||
			if e.MatchString(iface.Name) {
 | 
			
		||||
				interfaces = append(interfaces, iface)
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			@ -107,10 +105,6 @@ func (m *multicast) interfaces() []net.Interface {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (m *multicast) announce() {
 | 
			
		||||
	var anAddr net.TCPAddr
 | 
			
		||||
	m.myAddrMutex.Lock()
 | 
			
		||||
	m.myAddr = m.core.link.tcp.getAddr()
 | 
			
		||||
	m.myAddrMutex.Unlock()
 | 
			
		||||
	groupAddr, err := net.ResolveUDPAddr("udp6", m.groupAddr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
| 
						 | 
				
			
			@ -121,27 +115,47 @@ func (m *multicast) announce() {
 | 
			
		|||
	}
 | 
			
		||||
	for {
 | 
			
		||||
		for _, iface := range m.interfaces() {
 | 
			
		||||
			m.sock.JoinGroup(&iface, groupAddr)
 | 
			
		||||
			// Find interface addresses
 | 
			
		||||
			addrs, err := iface.Addrs()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				panic(err)
 | 
			
		||||
			}
 | 
			
		||||
			m.myAddrMutex.RLock()
 | 
			
		||||
			anAddr.Port = m.myAddr.Port
 | 
			
		||||
			m.myAddrMutex.RUnlock()
 | 
			
		||||
			for _, addr := range addrs {
 | 
			
		||||
				addrIP, _, _ := net.ParseCIDR(addr.String())
 | 
			
		||||
				// Ignore IPv4 addresses
 | 
			
		||||
				if addrIP.To4() != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				} // IPv6 only
 | 
			
		||||
				}
 | 
			
		||||
				// Ignore non-link-local addresses
 | 
			
		||||
				if !addrIP.IsLinkLocalUnicast() {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				anAddr.IP = addrIP
 | 
			
		||||
				anAddr.Zone = iface.Name
 | 
			
		||||
				destAddr.Zone = iface.Name
 | 
			
		||||
				msg := []byte(anAddr.String())
 | 
			
		||||
				m.sock.WriteTo(msg, nil, destAddr)
 | 
			
		||||
				// Join the multicast group
 | 
			
		||||
				m.sock.JoinGroup(&iface, groupAddr)
 | 
			
		||||
				// Try and see if we already have a TCP listener for this interface
 | 
			
		||||
				var listener *tcpListener
 | 
			
		||||
				if _, ok := m.listeners[iface.Name]; !ok {
 | 
			
		||||
					// No listener was found - let's create one
 | 
			
		||||
					listenaddr := fmt.Sprintf("[%s%%%s]:0", addrIP, iface.Name)
 | 
			
		||||
					if l, err := m.core.link.tcp.listen(listenaddr); err == nil {
 | 
			
		||||
						// Store the listener so that we can stop it later if needed
 | 
			
		||||
						listener = &tcpListener{
 | 
			
		||||
							listener: l,
 | 
			
		||||
							stop:     make(chan bool),
 | 
			
		||||
						}
 | 
			
		||||
						m.listeners[iface.Name] = listener
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					// An existing listener was found
 | 
			
		||||
					listener = m.listeners[iface.Name]
 | 
			
		||||
				}
 | 
			
		||||
				// Get the listener details and construct the multicast beacon
 | 
			
		||||
				lladdr := (*listener.listener).Addr().String()
 | 
			
		||||
				if a, err := net.ResolveTCPAddr("tcp6", lladdr); err == nil {
 | 
			
		||||
					destAddr.Zone = iface.Name
 | 
			
		||||
					msg := []byte(a.String())
 | 
			
		||||
					m.sock.WriteTo(msg, nil, destAddr)
 | 
			
		||||
				}
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,13 +33,17 @@ const tcp_ping_interval = (default_timeout * 2 / 3)
 | 
			
		|||
 | 
			
		||||
// The TCP listener and information about active TCP connections, to avoid duplication.
 | 
			
		||||
type tcp struct {
 | 
			
		||||
	link          *link
 | 
			
		||||
	reconfigure   chan chan error
 | 
			
		||||
	mutex         sync.Mutex // Protecting the below
 | 
			
		||||
	listeners     map[string]net.Listener
 | 
			
		||||
	listenerstops map[string]chan bool
 | 
			
		||||
	calls         map[string]struct{}
 | 
			
		||||
	conns         map[linkInfo](chan struct{})
 | 
			
		||||
	link        *link
 | 
			
		||||
	reconfigure chan chan error
 | 
			
		||||
	mutex       sync.Mutex // Protecting the below
 | 
			
		||||
	listeners   map[string]*tcpListener
 | 
			
		||||
	calls       map[string]struct{}
 | 
			
		||||
	conns       map[linkInfo](chan struct{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type tcpListener struct {
 | 
			
		||||
	listener *net.Listener
 | 
			
		||||
	stop     chan bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Wrapper function to set additional options for specific connection types.
 | 
			
		||||
| 
						 | 
				
			
			@ -60,7 +64,7 @@ func (t *tcp) getAddr() *net.TCPAddr {
 | 
			
		|||
	t.mutex.Lock()
 | 
			
		||||
	defer t.mutex.Unlock()
 | 
			
		||||
	for _, listener := range t.listeners {
 | 
			
		||||
		return listener.Addr().(*net.TCPAddr)
 | 
			
		||||
		return (*listener.listener).Addr().(*net.TCPAddr)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -72,8 +76,7 @@ func (t *tcp) init(l *link) error {
 | 
			
		|||
	t.mutex.Lock()
 | 
			
		||||
	t.calls = make(map[string]struct{})
 | 
			
		||||
	t.conns = make(map[linkInfo](chan struct{}))
 | 
			
		||||
	t.listeners = make(map[string]net.Listener)
 | 
			
		||||
	t.listenerstops = make(map[string]chan bool)
 | 
			
		||||
	t.listeners = make(map[string]*tcpListener)
 | 
			
		||||
	t.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
| 
						 | 
				
			
			@ -89,7 +92,7 @@ func (t *tcp) init(l *link) error {
 | 
			
		|||
						e <- errors.New("unknown scheme: " + add)
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					if err := t.listen(add[6:]); err != nil {
 | 
			
		||||
					if _, err := t.listen(add[6:]); err != nil {
 | 
			
		||||
						e <- err
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
| 
						 | 
				
			
			@ -110,7 +113,7 @@ func (t *tcp) init(l *link) error {
 | 
			
		|||
		if listenaddr[:6] != "tcp://" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if err := t.listen(listenaddr[6:]); err != nil {
 | 
			
		||||
		if _, err := t.listen(listenaddr[6:]); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -118,7 +121,7 @@ func (t *tcp) init(l *link) error {
 | 
			
		|||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *tcp) listen(listenaddr string) error {
 | 
			
		||||
func (t *tcp) listen(listenaddr string) (*net.Listener, error) {
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
| 
						 | 
				
			
			@ -127,36 +130,36 @@ func (t *tcp) listen(listenaddr string) error {
 | 
			
		|||
	}
 | 
			
		||||
	listener, err := lc.Listen(ctx, "tcp", listenaddr)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		l := tcpListener{
 | 
			
		||||
			listener: &listener,
 | 
			
		||||
			stop:     make(chan bool, 1),
 | 
			
		||||
		}
 | 
			
		||||
		t.mutex.Lock()
 | 
			
		||||
		t.listeners[listenaddr] = listener
 | 
			
		||||
		t.listenerstops[listenaddr] = make(chan bool, 1)
 | 
			
		||||
		t.listeners[listenaddr[6:]] = &l
 | 
			
		||||
		t.mutex.Unlock()
 | 
			
		||||
		go t.listener(listenaddr)
 | 
			
		||||
		return nil
 | 
			
		||||
		go t.listener(&l)
 | 
			
		||||
		return &listener, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
	return nil, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Runs the listener, which spawns off goroutines for incoming connections.
 | 
			
		||||
func (t *tcp) listener(listenaddr string) {
 | 
			
		||||
	t.mutex.Lock()
 | 
			
		||||
	listener, ok1 := t.listeners[listenaddr]
 | 
			
		||||
	listenerstop, ok2 := t.listenerstops[listenaddr]
 | 
			
		||||
	t.mutex.Unlock()
 | 
			
		||||
	if !ok1 || !ok2 {
 | 
			
		||||
		t.link.core.log.Errorln("Tried to start TCP listener for", listenaddr, "which doesn't exist")
 | 
			
		||||
func (t *tcp) listener(listener *tcpListener) {
 | 
			
		||||
	if listener == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	reallistenaddr := listener.Addr().String()
 | 
			
		||||
	defer listener.Close()
 | 
			
		||||
	reallistener := *listener.listener
 | 
			
		||||
	reallistenaddr := reallistener.Addr().String()
 | 
			
		||||
	stop := listener.stop
 | 
			
		||||
	defer reallistener.Close()
 | 
			
		||||
	t.link.core.log.Infoln("Listening for TCP on:", reallistenaddr)
 | 
			
		||||
	accepted := make(chan bool)
 | 
			
		||||
	for {
 | 
			
		||||
		var sock net.Conn
 | 
			
		||||
		var err error
 | 
			
		||||
		go func() {
 | 
			
		||||
			sock, err = listener.Accept()
 | 
			
		||||
			sock, err = reallistener.Accept()
 | 
			
		||||
			accepted <- true
 | 
			
		||||
		}()
 | 
			
		||||
		select {
 | 
			
		||||
| 
						 | 
				
			
			@ -166,7 +169,7 @@ func (t *tcp) listener(listenaddr string) {
 | 
			
		|||
				return
 | 
			
		||||
			}
 | 
			
		||||
			go t.handler(sock, true)
 | 
			
		||||
		case <-listenerstop:
 | 
			
		||||
		case <-stop:
 | 
			
		||||
			t.link.core.log.Errorln("Stopping TCP listener on:", reallistenaddr)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue