diff --git a/src/core/link.go b/src/core/link.go index 1aedbd44..5ff6c9a3 100644 --- a/src/core/link.go +++ b/src/core/link.go @@ -3,7 +3,6 @@ package core import ( "bytes" "context" - "crypto/ed25519" "encoding/hex" "errors" "fmt" @@ -25,18 +24,16 @@ import ( ) type links struct { - core *Core - tcp *linkTCP // TCP interface support - tls *linkTLS // TLS interface support - mutex sync.RWMutex // protects links below - links map[linkInfo]*link - stopped chan struct{} + core *Core + tcp *linkTCP // TCP interface support + tls *linkTLS // TLS interface support + mutex sync.RWMutex // protects links below + links map[linkInfo]*link // *link is nil if connection in progress // TODO timeout (to remove from switch), read from config.ReadTimeout } // linkInfo is used as a map key type linkInfo struct { - key keyArray linkType string // Type of link, e.g. TCP, AWDL local string // Local name or address remote string // Remote name or address @@ -50,7 +47,6 @@ type link struct { info linkInfo incoming bool force bool - closed chan struct{} } type linkOptions struct { @@ -70,7 +66,6 @@ func (l *links) init(c *Core) error { l.mutex.Lock() l.links = make(map[linkInfo]*link) l.mutex.Unlock() - l.stopped = make(chan struct{}) var listeners []ListenAddress phony.Block(c, func() { @@ -83,8 +78,23 @@ func (l *links) init(c *Core) error { return nil } +func (l *links) isConnectedTo(info linkInfo) bool { + l.mutex.RLock() + defer l.mutex.RUnlock() + fmt.Println(l.links) + _, isConnected := l.links[info] + return isConnected +} + func (l *links) call(u *url.URL, sintf string) error { - // TODO: don't dial duplicates here + info := linkInfo{ + linkType: strings.ToLower(u.Scheme), + local: sintf, + remote: u.Host, + } + if l.isConnectedTo(info) { + return fmt.Errorf("already connected to this node") + } tcpOpts := tcpOptions{ linkOptions: linkOptions{ pinnedEd25519Keys: map[keyArray]struct{}{}, @@ -97,10 +107,10 @@ func (l *links) call(u *url.URL, sintf string) error { tcpOpts.pinnedEd25519Keys[sigPubKey] = struct{}{} } } - switch u.Scheme { + switch info.linkType { case "tcp": go func() { - if _, err := l.tcp.dial(u, tcpOpts, sintf); err != nil { + if err := l.tcp.dial(u, tcpOpts, sintf); err != nil { l.core.log.Warnf("Failed to dial TCP %s: %s\n", u.Host, err) } }() @@ -136,7 +146,7 @@ func (l *links) call(u *url.URL, sintf string) error { } } go func() { - if _, err := l.tls.dial(u, tcpOpts, sintf); err != nil { + if err := l.tls.dial(u, tcpOpts, sintf); err != nil { l.core.log.Warnf("Failed to dial TLS %s: %s\n", u.Host, err) } }() @@ -147,19 +157,7 @@ func (l *links) call(u *url.URL, sintf string) error { return nil } -func (l *links) create(conn net.Conn, name, linkType, local, remote string, incoming, force bool, options linkOptions) (*link, error) { - // Technically anything unique would work for names, but let's pick something human readable, just for debugging - info := linkInfo{ - linkType: linkType, - local: local, - remote: remote, - } - l.mutex.RLock() - _, isIn := l.links[info] - l.mutex.RUnlock() - if isIn { - return nil, fmt.Errorf("duplicate") - } +func (l *links) create(conn net.Conn, name string, info linkInfo, incoming, force bool, options linkOptions) error { intf := link{ conn: &linkConn{ Conn: conn, @@ -173,16 +171,34 @@ func (l *links) create(conn net.Conn, name, linkType, local, remote string, inco force: force, } go func() { - if _, err := intf.handler(); err != nil { - l.core.log.Warnf("Handler error (incoming %v): %s\n", incoming, err) + if err := intf.handler(info); err != nil { + l.core.log.Errorf("Link handler error (%s): %s", conn.RemoteAddr(), err) } }() - return &intf, nil + return nil } -func (intf *link) handler() (chan struct{}, error) { - // TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later +func (intf *link) handler(info linkInfo) error { defer intf.conn.Close() + + // Don't connect to this link more than once. + if intf.links.isConnectedTo(info) { + return fmt.Errorf("already connected to %+v", info) + } + + // Mark the connection as in progress. + intf.links.mutex.Lock() + intf.links.links[info] = nil + intf.links.mutex.Unlock() + + // When we're done, clean up the connection entry. + defer func() { + intf.links.mutex.Lock() + delete(intf.links.links, info) + intf.links.mutex.Unlock() + }() + + // TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later meta := version_getBaseMetadata() meta.key = intf.links.core.public metaBytes := meta.encode() @@ -195,10 +211,10 @@ func (intf *link) handler() (chan struct{}, error) { err = errors.New("incomplete metadata send") } }) { - return nil, errors.New("timeout on metadata send") + return errors.New("timeout on metadata send") } if err != nil { - return nil, fmt.Errorf("write handshake: %w", err) + return fmt.Errorf("write handshake: %w", err) } if !util.FuncTimeout(30*time.Second, func() { var n int @@ -207,15 +223,15 @@ func (intf *link) handler() (chan struct{}, error) { err = errors.New("incomplete metadata recv") } }) { - return nil, errors.New("timeout on metadata recv") + return errors.New("timeout on metadata recv") } if err != nil { - return nil, fmt.Errorf("read handshake: %w", err) + return fmt.Errorf("read handshake: %w", err) } meta = version_metadata{} base := version_getBaseMetadata() if !meta.decode(metaBytes) { - return nil, errors.New("failed to decode metadata") + return errors.New("failed to decode metadata") } if !meta.check() { var connectError string @@ -230,7 +246,7 @@ func (intf *link) handler() (chan struct{}, error) { fmt.Sprintf("%d.%d", base.ver, base.minorVer), fmt.Sprintf("%d.%d", meta.ver, meta.minorVer), ) - return nil, errors.New("remote node is incompatible version") + return errors.New("remote node is incompatible version") } // Check if the remote side matches the keys we expected. This is a bit of a weak // check - in future versions we really should check a signature or something like that. @@ -239,7 +255,7 @@ func (intf *link) handler() (chan struct{}, error) { copy(key[:], meta.key) if _, allowed := pinned[key]; !allowed { intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name()) - return nil, fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys") + return fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys") } } // Check if we're authorized to connect to this key / IP @@ -255,58 +271,50 @@ func (intf *link) handler() (chan struct{}, error) { intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s", strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.key)) intf.close() - return nil, nil + return fmt.Errorf("forbidden connection") } - // Check if we already have a link to this node - copy(intf.info.key[:], meta.key) + intf.links.mutex.Lock() - if oldIntf, isIn := intf.links.links[intf.info]; isIn { - intf.links.mutex.Unlock() - // FIXME we should really return an error and let the caller block instead - // That lets them do things like close connections on its own, avoid printing a connection message in the first place, etc. - intf.links.core.log.Debugln("DEBUG: found existing interface for", intf.name()) - return oldIntf.closed, nil - } else { - intf.closed = make(chan struct{}) - intf.links.links[intf.info] = intf - defer func() { - intf.links.mutex.Lock() - delete(intf.links.links, intf.info) - intf.links.mutex.Unlock() - close(intf.closed) - }() - intf.links.core.log.Debugln("DEBUG: registered interface for", intf.name()) - } + intf.links.links[info] = intf intf.links.mutex.Unlock() - themAddr := address.AddrForKey(ed25519.PublicKey(intf.info.key[:])) - themAddrString := net.IP(themAddr[:]).String() - themString := fmt.Sprintf("%s@%s", themAddrString, intf.info.remote) + + remoteAddr := net.IP(address.AddrForKey(meta.key)[:]).String() + remoteStr := fmt.Sprintf("%s@%s", remoteAddr, intf.info.remote) + localStr := intf.conn.LocalAddr() intf.links.core.log.Infof("Connected %s: %s, source %s", - strings.ToUpper(intf.info.linkType), themString, intf.info.local) - // Run the handler - err = intf.links.core.HandleConn(ed25519.PublicKey(intf.info.key[:]), intf.conn) - if err != nil { - err = fmt.Errorf("connection: %w", err) - } + strings.ToUpper(intf.info.linkType), remoteStr, localStr) + // TODO don't report an error if it's just a 'use of closed network connection' - if err != nil { + if err = intf.links.core.HandleConn(meta.key, intf.conn); err != nil { intf.links.core.log.Infof("Disconnected %s: %s, source %s; error: %s", - strings.ToUpper(intf.info.linkType), themString, intf.info.local, err) + strings.ToUpper(intf.info.linkType), remoteStr, localStr, err) } else { intf.links.core.log.Infof("Disconnected %s: %s, source %s", - strings.ToUpper(intf.info.linkType), themString, intf.info.local) + strings.ToUpper(intf.info.linkType), remoteStr, localStr) } - return nil, err + + return nil } -func (intf *link) close() { - intf.conn.Close() +func (intf *link) close() error { + return intf.conn.Close() } func (intf *link) name() string { return intf.lname } +func linkInfoFor(linkType, local, remote string) linkInfo { + if h, _, err := net.SplitHostPort(remote); err == nil { + remote = h + } + return linkInfo{ + linkType: linkType, + local: local, + remote: remote, + } +} + type linkConn struct { // tx and rx are at the beginning of the struct to ensure 64-bit alignment // on 32-bit platforms, see https://pkg.go.dev/sync/atomic#pkg-note-BUG diff --git a/src/core/link_tcp.go b/src/core/link_tcp.go index eefc1903..818ba1ef 100644 --- a/src/core/link_tcp.go +++ b/src/core/link_tcp.go @@ -38,24 +38,25 @@ func (l *links) newLinkTCP() *linkTCP { return lt } -func (l *linkTCP) dial(url *url.URL, options tcpOptions, sintf string) (*link, error) { +func (l *linkTCP) dial(url *url.URL, options tcpOptions, sintf string) error { + info := linkInfoFor("tcp", url.Host, sintf) + if l.links.isConnectedTo(info) { + return fmt.Errorf("duplicate connection attempt") + } addr, err := net.ResolveTCPAddr("tcp", url.Host) if err != nil { - return nil, err + return err } addr.Zone = sintf dialer, err := l.dialerFor(addr.String(), sintf) if err != nil { - return nil, err + return err } conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) if err != nil { - return nil, err + return err } - if _, err = l.handler("TCP", conn, options, false); err != nil { - l.core.log.Errorln("Failed to create outbound link:", err) - } - return nil, err + return l.handler(url.String(), info, conn, options, false) } func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { @@ -85,8 +86,10 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { cancel() return } - if _, err := l.handler("TCP", conn, tcpOptions{}, true); err != nil { - l.core.log.Errorln("Failed to create link:", err) + name := fmt.Sprintf("tcp://%s", conn.RemoteAddr()) + info := linkInfoFor("tcp", sintf, conn.RemoteAddr().String()) + if err = l.handler(name, info, conn, tcpOptions{}, true); err != nil { + l.core.log.Errorln("Failed to create inbound link:", err) } } }() @@ -96,16 +99,14 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { }, nil } -func (l *linkTCP) handler(proto string, conn net.Conn, options tcpOptions, incoming bool) (*link, error) { +func (l *linkTCP) handler(name string, info linkInfo, conn net.Conn, options tcpOptions, incoming bool) error { return l.links.create( - conn, // connection - conn.RemoteAddr().String(), // connection name - proto, // connection protocol - conn.LocalAddr().String(), // local address - conn.RemoteAddr().String(), // remote address - incoming, // not incoming - false, // not forced - options.linkOptions, // connection options + conn, // connection + name, // connection name + info, // connection info + incoming, // not incoming + false, // not forced + options.linkOptions, // connection options ) } diff --git a/src/core/link_tls.go b/src/core/link_tls.go index 545d55a5..e85a336b 100644 --- a/src/core/link_tls.go +++ b/src/core/link_tls.go @@ -45,15 +45,19 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS { return lt } -func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) (*link, error) { +func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) error { + info := linkInfoFor("tls", url.Host, sintf) + if l.links.isConnectedTo(info) { + return fmt.Errorf("duplicate connection attempt") + } addr, err := net.ResolveTCPAddr("tcp", url.Host) if err != nil { - return nil, err + return err } addr.Zone = sintf dialer, err := l.tcp.dialerFor(addr.String(), sintf) if err != nil { - return nil, err + return err } tlsdialer := &tls.Dialer{ NetDialer: dialer, @@ -61,12 +65,9 @@ func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) (*link, e } conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String()) if err != nil { - return nil, err + return err } - if _, err = l.handler(conn, options, false); err != nil { - l.core.log.Errorln("Failed to create outbound link:", err) - } - return nil, err + return l.handler(url.String(), info, conn, options, false) } func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { @@ -97,7 +98,9 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { cancel() return } - if _, err := l.handler(conn, tcpOptions{}, true); err != nil { + name := fmt.Sprintf("tls://%s", conn.RemoteAddr()) + info := linkInfoFor("tls", sintf, conn.RemoteAddr().String()) + if err = l.handler(name, info, conn, tcpOptions{}, true); err != nil { l.core.log.Errorln("Failed to create inbound link:", err) } } @@ -155,6 +158,6 @@ func (l *linkTLS) generateConfig() (*tls.Config, error) { }, nil } -func (l *linkTLS) handler(conn net.Conn, options tcpOptions, incoming bool) (*link, error) { - return l.tcp.handler("TLS", conn, options, incoming) +func (l *linkTLS) handler(name string, info linkInfo, conn net.Conn, options tcpOptions, incoming bool) error { + return l.tcp.handler(name, info, conn, options, incoming) }