More refactoring

This commit is contained in:
Neil Alexander 2022-09-04 12:07:20 +01:00
parent 496eed7974
commit 15ce5ff319
3 changed files with 116 additions and 104 deletions

View file

@ -3,7 +3,6 @@ package core
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/ed25519"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
@ -25,18 +24,16 @@ import (
) )
type links struct { type links struct {
core *Core core *Core
tcp *linkTCP // TCP interface support tcp *linkTCP // TCP interface support
tls *linkTLS // TLS interface support tls *linkTLS // TLS interface support
mutex sync.RWMutex // protects links below mutex sync.RWMutex // protects links below
links map[linkInfo]*link links map[linkInfo]*link // *link is nil if connection in progress
stopped chan struct{}
// TODO timeout (to remove from switch), read from config.ReadTimeout // TODO timeout (to remove from switch), read from config.ReadTimeout
} }
// linkInfo is used as a map key // linkInfo is used as a map key
type linkInfo struct { type linkInfo struct {
key keyArray
linkType string // Type of link, e.g. TCP, AWDL linkType string // Type of link, e.g. TCP, AWDL
local string // Local name or address local string // Local name or address
remote string // Remote name or address remote string // Remote name or address
@ -50,7 +47,6 @@ type link struct {
info linkInfo info linkInfo
incoming bool incoming bool
force bool force bool
closed chan struct{}
} }
type linkOptions struct { type linkOptions struct {
@ -70,7 +66,6 @@ func (l *links) init(c *Core) error {
l.mutex.Lock() l.mutex.Lock()
l.links = make(map[linkInfo]*link) l.links = make(map[linkInfo]*link)
l.mutex.Unlock() l.mutex.Unlock()
l.stopped = make(chan struct{})
var listeners []ListenAddress var listeners []ListenAddress
phony.Block(c, func() { phony.Block(c, func() {
@ -83,8 +78,23 @@ func (l *links) init(c *Core) error {
return nil return nil
} }
func (l *links) isConnectedTo(info linkInfo) bool {
l.mutex.RLock()
defer l.mutex.RUnlock()
fmt.Println(l.links)
_, isConnected := l.links[info]
return isConnected
}
func (l *links) call(u *url.URL, sintf string) error { func (l *links) call(u *url.URL, sintf string) error {
// TODO: don't dial duplicates here info := linkInfo{
linkType: strings.ToLower(u.Scheme),
local: sintf,
remote: u.Host,
}
if l.isConnectedTo(info) {
return fmt.Errorf("already connected to this node")
}
tcpOpts := tcpOptions{ tcpOpts := tcpOptions{
linkOptions: linkOptions{ linkOptions: linkOptions{
pinnedEd25519Keys: map[keyArray]struct{}{}, pinnedEd25519Keys: map[keyArray]struct{}{},
@ -97,10 +107,10 @@ func (l *links) call(u *url.URL, sintf string) error {
tcpOpts.pinnedEd25519Keys[sigPubKey] = struct{}{} tcpOpts.pinnedEd25519Keys[sigPubKey] = struct{}{}
} }
} }
switch u.Scheme { switch info.linkType {
case "tcp": case "tcp":
go func() { go func() {
if _, err := l.tcp.dial(u, tcpOpts, sintf); err != nil { if err := l.tcp.dial(u, tcpOpts, sintf); err != nil {
l.core.log.Warnf("Failed to dial TCP %s: %s\n", u.Host, err) l.core.log.Warnf("Failed to dial TCP %s: %s\n", u.Host, err)
} }
}() }()
@ -136,7 +146,7 @@ func (l *links) call(u *url.URL, sintf string) error {
} }
} }
go func() { go func() {
if _, err := l.tls.dial(u, tcpOpts, sintf); err != nil { if err := l.tls.dial(u, tcpOpts, sintf); err != nil {
l.core.log.Warnf("Failed to dial TLS %s: %s\n", u.Host, err) l.core.log.Warnf("Failed to dial TLS %s: %s\n", u.Host, err)
} }
}() }()
@ -147,19 +157,7 @@ func (l *links) call(u *url.URL, sintf string) error {
return nil return nil
} }
func (l *links) create(conn net.Conn, name, linkType, local, remote string, incoming, force bool, options linkOptions) (*link, error) { func (l *links) create(conn net.Conn, name string, info linkInfo, incoming, force bool, options linkOptions) error {
// Technically anything unique would work for names, but let's pick something human readable, just for debugging
info := linkInfo{
linkType: linkType,
local: local,
remote: remote,
}
l.mutex.RLock()
_, isIn := l.links[info]
l.mutex.RUnlock()
if isIn {
return nil, fmt.Errorf("duplicate")
}
intf := link{ intf := link{
conn: &linkConn{ conn: &linkConn{
Conn: conn, Conn: conn,
@ -173,16 +171,34 @@ func (l *links) create(conn net.Conn, name, linkType, local, remote string, inco
force: force, force: force,
} }
go func() { go func() {
if _, err := intf.handler(); err != nil { if err := intf.handler(info); err != nil {
l.core.log.Warnf("Handler error (incoming %v): %s\n", incoming, err) l.core.log.Errorf("Link handler error (%s): %s", conn.RemoteAddr(), err)
} }
}() }()
return &intf, nil return nil
} }
func (intf *link) handler() (chan struct{}, error) { func (intf *link) handler(info linkInfo) error {
// TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later
defer intf.conn.Close() defer intf.conn.Close()
// Don't connect to this link more than once.
if intf.links.isConnectedTo(info) {
return fmt.Errorf("already connected to %+v", info)
}
// Mark the connection as in progress.
intf.links.mutex.Lock()
intf.links.links[info] = nil
intf.links.mutex.Unlock()
// When we're done, clean up the connection entry.
defer func() {
intf.links.mutex.Lock()
delete(intf.links.links, info)
intf.links.mutex.Unlock()
}()
// TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later
meta := version_getBaseMetadata() meta := version_getBaseMetadata()
meta.key = intf.links.core.public meta.key = intf.links.core.public
metaBytes := meta.encode() metaBytes := meta.encode()
@ -195,10 +211,10 @@ func (intf *link) handler() (chan struct{}, error) {
err = errors.New("incomplete metadata send") err = errors.New("incomplete metadata send")
} }
}) { }) {
return nil, errors.New("timeout on metadata send") return errors.New("timeout on metadata send")
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("write handshake: %w", err) return fmt.Errorf("write handshake: %w", err)
} }
if !util.FuncTimeout(30*time.Second, func() { if !util.FuncTimeout(30*time.Second, func() {
var n int var n int
@ -207,15 +223,15 @@ func (intf *link) handler() (chan struct{}, error) {
err = errors.New("incomplete metadata recv") err = errors.New("incomplete metadata recv")
} }
}) { }) {
return nil, errors.New("timeout on metadata recv") return errors.New("timeout on metadata recv")
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("read handshake: %w", err) return fmt.Errorf("read handshake: %w", err)
} }
meta = version_metadata{} meta = version_metadata{}
base := version_getBaseMetadata() base := version_getBaseMetadata()
if !meta.decode(metaBytes) { if !meta.decode(metaBytes) {
return nil, errors.New("failed to decode metadata") return errors.New("failed to decode metadata")
} }
if !meta.check() { if !meta.check() {
var connectError string var connectError string
@ -230,7 +246,7 @@ func (intf *link) handler() (chan struct{}, error) {
fmt.Sprintf("%d.%d", base.ver, base.minorVer), fmt.Sprintf("%d.%d", base.ver, base.minorVer),
fmt.Sprintf("%d.%d", meta.ver, meta.minorVer), fmt.Sprintf("%d.%d", meta.ver, meta.minorVer),
) )
return nil, errors.New("remote node is incompatible version") return errors.New("remote node is incompatible version")
} }
// Check if the remote side matches the keys we expected. This is a bit of a weak // Check if the remote side matches the keys we expected. This is a bit of a weak
// check - in future versions we really should check a signature or something like that. // check - in future versions we really should check a signature or something like that.
@ -239,7 +255,7 @@ func (intf *link) handler() (chan struct{}, error) {
copy(key[:], meta.key) copy(key[:], meta.key)
if _, allowed := pinned[key]; !allowed { if _, allowed := pinned[key]; !allowed {
intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name()) intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name())
return nil, fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys") return fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys")
} }
} }
// Check if we're authorized to connect to this key / IP // Check if we're authorized to connect to this key / IP
@ -255,58 +271,50 @@ func (intf *link) handler() (chan struct{}, error) {
intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s", intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s",
strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.key)) strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.key))
intf.close() intf.close()
return nil, nil return fmt.Errorf("forbidden connection")
} }
// Check if we already have a link to this node
copy(intf.info.key[:], meta.key)
intf.links.mutex.Lock() intf.links.mutex.Lock()
if oldIntf, isIn := intf.links.links[intf.info]; isIn { intf.links.links[info] = intf
intf.links.mutex.Unlock()
// FIXME we should really return an error and let the caller block instead
// That lets them do things like close connections on its own, avoid printing a connection message in the first place, etc.
intf.links.core.log.Debugln("DEBUG: found existing interface for", intf.name())
return oldIntf.closed, nil
} else {
intf.closed = make(chan struct{})
intf.links.links[intf.info] = intf
defer func() {
intf.links.mutex.Lock()
delete(intf.links.links, intf.info)
intf.links.mutex.Unlock()
close(intf.closed)
}()
intf.links.core.log.Debugln("DEBUG: registered interface for", intf.name())
}
intf.links.mutex.Unlock() intf.links.mutex.Unlock()
themAddr := address.AddrForKey(ed25519.PublicKey(intf.info.key[:]))
themAddrString := net.IP(themAddr[:]).String() remoteAddr := net.IP(address.AddrForKey(meta.key)[:]).String()
themString := fmt.Sprintf("%s@%s", themAddrString, intf.info.remote) remoteStr := fmt.Sprintf("%s@%s", remoteAddr, intf.info.remote)
localStr := intf.conn.LocalAddr()
intf.links.core.log.Infof("Connected %s: %s, source %s", intf.links.core.log.Infof("Connected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local) strings.ToUpper(intf.info.linkType), remoteStr, localStr)
// Run the handler
err = intf.links.core.HandleConn(ed25519.PublicKey(intf.info.key[:]), intf.conn)
if err != nil {
err = fmt.Errorf("connection: %w", err)
}
// TODO don't report an error if it's just a 'use of closed network connection' // TODO don't report an error if it's just a 'use of closed network connection'
if err != nil { if err = intf.links.core.HandleConn(meta.key, intf.conn); err != nil {
intf.links.core.log.Infof("Disconnected %s: %s, source %s; error: %s", intf.links.core.log.Infof("Disconnected %s: %s, source %s; error: %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local, err) strings.ToUpper(intf.info.linkType), remoteStr, localStr, err)
} else { } else {
intf.links.core.log.Infof("Disconnected %s: %s, source %s", intf.links.core.log.Infof("Disconnected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local) strings.ToUpper(intf.info.linkType), remoteStr, localStr)
} }
return nil, err
return nil
} }
func (intf *link) close() { func (intf *link) close() error {
intf.conn.Close() return intf.conn.Close()
} }
func (intf *link) name() string { func (intf *link) name() string {
return intf.lname return intf.lname
} }
func linkInfoFor(linkType, local, remote string) linkInfo {
if h, _, err := net.SplitHostPort(remote); err == nil {
remote = h
}
return linkInfo{
linkType: linkType,
local: local,
remote: remote,
}
}
type linkConn struct { type linkConn struct {
// tx and rx are at the beginning of the struct to ensure 64-bit alignment // tx and rx are at the beginning of the struct to ensure 64-bit alignment
// on 32-bit platforms, see https://pkg.go.dev/sync/atomic#pkg-note-BUG // on 32-bit platforms, see https://pkg.go.dev/sync/atomic#pkg-note-BUG

View file

@ -38,24 +38,25 @@ func (l *links) newLinkTCP() *linkTCP {
return lt return lt
} }
func (l *linkTCP) dial(url *url.URL, options tcpOptions, sintf string) (*link, error) { func (l *linkTCP) dial(url *url.URL, options tcpOptions, sintf string) error {
info := linkInfoFor("tcp", url.Host, sintf)
if l.links.isConnectedTo(info) {
return fmt.Errorf("duplicate connection attempt")
}
addr, err := net.ResolveTCPAddr("tcp", url.Host) addr, err := net.ResolveTCPAddr("tcp", url.Host)
if err != nil { if err != nil {
return nil, err return err
} }
addr.Zone = sintf addr.Zone = sintf
dialer, err := l.dialerFor(addr.String(), sintf) dialer, err := l.dialerFor(addr.String(), sintf)
if err != nil { if err != nil {
return nil, err return err
} }
conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String())
if err != nil { if err != nil {
return nil, err return err
} }
if _, err = l.handler("TCP", conn, options, false); err != nil { return l.handler(url.String(), info, conn, options, false)
l.core.log.Errorln("Failed to create outbound link:", err)
}
return nil, err
} }
func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
@ -85,8 +86,10 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
cancel() cancel()
return return
} }
if _, err := l.handler("TCP", conn, tcpOptions{}, true); err != nil { name := fmt.Sprintf("tcp://%s", conn.RemoteAddr())
l.core.log.Errorln("Failed to create link:", err) info := linkInfoFor("tcp", sintf, conn.RemoteAddr().String())
if err = l.handler(name, info, conn, tcpOptions{}, true); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err)
} }
} }
}() }()
@ -96,16 +99,14 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
}, nil }, nil
} }
func (l *linkTCP) handler(proto string, conn net.Conn, options tcpOptions, incoming bool) (*link, error) { func (l *linkTCP) handler(name string, info linkInfo, conn net.Conn, options tcpOptions, incoming bool) error {
return l.links.create( return l.links.create(
conn, // connection conn, // connection
conn.RemoteAddr().String(), // connection name name, // connection name
proto, // connection protocol info, // connection info
conn.LocalAddr().String(), // local address incoming, // not incoming
conn.RemoteAddr().String(), // remote address false, // not forced
incoming, // not incoming options.linkOptions, // connection options
false, // not forced
options.linkOptions, // connection options
) )
} }

View file

@ -45,15 +45,19 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
return lt return lt
} }
func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) (*link, error) { func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) error {
info := linkInfoFor("tls", url.Host, sintf)
if l.links.isConnectedTo(info) {
return fmt.Errorf("duplicate connection attempt")
}
addr, err := net.ResolveTCPAddr("tcp", url.Host) addr, err := net.ResolveTCPAddr("tcp", url.Host)
if err != nil { if err != nil {
return nil, err return err
} }
addr.Zone = sintf addr.Zone = sintf
dialer, err := l.tcp.dialerFor(addr.String(), sintf) dialer, err := l.tcp.dialerFor(addr.String(), sintf)
if err != nil { if err != nil {
return nil, err return err
} }
tlsdialer := &tls.Dialer{ tlsdialer := &tls.Dialer{
NetDialer: dialer, NetDialer: dialer,
@ -61,12 +65,9 @@ func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) (*link, e
} }
conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String()) conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String())
if err != nil { if err != nil {
return nil, err return err
} }
if _, err = l.handler(conn, options, false); err != nil { return l.handler(url.String(), info, conn, options, false)
l.core.log.Errorln("Failed to create outbound link:", err)
}
return nil, err
} }
func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {
@ -97,7 +98,9 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {
cancel() cancel()
return return
} }
if _, err := l.handler(conn, tcpOptions{}, true); err != nil { name := fmt.Sprintf("tls://%s", conn.RemoteAddr())
info := linkInfoFor("tls", sintf, conn.RemoteAddr().String())
if err = l.handler(name, info, conn, tcpOptions{}, true); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err) l.core.log.Errorln("Failed to create inbound link:", err)
} }
} }
@ -155,6 +158,6 @@ func (l *linkTLS) generateConfig() (*tls.Config, error) {
}, nil }, nil
} }
func (l *linkTLS) handler(conn net.Conn, options tcpOptions, incoming bool) (*link, error) { func (l *linkTLS) handler(name string, info linkInfo, conn net.Conn, options tcpOptions, incoming bool) error {
return l.tcp.handler("TLS", conn, options, incoming) return l.tcp.handler(name, info, conn, options, incoming)
} }