From 3ff2b83e76addf1f41c3d5f3a9727383e70f2ccc Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Sun, 4 Sep 2022 17:35:06 +0100 Subject: [PATCH] Cleaner shutdowns, UNIX socket support, more tweaks --- cmd/yggdrasil/main.go | 8 ++-- src/admin/admin.go | 3 ++ src/core/api.go | 2 + src/core/core.go | 30 ++++++++----- src/core/link.go | 59 +++++++++++++++++++++++--- src/core/link_tcp.go | 29 ++++++------- src/core/link_tls.go | 26 +++++++----- src/core/link_unix.go | 98 +++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 212 insertions(+), 43 deletions(-) create mode 100644 src/core/link_unix.go diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index bd5294fd..1ef6738b 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -322,7 +322,9 @@ func run(args yggArgs, ctx context.Context, done chan struct{}) { if n.admin, err = admin.New(n.core, logger, options...); err != nil { panic(err) } - n.admin.SetupAdminHandlers() + if n.admin != nil { + n.admin.SetupAdminHandlers() + } } // Setup the multicast module. @@ -339,7 +341,7 @@ func run(args yggArgs, ctx context.Context, done chan struct{}) { if n.multicast, err = multicast.New(n.core, logger, options...); err != nil { panic(err) } - if n.admin != nil { + if n.admin != nil && n.multicast != nil { n.multicast.SetupAdminHandlers(n.admin) } } @@ -353,7 +355,7 @@ func run(args yggArgs, ctx context.Context, done chan struct{}) { if n.tuntap, err = tuntap.New(ipv6rwc.NewReadWriteCloser(n.core), logger, options...); err != nil { panic(err) } - if n.admin != nil { + if n.admin != nil && n.tuntap != nil { n.tuntap.SetupAdminHandlers(n.admin) } } diff --git a/src/admin/admin.go b/src/admin/admin.go index 1a402060..4e98c891 100644 --- a/src/admin/admin.go +++ b/src/admin/admin.go @@ -175,6 +175,9 @@ func (a *AdminSocket) IsStarted() bool { // Stop will stop the admin API and close the socket. func (a *AdminSocket) Stop() error { + if a == nil { + return nil + } if a.listener != nil { select { case <-a.done: diff --git a/src/core/api.go b/src/core/api.go index b5f09ab5..571b007d 100644 --- a/src/core/api.go +++ b/src/core/api.go @@ -143,6 +143,8 @@ func (c *Core) Listen(u *url.URL, sintf string) (*Listener, error) { return c.links.tcp.listen(u, sintf) case "tls": return c.links.tls.listen(u, sintf) + case "unix": + return c.links.unix.listen(u, sintf) default: return nil, fmt.Errorf("unrecognised scheme %q", u.Scheme) } diff --git a/src/core/core.go b/src/core/core.go index 8b22ead1..4cc08ad6 100644 --- a/src/core/core.go +++ b/src/core/core.go @@ -48,6 +48,12 @@ func New(secret ed25519.PrivateKey, logger util.Logger, opts ...SetupOption) (*C c := &Core{ log: logger, } + if name := version.BuildName(); name != "unknown" { + c.log.Infoln("Build name:", name) + } + if version := version.BuildVersion(); version != "unknown" { + c.log.Infoln("Build version:", version) + } c.ctx, c.cancel = context.WithCancel(context.Background()) // Take a copy of the private key so that it is in our own memory space. if len(secret) != ed25519.PrivateKeySize { @@ -76,15 +82,17 @@ func New(secret ed25519.PrivateKey, logger util.Logger, opts ...SetupOption) (*C if err := c.proto.nodeinfo.setNodeInfo(c.config.nodeinfo, bool(c.config.nodeinfoPrivacy)); err != nil { return nil, fmt.Errorf("error setting node info: %w", err) } - c.addPeerTimer = time.AfterFunc(time.Minute, func() { - c.Act(nil, c._addPeerLoop) - }) - if name := version.BuildName(); name != "unknown" { - c.log.Infoln("Build name:", name) - } - if version := version.BuildVersion(); version != "unknown" { - c.log.Infoln("Build version:", version) + for listenaddr := range c.config._listeners { + u, err := url.Parse(string(listenaddr)) + if err != nil { + c.log.Errorf("Invalid listener URI %q specified, ignoring\n", listenaddr) + continue + } + if _, err = c.links.listen(u, ""); err != nil { + c.log.Errorf("Failed to start listener %q: %s\n", listenaddr, err) + } } + c.Act(nil, c._addPeerLoop) return c, nil } @@ -92,10 +100,11 @@ func New(secret ed25519.PrivateKey, logger util.Logger, opts ...SetupOption) (*C // configure them. The loop ensures that disconnected peers will eventually // be reconnected with. func (c *Core) _addPeerLoop() { - if c.addPeerTimer == nil { + select { + case <-c.ctx.Done(): return + default: } - // Add peers from the Peers section for peer := range c.config._peers { go func(peer string, intf string) { @@ -126,6 +135,7 @@ func (c *Core) Stop() { // This function is unsafe and should only be ran by the core actor. func (c *Core) _close() error { c.cancel() + _ = c.links.shutdown() err := c.PacketConn.Close() if c.addPeerTimer != nil { c.addPeerTimer.Stop() diff --git a/src/core/link.go b/src/core/link.go index e00cb6f6..6a8a0527 100644 --- a/src/core/link.go +++ b/src/core/link.go @@ -2,7 +2,6 @@ package core import ( "bytes" - "context" "encoding/hex" "errors" "fmt" @@ -27,6 +26,7 @@ type links struct { core *Core tcp *linkTCP // TCP interface support tls *linkTLS // TLS interface support + unix *linkUNIX // UNIX interface support mutex sync.RWMutex // protects links below links map[linkInfo]*link // *link is nil if connection in progress // TODO timeout (to remove from switch), read from config.ReadTimeout @@ -55,13 +55,20 @@ type linkOptions struct { type Listener struct { net.Listener - Close context.CancelFunc // deliberately replaces net.Listener.Close() + closed chan struct{} +} + +func (l *Listener) Close() error { + err := l.Listener.Close() + <-l.closed + return err } func (l *links) init(c *Core) error { l.core = c l.tcp = l.newLinkTCP() l.tls = l.newLinkTLS(l.tcp) + l.unix = l.newLinkUNIX() l.mutex.Lock() l.links = make(map[linkInfo]*link) @@ -78,6 +85,25 @@ func (l *links) init(c *Core) error { return nil } +func (l *links) shutdown() error { + phony.Block(l.tcp, func() { + for l := range l.tcp._listeners { + l.Close() + } + }) + phony.Block(l.tls, func() { + for l := range l.tls._listeners { + l.Close() + } + }) + phony.Block(l.unix, func() { + for l := range l.unix._listeners { + l.Close() + } + }) + return nil +} + func (l *links) isConnectedTo(info linkInfo) bool { l.mutex.RLock() defer l.mutex.RUnlock() @@ -146,12 +172,35 @@ func (l *links) call(u *url.URL, sintf string) error { } }() + case "unix": + go func() { + if err := l.unix.dial(u, tcpOpts.linkOptions, sintf); err != nil { + l.core.log.Warnf("Failed to dial UNIX %s: %s\n", u.Host, err) + } + }() + default: return errors.New("unknown call scheme: " + u.Scheme) } return nil } +func (l *links) listen(u *url.URL, sintf string) (*Listener, error) { + var listener *Listener + var err error + switch u.Scheme { + case "tcp": + listener, err = l.tcp.listen(u, sintf) + case "tls": + listener, err = l.tls.listen(u, sintf) + case "unix": + listener, err = l.unix.listen(u, sintf) + default: + return nil, fmt.Errorf("unrecognised scheme %q", u.Scheme) + } + return listener, err +} + func (l *links) create(conn net.Conn, name string, info linkInfo, incoming, force bool, options linkOptions) error { intf := link{ conn: &linkConn{ @@ -167,7 +216,7 @@ func (l *links) create(conn net.Conn, name string, info linkInfo, incoming, forc } go func() { if err := intf.handler(); err != nil { - l.core.log.Errorf("Link handler error (%s): %s", conn.RemoteAddr(), err) + l.core.log.Errorf("Link handler %s error (%s): %s", name, conn.RemoteAddr(), err) } }() return nil @@ -178,7 +227,7 @@ func (intf *link) handler() error { // Don't connect to this link more than once. if intf.links.isConnectedTo(intf.info) { - return fmt.Errorf("already connected to %+v", intf.info) + return fmt.Errorf("already connected to this node") } // Mark the connection as in progress. @@ -280,7 +329,7 @@ func (intf *link) handler() error { strings.ToUpper(intf.info.linkType), remoteStr, localStr) // TODO don't report an error if it's just a 'use of closed network connection' - if err = intf.links.core.HandleConn(meta.key, intf.conn); err != nil { + if err = intf.links.core.HandleConn(meta.key, intf.conn); err != nil && err != io.EOF { intf.links.core.log.Infof("Disconnected %s: %s, source %s; error: %s", strings.ToUpper(intf.info.linkType), remoteStr, localStr, err) } else { diff --git a/src/core/link_tcp.go b/src/core/link_tcp.go index 4649b332..767df0c8 100644 --- a/src/core/link_tcp.go +++ b/src/core/link_tcp.go @@ -16,7 +16,7 @@ type linkTCP struct { phony.Inbox *links listener *net.ListenConfig - _listeners map[net.Listener]context.CancelFunc + _listeners map[*Listener]context.CancelFunc } type tcpOptions struct { @@ -33,7 +33,7 @@ func (l *links) newLinkTCP() *linkTCP { listener: &net.ListenConfig{ KeepAlive: -1, }, - _listeners: map[net.Listener]context.CancelFunc{}, + _listeners: map[*Listener]context.CancelFunc{}, } lt.listener.Control = lt.tcpContext return lt @@ -64,8 +64,7 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { ctx, cancel := context.WithCancel(l.core.ctx) hostport := url.Host if sintf != "" { - host, port, err := net.SplitHostPort(hostport) - if err == nil { + if host, port, err := net.SplitHostPort(hostport); err == nil { hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port) } } @@ -74,18 +73,23 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { cancel() return nil, err } + entry := &Listener{ + Listener: listener, + closed: make(chan struct{}), + } phony.Block(l, func() { - l._listeners[listener] = cancel + l._listeners[entry] = cancel }) + l.core.log.Printf("TCP listener started on %s", listener.Addr()) go func() { defer phony.Block(l, func() { - delete(l._listeners, listener) + delete(l._listeners, entry) }) for { conn, err := listener.Accept() if err != nil { cancel() - return + break } addr := conn.RemoteAddr().(*net.TCPAddr) name := fmt.Sprintf("tls://%s", addr) @@ -94,11 +98,11 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { l.core.log.Errorln("Failed to create inbound link:", err) } } + listener.Close() + close(entry.closed) + l.core.log.Printf("TCP listener stopped on %s", listener.Addr()) }() - return &Listener{ - Listener: listener, - Close: cancel, - }, nil + return entry, nil } func (l *linkTCP) handler(name string, info linkInfo, conn net.Conn, options tcpOptions, incoming bool) error { @@ -160,9 +164,6 @@ func (l *linkTCP) dialerFor(saddr, sintf string) (*net.Dialer, error) { if err != nil { continue } - if src.Equal(dst.IP) { - continue - } if !src.IsGlobalUnicast() && !src.IsLinkLocalUnicast() { continue } diff --git a/src/core/link_tls.go b/src/core/link_tls.go index fa4d37c1..40c3196c 100644 --- a/src/core/link_tls.go +++ b/src/core/link_tls.go @@ -25,7 +25,7 @@ type linkTLS struct { tcp *linkTCP listener *net.ListenConfig config *tls.Config - _listeners map[net.Listener]context.CancelFunc + _listeners map[*Listener]context.CancelFunc } func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS { @@ -36,7 +36,7 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS { Control: tcp.tcpContext, KeepAlive: -1, }, - _listeners: map[net.Listener]context.CancelFunc{}, + _listeners: map[*Listener]context.CancelFunc{}, } var err error lt.config, err = lt.generateConfig() @@ -75,8 +75,7 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { ctx, cancel := context.WithCancel(l.core.ctx) hostport := url.Host if sintf != "" { - host, port, err := net.SplitHostPort(hostport) - if err == nil { + if host, port, err := net.SplitHostPort(hostport); err == nil { hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port) } } @@ -86,18 +85,23 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { return nil, err } tlslistener := tls.NewListener(listener, l.config) + entry := &Listener{ + Listener: tlslistener, + closed: make(chan struct{}), + } phony.Block(l, func() { - l._listeners[tlslistener] = cancel + l._listeners[entry] = cancel }) + l.core.log.Printf("TLS listener started on %s", listener.Addr()) go func() { defer phony.Block(l, func() { - delete(l._listeners, tlslistener) + delete(l._listeners, entry) }) for { conn, err := tlslistener.Accept() if err != nil { cancel() - return + break } addr := conn.RemoteAddr().(*net.TCPAddr) name := fmt.Sprintf("tls://%s", addr) @@ -106,11 +110,11 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { l.core.log.Errorln("Failed to create inbound link:", err) } } + tlslistener.Close() + close(entry.closed) + l.core.log.Printf("TLS listener stopped on %s", listener.Addr()) }() - return &Listener{ - Listener: tlslistener, - Close: cancel, - }, nil + return entry, nil } func (l *linkTLS) generateConfig() (*tls.Config, error) { diff --git a/src/core/link_unix.go b/src/core/link_unix.go new file mode 100644 index 00000000..d63c1d90 --- /dev/null +++ b/src/core/link_unix.go @@ -0,0 +1,98 @@ +package core + +import ( + "context" + "fmt" + "net" + "net/url" + "time" + + "github.com/Arceliar/phony" +) + +type linkUNIX struct { + phony.Inbox + *links + dialer *net.Dialer + listener *net.ListenConfig + _listeners map[*Listener]context.CancelFunc +} + +func (l *links) newLinkUNIX() *linkUNIX { + lt := &linkUNIX{ + links: l, + dialer: &net.Dialer{ + Timeout: time.Second * 5, + KeepAlive: -1, + }, + listener: &net.ListenConfig{ + KeepAlive: -1, + }, + _listeners: map[*Listener]context.CancelFunc{}, + } + return lt +} + +func (l *linkUNIX) dial(url *url.URL, options linkOptions, _ string) error { + info := linkInfoFor("unix", "", url.Path) + if l.links.isConnectedTo(info) { + return fmt.Errorf("duplicate connection attempt") + } + addr, err := net.ResolveUnixAddr("unix", url.Path) + if err != nil { + return err + } + conn, err := l.dialer.DialContext(l.core.ctx, "unix", addr.String()) + if err != nil { + return err + } + return l.handler(url.String(), info, conn, options, false) +} + +func (l *linkUNIX) listen(url *url.URL, _ string) (*Listener, error) { + ctx, cancel := context.WithCancel(l.core.ctx) + listener, err := l.listener.Listen(ctx, "unix", url.Path) + if err != nil { + cancel() + return nil, err + } + entry := &Listener{ + Listener: listener, + closed: make(chan struct{}), + } + phony.Block(l, func() { + l._listeners[entry] = cancel + }) + l.core.log.Printf("UNIX listener started on %s", listener.Addr()) + go func() { + defer phony.Block(l, func() { + delete(l._listeners, entry) + }) + for { + conn, err := listener.Accept() + if err != nil { + cancel() + break + } + info := linkInfoFor("unix", "", url.String()) + if err = l.handler(url.String(), info, conn, linkOptions{}, true); err != nil { + l.core.log.Errorln("Failed to create inbound link:", err) + } + } + listener.Close() + close(entry.closed) + l.core.log.Printf("UNIX listener stopped on %s", listener.Addr()) + }() + return entry, nil +} + +func (l *linkUNIX) handler(name string, info linkInfo, conn net.Conn, options linkOptions, incoming bool) error { + return l.links.create( + conn, // connection + name, // connection name + info, // connection info + incoming, // not incoming + false, // not forced + options, // connection options + ) +}