More refactoring

This commit is contained in:
Neil Alexander 2022-09-04 12:07:20 +01:00
parent 496eed7974
commit 15ce5ff319
3 changed files with 116 additions and 104 deletions

View file

@ -3,7 +3,6 @@ package core
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/hex"
"errors"
"fmt"
@ -29,14 +28,12 @@ type links struct {
tcp *linkTCP // TCP interface support
tls *linkTLS // TLS interface support
mutex sync.RWMutex // protects links below
links map[linkInfo]*link
stopped chan struct{}
links map[linkInfo]*link // *link is nil if connection in progress
// TODO timeout (to remove from switch), read from config.ReadTimeout
}
// linkInfo is used as a map key
type linkInfo struct {
key keyArray
linkType string // Type of link, e.g. TCP, AWDL
local string // Local name or address
remote string // Remote name or address
@ -50,7 +47,6 @@ type link struct {
info linkInfo
incoming bool
force bool
closed chan struct{}
}
type linkOptions struct {
@ -70,7 +66,6 @@ func (l *links) init(c *Core) error {
l.mutex.Lock()
l.links = make(map[linkInfo]*link)
l.mutex.Unlock()
l.stopped = make(chan struct{})
var listeners []ListenAddress
phony.Block(c, func() {
@ -83,8 +78,23 @@ func (l *links) init(c *Core) error {
return nil
}
func (l *links) isConnectedTo(info linkInfo) bool {
l.mutex.RLock()
defer l.mutex.RUnlock()
fmt.Println(l.links)
_, isConnected := l.links[info]
return isConnected
}
func (l *links) call(u *url.URL, sintf string) error {
// TODO: don't dial duplicates here
info := linkInfo{
linkType: strings.ToLower(u.Scheme),
local: sintf,
remote: u.Host,
}
if l.isConnectedTo(info) {
return fmt.Errorf("already connected to this node")
}
tcpOpts := tcpOptions{
linkOptions: linkOptions{
pinnedEd25519Keys: map[keyArray]struct{}{},
@ -97,10 +107,10 @@ func (l *links) call(u *url.URL, sintf string) error {
tcpOpts.pinnedEd25519Keys[sigPubKey] = struct{}{}
}
}
switch u.Scheme {
switch info.linkType {
case "tcp":
go func() {
if _, err := l.tcp.dial(u, tcpOpts, sintf); err != nil {
if err := l.tcp.dial(u, tcpOpts, sintf); err != nil {
l.core.log.Warnf("Failed to dial TCP %s: %s\n", u.Host, err)
}
}()
@ -136,7 +146,7 @@ func (l *links) call(u *url.URL, sintf string) error {
}
}
go func() {
if _, err := l.tls.dial(u, tcpOpts, sintf); err != nil {
if err := l.tls.dial(u, tcpOpts, sintf); err != nil {
l.core.log.Warnf("Failed to dial TLS %s: %s\n", u.Host, err)
}
}()
@ -147,19 +157,7 @@ func (l *links) call(u *url.URL, sintf string) error {
return nil
}
func (l *links) create(conn net.Conn, name, linkType, local, remote string, incoming, force bool, options linkOptions) (*link, error) {
// Technically anything unique would work for names, but let's pick something human readable, just for debugging
info := linkInfo{
linkType: linkType,
local: local,
remote: remote,
}
l.mutex.RLock()
_, isIn := l.links[info]
l.mutex.RUnlock()
if isIn {
return nil, fmt.Errorf("duplicate")
}
func (l *links) create(conn net.Conn, name string, info linkInfo, incoming, force bool, options linkOptions) error {
intf := link{
conn: &linkConn{
Conn: conn,
@ -173,16 +171,34 @@ func (l *links) create(conn net.Conn, name, linkType, local, remote string, inco
force: force,
}
go func() {
if _, err := intf.handler(); err != nil {
l.core.log.Warnf("Handler error (incoming %v): %s\n", incoming, err)
if err := intf.handler(info); err != nil {
l.core.log.Errorf("Link handler error (%s): %s", conn.RemoteAddr(), err)
}
}()
return &intf, nil
return nil
}
func (intf *link) handler() (chan struct{}, error) {
// TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later
func (intf *link) handler(info linkInfo) error {
defer intf.conn.Close()
// Don't connect to this link more than once.
if intf.links.isConnectedTo(info) {
return fmt.Errorf("already connected to %+v", info)
}
// Mark the connection as in progress.
intf.links.mutex.Lock()
intf.links.links[info] = nil
intf.links.mutex.Unlock()
// When we're done, clean up the connection entry.
defer func() {
intf.links.mutex.Lock()
delete(intf.links.links, info)
intf.links.mutex.Unlock()
}()
// TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later
meta := version_getBaseMetadata()
meta.key = intf.links.core.public
metaBytes := meta.encode()
@ -195,10 +211,10 @@ func (intf *link) handler() (chan struct{}, error) {
err = errors.New("incomplete metadata send")
}
}) {
return nil, errors.New("timeout on metadata send")
return errors.New("timeout on metadata send")
}
if err != nil {
return nil, fmt.Errorf("write handshake: %w", err)
return fmt.Errorf("write handshake: %w", err)
}
if !util.FuncTimeout(30*time.Second, func() {
var n int
@ -207,15 +223,15 @@ func (intf *link) handler() (chan struct{}, error) {
err = errors.New("incomplete metadata recv")
}
}) {
return nil, errors.New("timeout on metadata recv")
return errors.New("timeout on metadata recv")
}
if err != nil {
return nil, fmt.Errorf("read handshake: %w", err)
return fmt.Errorf("read handshake: %w", err)
}
meta = version_metadata{}
base := version_getBaseMetadata()
if !meta.decode(metaBytes) {
return nil, errors.New("failed to decode metadata")
return errors.New("failed to decode metadata")
}
if !meta.check() {
var connectError string
@ -230,7 +246,7 @@ func (intf *link) handler() (chan struct{}, error) {
fmt.Sprintf("%d.%d", base.ver, base.minorVer),
fmt.Sprintf("%d.%d", meta.ver, meta.minorVer),
)
return nil, errors.New("remote node is incompatible version")
return errors.New("remote node is incompatible version")
}
// Check if the remote side matches the keys we expected. This is a bit of a weak
// check - in future versions we really should check a signature or something like that.
@ -239,7 +255,7 @@ func (intf *link) handler() (chan struct{}, error) {
copy(key[:], meta.key)
if _, allowed := pinned[key]; !allowed {
intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name())
return nil, fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys")
return fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys")
}
}
// Check if we're authorized to connect to this key / IP
@ -255,58 +271,50 @@ func (intf *link) handler() (chan struct{}, error) {
intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s",
strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.key))
intf.close()
return nil, nil
}
// Check if we already have a link to this node
copy(intf.info.key[:], meta.key)
intf.links.mutex.Lock()
if oldIntf, isIn := intf.links.links[intf.info]; isIn {
intf.links.mutex.Unlock()
// FIXME we should really return an error and let the caller block instead
// That lets them do things like close connections on its own, avoid printing a connection message in the first place, etc.
intf.links.core.log.Debugln("DEBUG: found existing interface for", intf.name())
return oldIntf.closed, nil
} else {
intf.closed = make(chan struct{})
intf.links.links[intf.info] = intf
defer func() {
intf.links.mutex.Lock()
delete(intf.links.links, intf.info)
intf.links.mutex.Unlock()
close(intf.closed)
}()
intf.links.core.log.Debugln("DEBUG: registered interface for", intf.name())
}
intf.links.mutex.Unlock()
themAddr := address.AddrForKey(ed25519.PublicKey(intf.info.key[:]))
themAddrString := net.IP(themAddr[:]).String()
themString := fmt.Sprintf("%s@%s", themAddrString, intf.info.remote)
intf.links.core.log.Infof("Connected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local)
// Run the handler
err = intf.links.core.HandleConn(ed25519.PublicKey(intf.info.key[:]), intf.conn)
if err != nil {
err = fmt.Errorf("connection: %w", err)
}
// TODO don't report an error if it's just a 'use of closed network connection'
if err != nil {
intf.links.core.log.Infof("Disconnected %s: %s, source %s; error: %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local, err)
} else {
intf.links.core.log.Infof("Disconnected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local)
}
return nil, err
return fmt.Errorf("forbidden connection")
}
func (intf *link) close() {
intf.conn.Close()
intf.links.mutex.Lock()
intf.links.links[info] = intf
intf.links.mutex.Unlock()
remoteAddr := net.IP(address.AddrForKey(meta.key)[:]).String()
remoteStr := fmt.Sprintf("%s@%s", remoteAddr, intf.info.remote)
localStr := intf.conn.LocalAddr()
intf.links.core.log.Infof("Connected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), remoteStr, localStr)
// TODO don't report an error if it's just a 'use of closed network connection'
if err = intf.links.core.HandleConn(meta.key, intf.conn); err != nil {
intf.links.core.log.Infof("Disconnected %s: %s, source %s; error: %s",
strings.ToUpper(intf.info.linkType), remoteStr, localStr, err)
} else {
intf.links.core.log.Infof("Disconnected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), remoteStr, localStr)
}
return nil
}
func (intf *link) close() error {
return intf.conn.Close()
}
func (intf *link) name() string {
return intf.lname
}
func linkInfoFor(linkType, local, remote string) linkInfo {
if h, _, err := net.SplitHostPort(remote); err == nil {
remote = h
}
return linkInfo{
linkType: linkType,
local: local,
remote: remote,
}
}
type linkConn struct {
// tx and rx are at the beginning of the struct to ensure 64-bit alignment
// on 32-bit platforms, see https://pkg.go.dev/sync/atomic#pkg-note-BUG

View file

@ -38,24 +38,25 @@ func (l *links) newLinkTCP() *linkTCP {
return lt
}
func (l *linkTCP) dial(url *url.URL, options tcpOptions, sintf string) (*link, error) {
func (l *linkTCP) dial(url *url.URL, options tcpOptions, sintf string) error {
info := linkInfoFor("tcp", url.Host, sintf)
if l.links.isConnectedTo(info) {
return fmt.Errorf("duplicate connection attempt")
}
addr, err := net.ResolveTCPAddr("tcp", url.Host)
if err != nil {
return nil, err
return err
}
addr.Zone = sintf
dialer, err := l.dialerFor(addr.String(), sintf)
if err != nil {
return nil, err
return err
}
conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String())
if err != nil {
return nil, err
return err
}
if _, err = l.handler("TCP", conn, options, false); err != nil {
l.core.log.Errorln("Failed to create outbound link:", err)
}
return nil, err
return l.handler(url.String(), info, conn, options, false)
}
func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
@ -85,8 +86,10 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
cancel()
return
}
if _, err := l.handler("TCP", conn, tcpOptions{}, true); err != nil {
l.core.log.Errorln("Failed to create link:", err)
name := fmt.Sprintf("tcp://%s", conn.RemoteAddr())
info := linkInfoFor("tcp", sintf, conn.RemoteAddr().String())
if err = l.handler(name, info, conn, tcpOptions{}, true); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err)
}
}
}()
@ -96,13 +99,11 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
}, nil
}
func (l *linkTCP) handler(proto string, conn net.Conn, options tcpOptions, incoming bool) (*link, error) {
func (l *linkTCP) handler(name string, info linkInfo, conn net.Conn, options tcpOptions, incoming bool) error {
return l.links.create(
conn, // connection
conn.RemoteAddr().String(), // connection name
proto, // connection protocol
conn.LocalAddr().String(), // local address
conn.RemoteAddr().String(), // remote address
name, // connection name
info, // connection info
incoming, // not incoming
false, // not forced
options.linkOptions, // connection options

View file

@ -45,15 +45,19 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
return lt
}
func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) (*link, error) {
func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) error {
info := linkInfoFor("tls", url.Host, sintf)
if l.links.isConnectedTo(info) {
return fmt.Errorf("duplicate connection attempt")
}
addr, err := net.ResolveTCPAddr("tcp", url.Host)
if err != nil {
return nil, err
return err
}
addr.Zone = sintf
dialer, err := l.tcp.dialerFor(addr.String(), sintf)
if err != nil {
return nil, err
return err
}
tlsdialer := &tls.Dialer{
NetDialer: dialer,
@ -61,12 +65,9 @@ func (l *linkTLS) dial(url *url.URL, options tcpOptions, sintf string) (*link, e
}
conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String())
if err != nil {
return nil, err
return err
}
if _, err = l.handler(conn, options, false); err != nil {
l.core.log.Errorln("Failed to create outbound link:", err)
}
return nil, err
return l.handler(url.String(), info, conn, options, false)
}
func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {
@ -97,7 +98,9 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {
cancel()
return
}
if _, err := l.handler(conn, tcpOptions{}, true); err != nil {
name := fmt.Sprintf("tls://%s", conn.RemoteAddr())
info := linkInfoFor("tls", sintf, conn.RemoteAddr().String())
if err = l.handler(name, info, conn, tcpOptions{}, true); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err)
}
}
@ -155,6 +158,6 @@ func (l *linkTLS) generateConfig() (*tls.Config, error) {
}, nil
}
func (l *linkTLS) handler(conn net.Conn, options tcpOptions, incoming bool) (*link, error) {
return l.tcp.handler("TLS", conn, options, incoming)
func (l *linkTLS) handler(name string, info linkInfo, conn net.Conn, options tcpOptions, incoming bool) error {
return l.tcp.handler(name, info, conn, options, incoming)
}