Refactoring

This commit is contained in:
Neil Alexander 2023-12-08 19:16:27 +00:00
parent ee4c89a84c
commit 3a9cdfd9fd
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 93 additions and 46 deletions

View file

@ -29,12 +29,19 @@ type YggdrasilTransport struct {
transport *quic.Transport transport *quic.Transport
tlsConfig *tls.Config tlsConfig *tls.Config
quicConfig *quic.Config quicConfig *quic.Config
incoming chan *yggdrasilSession incoming chan *yggdrasilStream
sessions sync.Map // string -> quic.Connection connections sync.Map // string -> *yggdrasilConnection
dials sync.Map // string -> *yggdrasilDial dials sync.Map // string -> *yggdrasilDial
} }
type yggdrasilSession struct { type yggdrasilConnection struct {
context.Context
context.CancelFunc
quic.Connection
done chan struct{}
}
type yggdrasilStream struct {
quic.Connection quic.Connection
quic.Stream quic.Stream
} }
@ -44,22 +51,25 @@ type yggdrasilDial struct {
context.CancelFunc context.CancelFunc
} }
func New(ygg *core.Core, cert tls.Certificate) (*YggdrasilTransport, error) { func New(ygg *core.Core, cert tls.Certificate, qc *quic.Config) (*YggdrasilTransport, error) {
if qc == nil {
qc = &quic.Config{
HandshakeIdleTimeout: time.Second * 5,
MaxIdleTimeout: time.Second * 60,
}
}
tr := &YggdrasilTransport{ tr := &YggdrasilTransport{
tlsConfig: &tls.Config{ tlsConfig: &tls.Config{
ServerName: hex.EncodeToString(ygg.PublicKey()), ServerName: hex.EncodeToString(ygg.PublicKey()),
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true, InsecureSkipVerify: true,
}, },
quicConfig: &quic.Config{ quicConfig: qc,
HandshakeIdleTimeout: time.Second * 5,
MaxIdleTimeout: time.Second * 60,
},
transport: &quic.Transport{ transport: &quic.Transport{
Conn: ygg, Conn: ygg,
}, },
yggdrasil: ygg, yggdrasil: ygg,
incoming: make(chan *yggdrasilSession, 1), incoming: make(chan *yggdrasilStream),
} }
var err error var err error
@ -67,44 +77,58 @@ func New(ygg *core.Core, cert tls.Certificate) (*YggdrasilTransport, error) {
return nil, fmt.Errorf("quic.Listen: %w", err) return nil, fmt.Errorf("quic.Listen: %w", err)
} }
go tr.connectionAcceptLoop() go tr.connectionAcceptLoop(context.TODO())
return tr, nil return tr, nil
} }
func (t *YggdrasilTransport) connectionAcceptLoop() { func (t *YggdrasilTransport) connectionAcceptLoop(ctx context.Context) {
for { for {
qc, err := t.listener.Accept(context.TODO()) qc, err := t.listener.Accept(ctx)
if err != nil { if err != nil {
return return
} }
// If there's already an open connection for this node then we
// will want to shut down the existing one and replace it with
// this one.
host := qc.RemoteAddr().String() host := qc.RemoteAddr().String()
if eqc, ok := t.sessions.LoadAndDelete(host); ok { ctx, cancel := context.WithCancel(ctx)
eqc := eqc.(quic.Connection) yc := &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})}
_ = eqc.CloseWithError(0, "Connection replaced") if eqc, ok := t.connections.Swap(host, yc); ok {
if eqc, ok := eqc.(*yggdrasilConnection); ok {
eqc.CancelFunc()
} }
t.sessions.Store(host, qc) }
// Now if there are any in-progress dials, we can cancel those
// too as we now have an open connection that we can open new
// streams on.
if dial, ok := t.dials.LoadAndDelete(host); ok { if dial, ok := t.dials.LoadAndDelete(host); ok {
dial := dial.(*yggdrasilDial) dial := dial.(*yggdrasilDial)
dial.CancelFunc() dial.CancelFunc()
} }
go t.streamAcceptLoop(qc) go t.streamAcceptLoop(yc)
} }
} }
func (t *YggdrasilTransport) streamAcceptLoop(qc quic.Connection) { func (t *YggdrasilTransport) streamAcceptLoop(yc *yggdrasilConnection) {
host := qc.RemoteAddr().String() host := yc.RemoteAddr().String()
defer qc.CloseWithError(0, "Timed out") // nolint:errcheck defer yc.CloseWithError(0, "Timed out") // nolint:errcheck
defer t.sessions.Delete(host) defer t.connections.Delete(host)
for { for {
qs, err := qc.AcceptStream(context.Background()) qs, err := yc.AcceptStream(yc.Context)
if err != nil { if err != nil {
break yc.CancelFunc()
return
}
select {
case t.incoming <- &yggdrasilStream{yc.Connection, qs}:
case <-yc.Context.Done():
return
} }
t.incoming <- &yggdrasilSession{qc, qs}
} }
} }
@ -118,20 +142,32 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri
} }
ctx, cancel := context.WithTimeout(ctx, time.Second*5) ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel() defer cancel()
var retry bool
// We might want to retrying once if part of the dial process fails,
// but keep a track of whether we're already retrying.
var retrying bool
retry: retry:
qc, ok := t.sessions.Load(host) yc, ok := t.connections.Load(host)
if !ok { if !ok {
// Check if there is already a dial to this host in progress.
// If there is then we will wait for it.
if dial, ok := t.dials.Load(host); ok { if dial, ok := t.dials.Load(host); ok {
<-dial.(*yggdrasilDial).Done() <-dial.(*yggdrasilDial).Done()
} }
if qc, ok = t.sessions.Load(host); !ok {
// Even after a dial, there's no connection. This means we
// probably failed to dial, so let's try it again.
if yc, ok = t.connections.Load(host); !ok {
// A cancellable context means we can cancel the dial in
// progress from elsewhere if we need to.
dialctx, dialcancel := context.WithCancel(ctx) dialctx, dialcancel := context.WithCancel(ctx)
defer dialcancel() defer dialcancel()
// Make a record of the dial context.
t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel}) t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel})
defer t.dials.Delete(host) defer t.dials.Delete(host)
// Decode the address from hex.
addr := make(iwt.Addr, ed25519.PublicKeySize) addr := make(iwt.Addr, ed25519.PublicKeySize)
k, err := hex.DecodeString(host) k, err := hex.DecodeString(host)
if err != nil { if err != nil {
@ -139,29 +175,40 @@ retry:
} }
copy(addr, k) copy(addr, k)
// Attempt to open a QUIC session.
var qc quic.Connection
if qc, err = t.transport.Dial(dialctx, addr, t.tlsConfig, t.quicConfig); err != nil { if qc, err = t.transport.Dial(dialctx, addr, t.tlsConfig, t.quicConfig); err != nil {
return nil, err return nil, err
} }
qc := qc.(quic.Connection) // If we succeeded then we'll store our QUIC connection so
t.sessions.Store(host, qc) // that the next dial can open a stream on it directly. Start
go t.streamAcceptLoop(qc) // the accept loop so that streams can be accepted.
{
ctx, cancel := context.WithCancel(context.Background())
yc = &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})}
t.connections.Store(host, yc)
go t.streamAcceptLoop(yc.(*yggdrasilConnection))
} }
} }
if qc == nil { }
return nil, net.ErrClosed if yc, ok := yc.(*yggdrasilConnection); ok {
} else { // We've either found a session or we successfully
qc := qc.(quic.Connection) // dialed a new one, so open a stream on it.
qs, err := qc.OpenStreamSync(ctx) qs, err := yc.OpenStreamSync(ctx)
if err != nil { if err != nil {
if !retry { // We failed to open a stream, so if this isn't a
retry = true // retry, then let's try opening a new connection.
if !retrying {
retrying = true
goto retry goto retry
} }
return nil, err return nil, err
} }
return &yggdrasilSession{qc, qs}, err return &yggdrasilStream{yc.Connection, qs}, err
} }
// We failed to open a session.
return nil, net.ErrClosed
} }
func (t *YggdrasilTransport) Accept() (net.Conn, error) { func (t *YggdrasilTransport) Accept() (net.Conn, error) {

View file

@ -38,11 +38,11 @@ func TestQUICOverYggdrasil(t *testing.T) {
go node2.HandleConn(node1.PublicKey(), r, 0) // nolint:errcheck go node2.HandleConn(node1.PublicKey(), r, 0) // nolint:errcheck
// Create QUIC over Yggdrasil endpoints. // Create QUIC over Yggdrasil endpoints.
quic1, err := New(node1, *cfg1.Certificate) quic1, err := New(node1, *cfg1.Certificate, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
quic2, err := New(node2, *cfg2.Certificate) quic2, err := New(node2, *cfg2.Certificate, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }