diff --git a/yggquic.go b/yggquic.go index 73b0495..fffb8b7 100644 --- a/yggquic.go +++ b/yggquic.go @@ -24,17 +24,24 @@ import ( ) type YggdrasilTransport struct { - yggdrasil net.PacketConn - listener *quic.Listener - transport *quic.Transport - tlsConfig *tls.Config - quicConfig *quic.Config - incoming chan *yggdrasilSession - sessions sync.Map // string -> quic.Connection - dials sync.Map // string -> *yggdrasilDial + yggdrasil net.PacketConn + listener *quic.Listener + transport *quic.Transport + tlsConfig *tls.Config + quicConfig *quic.Config + incoming chan *yggdrasilStream + connections sync.Map // string -> *yggdrasilConnection + dials sync.Map // string -> *yggdrasilDial } -type yggdrasilSession struct { +type yggdrasilConnection struct { + context.Context + context.CancelFunc + quic.Connection + done chan struct{} +} + +type yggdrasilStream struct { quic.Connection quic.Stream } @@ -44,22 +51,25 @@ type yggdrasilDial struct { context.CancelFunc } -func New(ygg *core.Core, cert tls.Certificate) (*YggdrasilTransport, error) { +func New(ygg *core.Core, cert tls.Certificate, qc *quic.Config) (*YggdrasilTransport, error) { + if qc == nil { + qc = &quic.Config{ + HandshakeIdleTimeout: time.Second * 5, + MaxIdleTimeout: time.Second * 60, + } + } tr := &YggdrasilTransport{ tlsConfig: &tls.Config{ ServerName: hex.EncodeToString(ygg.PublicKey()), Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true, }, - quicConfig: &quic.Config{ - HandshakeIdleTimeout: time.Second * 5, - MaxIdleTimeout: time.Second * 60, - }, + quicConfig: qc, transport: &quic.Transport{ Conn: ygg, }, yggdrasil: ygg, - incoming: make(chan *yggdrasilSession, 1), + incoming: make(chan *yggdrasilStream), } var err error @@ -67,44 +77,58 @@ func New(ygg *core.Core, cert tls.Certificate) (*YggdrasilTransport, error) { return nil, fmt.Errorf("quic.Listen: %w", err) } - go tr.connectionAcceptLoop() + go tr.connectionAcceptLoop(context.TODO()) return tr, nil } -func (t *YggdrasilTransport) connectionAcceptLoop() { +func (t *YggdrasilTransport) connectionAcceptLoop(ctx context.Context) { for { - qc, err := t.listener.Accept(context.TODO()) + qc, err := t.listener.Accept(ctx) if err != nil { return } + // If there's already an open connection for this node then we + // will want to shut down the existing one and replace it with + // this one. host := qc.RemoteAddr().String() - if eqc, ok := t.sessions.LoadAndDelete(host); ok { - eqc := eqc.(quic.Connection) - _ = eqc.CloseWithError(0, "Connection replaced") + ctx, cancel := context.WithCancel(ctx) + yc := &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})} + if eqc, ok := t.connections.Swap(host, yc); ok { + if eqc, ok := eqc.(*yggdrasilConnection); ok { + eqc.CancelFunc() + } } - t.sessions.Store(host, qc) + + // Now if there are any in-progress dials, we can cancel those + // too as we now have an open connection that we can open new + // streams on. if dial, ok := t.dials.LoadAndDelete(host); ok { dial := dial.(*yggdrasilDial) dial.CancelFunc() } - go t.streamAcceptLoop(qc) + go t.streamAcceptLoop(yc) } } -func (t *YggdrasilTransport) streamAcceptLoop(qc quic.Connection) { - host := qc.RemoteAddr().String() +func (t *YggdrasilTransport) streamAcceptLoop(yc *yggdrasilConnection) { + host := yc.RemoteAddr().String() - defer qc.CloseWithError(0, "Timed out") // nolint:errcheck - defer t.sessions.Delete(host) + defer yc.CloseWithError(0, "Timed out") // nolint:errcheck + defer t.connections.Delete(host) for { - qs, err := qc.AcceptStream(context.Background()) + qs, err := yc.AcceptStream(yc.Context) if err != nil { - break + yc.CancelFunc() + return + } + select { + case t.incoming <- &yggdrasilStream{yc.Connection, qs}: + case <-yc.Context.Done(): + return } - t.incoming <- &yggdrasilSession{qc, qs} } } @@ -118,20 +142,32 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - var retry bool + + // We might want to retrying once if part of the dial process fails, + // but keep a track of whether we're already retrying. + var retrying bool retry: - qc, ok := t.sessions.Load(host) + yc, ok := t.connections.Load(host) if !ok { + // Check if there is already a dial to this host in progress. + // If there is then we will wait for it. if dial, ok := t.dials.Load(host); ok { <-dial.(*yggdrasilDial).Done() } - if qc, ok = t.sessions.Load(host); !ok { + + // Even after a dial, there's no connection. This means we + // probably failed to dial, so let's try it again. + if yc, ok = t.connections.Load(host); !ok { + // A cancellable context means we can cancel the dial in + // progress from elsewhere if we need to. dialctx, dialcancel := context.WithCancel(ctx) defer dialcancel() + // Make a record of the dial context. t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel}) defer t.dials.Delete(host) + // Decode the address from hex. addr := make(iwt.Addr, ed25519.PublicKeySize) k, err := hex.DecodeString(host) if err != nil { @@ -139,29 +175,40 @@ retry: } copy(addr, k) + // Attempt to open a QUIC session. + var qc quic.Connection if qc, err = t.transport.Dial(dialctx, addr, t.tlsConfig, t.quicConfig); err != nil { return nil, err } - qc := qc.(quic.Connection) - t.sessions.Store(host, qc) - go t.streamAcceptLoop(qc) + // If we succeeded then we'll store our QUIC connection so + // that the next dial can open a stream on it directly. Start + // the accept loop so that streams can be accepted. + { + ctx, cancel := context.WithCancel(context.Background()) + yc = &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})} + t.connections.Store(host, yc) + go t.streamAcceptLoop(yc.(*yggdrasilConnection)) + } } } - if qc == nil { - return nil, net.ErrClosed - } else { - qc := qc.(quic.Connection) - qs, err := qc.OpenStreamSync(ctx) + if yc, ok := yc.(*yggdrasilConnection); ok { + // We've either found a session or we successfully + // dialed a new one, so open a stream on it. + qs, err := yc.OpenStreamSync(ctx) if err != nil { - if !retry { - retry = true + // We failed to open a stream, so if this isn't a + // retry, then let's try opening a new connection. + if !retrying { + retrying = true goto retry } return nil, err } - return &yggdrasilSession{qc, qs}, err + return &yggdrasilStream{yc.Connection, qs}, err } + // We failed to open a session. + return nil, net.ErrClosed } func (t *YggdrasilTransport) Accept() (net.Conn, error) { diff --git a/yggquic_test.go b/yggquic_test.go index 7f2e5a3..e6c4c65 100644 --- a/yggquic_test.go +++ b/yggquic_test.go @@ -38,11 +38,11 @@ func TestQUICOverYggdrasil(t *testing.T) { go node2.HandleConn(node1.PublicKey(), r, 0) // nolint:errcheck // Create QUIC over Yggdrasil endpoints. - quic1, err := New(node1, *cfg1.Certificate) + quic1, err := New(node1, *cfg1.Certificate, nil) if err != nil { t.Fatal(err) } - quic2, err := New(node2, *cfg2.Certificate) + quic2, err := New(node2, *cfg2.Certificate, nil) if err != nil { t.Fatal(err) }