mirror of
				https://github.com/yggdrasil-network/yggdrasil-go.git
				synced 2025-11-04 11:15:07 +03:00 
			
		
		
		
	Refactor multicast so that it creates a new TCP listener for each interface with LL addresses (so that it will not break if Listen is not set with a wildcard address)
This commit is contained in:
		
							parent
							
								
									2419b61b2c
								
							
						
					
					
						commit
						de2aff2758
					
				
					 3 changed files with 68 additions and 50 deletions
				
			
		| 
						 | 
					@ -119,7 +119,8 @@ func (l *link) listen(uri string) error {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	switch u.Scheme {
 | 
						switch u.Scheme {
 | 
				
			||||||
	case "tcp":
 | 
						case "tcp":
 | 
				
			||||||
		return l.tcp.listen(u.Host)
 | 
							_, err := l.tcp.listen(u.Host)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return errors.New("unknown listen scheme: " + u.Scheme)
 | 
							return errors.New("unknown listen scheme: " + u.Scheme)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,7 +5,6 @@ import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"golang.org/x/net/ipv6"
 | 
						"golang.org/x/net/ipv6"
 | 
				
			||||||
| 
						 | 
					@ -16,19 +15,16 @@ type multicast struct {
 | 
				
			||||||
	reconfigure chan chan error
 | 
						reconfigure chan chan error
 | 
				
			||||||
	sock        *ipv6.PacketConn
 | 
						sock        *ipv6.PacketConn
 | 
				
			||||||
	groupAddr   string
 | 
						groupAddr   string
 | 
				
			||||||
	myAddr      *net.TCPAddr
 | 
						listeners   map[string]*tcpListener
 | 
				
			||||||
	myAddrMutex sync.RWMutex
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *multicast) init(core *Core) {
 | 
					func (m *multicast) init(core *Core) {
 | 
				
			||||||
	m.core = core
 | 
						m.core = core
 | 
				
			||||||
	m.reconfigure = make(chan chan error, 1)
 | 
						m.reconfigure = make(chan chan error, 1)
 | 
				
			||||||
 | 
						m.listeners = make(map[string]*tcpListener)
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		for {
 | 
							for {
 | 
				
			||||||
			e := <-m.reconfigure
 | 
								e := <-m.reconfigure
 | 
				
			||||||
			m.myAddrMutex.Lock()
 | 
					 | 
				
			||||||
			m.myAddr = m.core.link.tcp.getAddr()
 | 
					 | 
				
			||||||
			m.myAddrMutex.Unlock()
 | 
					 | 
				
			||||||
			e <- nil
 | 
								e <- nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
| 
						 | 
					@ -94,10 +90,12 @@ func (m *multicast) interfaces() []net.Interface {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		for _, expr := range exprs {
 | 
							for _, expr := range exprs {
 | 
				
			||||||
 | 
								// Compile each regular expression
 | 
				
			||||||
			e, err := regexp.Compile(expr)
 | 
								e, err := regexp.Compile(expr)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				panic(err)
 | 
									panic(err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								// Does the interface match the regular expression? Store it if so
 | 
				
			||||||
			if e.MatchString(iface.Name) {
 | 
								if e.MatchString(iface.Name) {
 | 
				
			||||||
				interfaces = append(interfaces, iface)
 | 
									interfaces = append(interfaces, iface)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -107,10 +105,6 @@ func (m *multicast) interfaces() []net.Interface {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *multicast) announce() {
 | 
					func (m *multicast) announce() {
 | 
				
			||||||
	var anAddr net.TCPAddr
 | 
					 | 
				
			||||||
	m.myAddrMutex.Lock()
 | 
					 | 
				
			||||||
	m.myAddr = m.core.link.tcp.getAddr()
 | 
					 | 
				
			||||||
	m.myAddrMutex.Unlock()
 | 
					 | 
				
			||||||
	groupAddr, err := net.ResolveUDPAddr("udp6", m.groupAddr)
 | 
						groupAddr, err := net.ResolveUDPAddr("udp6", m.groupAddr)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		panic(err)
 | 
							panic(err)
 | 
				
			||||||
| 
						 | 
					@ -121,27 +115,47 @@ func (m *multicast) announce() {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		for _, iface := range m.interfaces() {
 | 
							for _, iface := range m.interfaces() {
 | 
				
			||||||
			m.sock.JoinGroup(&iface, groupAddr)
 | 
								// Find interface addresses
 | 
				
			||||||
			addrs, err := iface.Addrs()
 | 
								addrs, err := iface.Addrs()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				panic(err)
 | 
									panic(err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			m.myAddrMutex.RLock()
 | 
					 | 
				
			||||||
			anAddr.Port = m.myAddr.Port
 | 
					 | 
				
			||||||
			m.myAddrMutex.RUnlock()
 | 
					 | 
				
			||||||
			for _, addr := range addrs {
 | 
								for _, addr := range addrs {
 | 
				
			||||||
				addrIP, _, _ := net.ParseCIDR(addr.String())
 | 
									addrIP, _, _ := net.ParseCIDR(addr.String())
 | 
				
			||||||
 | 
									// Ignore IPv4 addresses
 | 
				
			||||||
				if addrIP.To4() != nil {
 | 
									if addrIP.To4() != nil {
 | 
				
			||||||
					continue
 | 
										continue
 | 
				
			||||||
				} // IPv6 only
 | 
									}
 | 
				
			||||||
 | 
									// Ignore non-link-local addresses
 | 
				
			||||||
				if !addrIP.IsLinkLocalUnicast() {
 | 
									if !addrIP.IsLinkLocalUnicast() {
 | 
				
			||||||
					continue
 | 
										continue
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				anAddr.IP = addrIP
 | 
									// Join the multicast group
 | 
				
			||||||
				anAddr.Zone = iface.Name
 | 
									m.sock.JoinGroup(&iface, groupAddr)
 | 
				
			||||||
 | 
									// Try and see if we already have a TCP listener for this interface
 | 
				
			||||||
 | 
									var listener *tcpListener
 | 
				
			||||||
 | 
									if _, ok := m.listeners[iface.Name]; !ok {
 | 
				
			||||||
 | 
										// No listener was found - let's create one
 | 
				
			||||||
 | 
										listenaddr := fmt.Sprintf("[%s%%%s]:0", addrIP, iface.Name)
 | 
				
			||||||
 | 
										if l, err := m.core.link.tcp.listen(listenaddr); err == nil {
 | 
				
			||||||
 | 
											// Store the listener so that we can stop it later if needed
 | 
				
			||||||
 | 
											listener = &tcpListener{
 | 
				
			||||||
 | 
												listener: l,
 | 
				
			||||||
 | 
												stop:     make(chan bool),
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
											m.listeners[iface.Name] = listener
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										// An existing listener was found
 | 
				
			||||||
 | 
										listener = m.listeners[iface.Name]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									// Get the listener details and construct the multicast beacon
 | 
				
			||||||
 | 
									lladdr := (*listener.listener).Addr().String()
 | 
				
			||||||
 | 
									if a, err := net.ResolveTCPAddr("tcp6", lladdr); err == nil {
 | 
				
			||||||
					destAddr.Zone = iface.Name
 | 
										destAddr.Zone = iface.Name
 | 
				
			||||||
				msg := []byte(anAddr.String())
 | 
										msg := []byte(a.String())
 | 
				
			||||||
					m.sock.WriteTo(msg, nil, destAddr)
 | 
										m.sock.WriteTo(msg, nil, destAddr)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
				break
 | 
									break
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			time.Sleep(time.Second)
 | 
								time.Sleep(time.Second)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -36,12 +36,16 @@ type tcp struct {
 | 
				
			||||||
	link        *link
 | 
						link        *link
 | 
				
			||||||
	reconfigure chan chan error
 | 
						reconfigure chan chan error
 | 
				
			||||||
	mutex       sync.Mutex // Protecting the below
 | 
						mutex       sync.Mutex // Protecting the below
 | 
				
			||||||
	listeners     map[string]net.Listener
 | 
						listeners   map[string]*tcpListener
 | 
				
			||||||
	listenerstops map[string]chan bool
 | 
					 | 
				
			||||||
	calls       map[string]struct{}
 | 
						calls       map[string]struct{}
 | 
				
			||||||
	conns       map[linkInfo](chan struct{})
 | 
						conns       map[linkInfo](chan struct{})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type tcpListener struct {
 | 
				
			||||||
 | 
						listener *net.Listener
 | 
				
			||||||
 | 
						stop     chan bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Wrapper function to set additional options for specific connection types.
 | 
					// Wrapper function to set additional options for specific connection types.
 | 
				
			||||||
func (t *tcp) setExtraOptions(c net.Conn) {
 | 
					func (t *tcp) setExtraOptions(c net.Conn) {
 | 
				
			||||||
	switch sock := c.(type) {
 | 
						switch sock := c.(type) {
 | 
				
			||||||
| 
						 | 
					@ -60,7 +64,7 @@ func (t *tcp) getAddr() *net.TCPAddr {
 | 
				
			||||||
	t.mutex.Lock()
 | 
						t.mutex.Lock()
 | 
				
			||||||
	defer t.mutex.Unlock()
 | 
						defer t.mutex.Unlock()
 | 
				
			||||||
	for _, listener := range t.listeners {
 | 
						for _, listener := range t.listeners {
 | 
				
			||||||
		return listener.Addr().(*net.TCPAddr)
 | 
							return (*listener.listener).Addr().(*net.TCPAddr)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -72,8 +76,7 @@ func (t *tcp) init(l *link) error {
 | 
				
			||||||
	t.mutex.Lock()
 | 
						t.mutex.Lock()
 | 
				
			||||||
	t.calls = make(map[string]struct{})
 | 
						t.calls = make(map[string]struct{})
 | 
				
			||||||
	t.conns = make(map[linkInfo](chan struct{}))
 | 
						t.conns = make(map[linkInfo](chan struct{}))
 | 
				
			||||||
	t.listeners = make(map[string]net.Listener)
 | 
						t.listeners = make(map[string]*tcpListener)
 | 
				
			||||||
	t.listenerstops = make(map[string]chan bool)
 | 
					 | 
				
			||||||
	t.mutex.Unlock()
 | 
						t.mutex.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
| 
						 | 
					@ -89,7 +92,7 @@ func (t *tcp) init(l *link) error {
 | 
				
			||||||
						e <- errors.New("unknown scheme: " + add)
 | 
											e <- errors.New("unknown scheme: " + add)
 | 
				
			||||||
						continue
 | 
											continue
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
					if err := t.listen(add[6:]); err != nil {
 | 
										if _, err := t.listen(add[6:]); err != nil {
 | 
				
			||||||
						e <- err
 | 
											e <- err
 | 
				
			||||||
						continue
 | 
											continue
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
| 
						 | 
					@ -110,7 +113,7 @@ func (t *tcp) init(l *link) error {
 | 
				
			||||||
		if listenaddr[:6] != "tcp://" {
 | 
							if listenaddr[:6] != "tcp://" {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if err := t.listen(listenaddr[6:]); err != nil {
 | 
							if _, err := t.listen(listenaddr[6:]); err != nil {
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -118,7 +121,7 @@ func (t *tcp) init(l *link) error {
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *tcp) listen(listenaddr string) error {
 | 
					func (t *tcp) listen(listenaddr string) (*net.Listener, error) {
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ctx := context.Background()
 | 
						ctx := context.Background()
 | 
				
			||||||
| 
						 | 
					@ -127,36 +130,36 @@ func (t *tcp) listen(listenaddr string) error {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	listener, err := lc.Listen(ctx, "tcp", listenaddr)
 | 
						listener, err := lc.Listen(ctx, "tcp", listenaddr)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
 | 
							l := tcpListener{
 | 
				
			||||||
 | 
								listener: &listener,
 | 
				
			||||||
 | 
								stop:     make(chan bool, 1),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		t.mutex.Lock()
 | 
							t.mutex.Lock()
 | 
				
			||||||
		t.listeners[listenaddr] = listener
 | 
							t.listeners[listenaddr[6:]] = &l
 | 
				
			||||||
		t.listenerstops[listenaddr] = make(chan bool, 1)
 | 
					 | 
				
			||||||
		t.mutex.Unlock()
 | 
							t.mutex.Unlock()
 | 
				
			||||||
		go t.listener(listenaddr)
 | 
							go t.listener(&l)
 | 
				
			||||||
		return nil
 | 
							return &listener, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return err
 | 
						return nil, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Runs the listener, which spawns off goroutines for incoming connections.
 | 
					// Runs the listener, which spawns off goroutines for incoming connections.
 | 
				
			||||||
func (t *tcp) listener(listenaddr string) {
 | 
					func (t *tcp) listener(listener *tcpListener) {
 | 
				
			||||||
	t.mutex.Lock()
 | 
						if listener == nil {
 | 
				
			||||||
	listener, ok1 := t.listeners[listenaddr]
 | 
					 | 
				
			||||||
	listenerstop, ok2 := t.listenerstops[listenaddr]
 | 
					 | 
				
			||||||
	t.mutex.Unlock()
 | 
					 | 
				
			||||||
	if !ok1 || !ok2 {
 | 
					 | 
				
			||||||
		t.link.core.log.Errorln("Tried to start TCP listener for", listenaddr, "which doesn't exist")
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	reallistenaddr := listener.Addr().String()
 | 
						reallistener := *listener.listener
 | 
				
			||||||
	defer listener.Close()
 | 
						reallistenaddr := reallistener.Addr().String()
 | 
				
			||||||
 | 
						stop := listener.stop
 | 
				
			||||||
 | 
						defer reallistener.Close()
 | 
				
			||||||
	t.link.core.log.Infoln("Listening for TCP on:", reallistenaddr)
 | 
						t.link.core.log.Infoln("Listening for TCP on:", reallistenaddr)
 | 
				
			||||||
	accepted := make(chan bool)
 | 
						accepted := make(chan bool)
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		var sock net.Conn
 | 
							var sock net.Conn
 | 
				
			||||||
		var err error
 | 
							var err error
 | 
				
			||||||
		go func() {
 | 
							go func() {
 | 
				
			||||||
			sock, err = listener.Accept()
 | 
								sock, err = reallistener.Accept()
 | 
				
			||||||
			accepted <- true
 | 
								accepted <- true
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
| 
						 | 
					@ -166,7 +169,7 @@ func (t *tcp) listener(listenaddr string) {
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			go t.handler(sock, true)
 | 
								go t.handler(sock, true)
 | 
				
			||||||
		case <-listenerstop:
 | 
							case <-stop:
 | 
				
			||||||
			t.link.core.log.Errorln("Stopping TCP listener on:", reallistenaddr)
 | 
								t.link.core.log.Errorln("Stopping TCP listener on:", reallistenaddr)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue