From b412fc6f0d7e73f622d46693e81aee718f528762 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Sat, 9 Dec 2023 22:01:36 +0000 Subject: [PATCH] Tweaks --- yggquic.go | 50 ++++++++++++++++++++++++------------------------- yggquic_test.go | 42 ++++++++++++++++++++++------------------- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/yggquic.go b/yggquic.go index fffb8b7..d9517de 100644 --- a/yggquic.go +++ b/yggquic.go @@ -24,6 +24,8 @@ import ( ) type YggdrasilTransport struct { + ctx context.Context + cancel context.CancelFunc yggdrasil net.PacketConn listener *quic.Listener transport *quic.Transport @@ -38,11 +40,10 @@ type yggdrasilConnection struct { context.Context context.CancelFunc quic.Connection - done chan struct{} } type yggdrasilStream struct { - quic.Connection + *yggdrasilConnection quic.Stream } @@ -71,13 +72,14 @@ func New(ygg *core.Core, cert tls.Certificate, qc *quic.Config) (*YggdrasilTrans yggdrasil: ygg, incoming: make(chan *yggdrasilStream), } + tr.ctx, tr.cancel = context.WithCancel(context.Background()) var err error if tr.listener, err = tr.transport.Listen(tr.tlsConfig, tr.quicConfig); err != nil { return nil, fmt.Errorf("quic.Listen: %w", err) } - go tr.connectionAcceptLoop(context.TODO()) + go tr.connectionAcceptLoop(tr.ctx) return tr, nil } @@ -92,41 +94,41 @@ func (t *YggdrasilTransport) connectionAcceptLoop(ctx context.Context) { // will want to shut down the existing one and replace it with // this one. host := qc.RemoteAddr().String() - ctx, cancel := context.WithCancel(ctx) - yc := &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})} + ctx, cancel := context.WithCancel(t.ctx) + yc := &yggdrasilConnection{ctx, cancel, qc} if eqc, ok := t.connections.Swap(host, yc); ok { if eqc, ok := eqc.(*yggdrasilConnection); ok { eqc.CancelFunc() } } + go t.streamAcceptLoop(yc) + // Now if there are any in-progress dials, we can cancel those // too as we now have an open connection that we can open new // streams on. if dial, ok := t.dials.LoadAndDelete(host); ok { - dial := dial.(*yggdrasilDial) - dial.CancelFunc() + dial.(*yggdrasilDial).CancelFunc() } - - go t.streamAcceptLoop(yc) } } func (t *YggdrasilTransport) streamAcceptLoop(yc *yggdrasilConnection) { host := yc.RemoteAddr().String() - defer yc.CloseWithError(0, "Timed out") // nolint:errcheck defer t.connections.Delete(host) for { qs, err := yc.AcceptStream(yc.Context) if err != nil { - yc.CancelFunc() return } select { - case t.incoming <- &yggdrasilStream{yc.Connection, qs}: + case t.incoming <- &yggdrasilStream{yc, qs}: + // An Accept call is waiting. case <-yc.Context.Done(): + // We've timed out waiting for a call to Accept + // to handle the connection. return } } @@ -140,8 +142,12 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri if network != "yggdrasil" { return nil, fmt.Errorf("network must be 'yggdrasil'") } - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() + + // Check if there is already a dial to this host in progress. + // If there is then we will wait for it. + if dial, ok := t.dials.Load(host); ok { + <-dial.(*yggdrasilDial).Done() + } // We might want to retrying once if part of the dial process fails, // but keep a track of whether we're already retrying. @@ -149,22 +155,14 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri retry: yc, ok := t.connections.Load(host) if !ok { - // Check if there is already a dial to this host in progress. - // If there is then we will wait for it. - if dial, ok := t.dials.Load(host); ok { - <-dial.(*yggdrasilDial).Done() - } - // Even after a dial, there's no connection. This means we // probably failed to dial, so let's try it again. if yc, ok = t.connections.Load(host); !ok { // A cancellable context means we can cancel the dial in // progress from elsewhere if we need to. - dialctx, dialcancel := context.WithCancel(ctx) - defer dialcancel() - - // Make a record of the dial context. + dialctx, dialcancel := context.WithTimeout(ctx, time.Second*5) t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel}) + defer dialcancel() defer t.dials.Delete(host) // Decode the address from hex. @@ -186,7 +184,7 @@ retry: // the accept loop so that streams can be accepted. { ctx, cancel := context.WithCancel(context.Background()) - yc = &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})} + yc = &yggdrasilConnection{ctx, cancel, qc} t.connections.Store(host, yc) go t.streamAcceptLoop(yc.(*yggdrasilConnection)) } @@ -205,7 +203,7 @@ retry: } return nil, err } - return &yggdrasilStream{yc.Connection, qs}, err + return &yggdrasilStream{yc, qs}, err } // We failed to open a session. return nil, net.ErrClosed diff --git a/yggquic_test.go b/yggquic_test.go index e6c4c65..f813876 100644 --- a/yggquic_test.go +++ b/yggquic_test.go @@ -52,32 +52,36 @@ func TestQUICOverYggdrasil(t *testing.T) { t.Parallel() destination := hex.EncodeToString(node1.PublicKey()) - c, err := quic2.Dial("yggdrasil", destination) - if err != nil { - t.Fatal(err) - } - t.Logf("Opened connection to %q", c.RemoteAddr().String()) - if _, err = c.Write([]byte("Hello!")); err != nil { - t.Fatal(err) - } - if err = c.Close(); err != nil { - t.Fatal(err) + for i := 0; i < 5; i++ { + c, err := quic2.Dial("yggdrasil", destination) + if err != nil { + t.Fatal(err) + } + t.Logf("Opened connection to %q", c.RemoteAddr().String()) + if _, err = c.Write([]byte("Hello!")); err != nil { + t.Fatal(err) + } + if err = c.Close(); err != nil { + t.Fatal(err) + } } }) t.Run("Listen", func(t *testing.T) { t.Parallel() - c, err := quic1.Accept() - if err != nil { - t.Fatal(err) - } - t.Logf("Accepted connection from %q", c.RemoteAddr()) + for i := 0; i < 5; i++ { + c, err := quic1.Accept() + if err != nil { + t.Fatal(err) + } + t.Logf("Accepted connection from %q", c.RemoteAddr()) - b, err := io.ReadAll(c) - if err != nil { - t.Fatal(err) + b, err := io.ReadAll(c) + if err != nil { + t.Fatal(err) + } + t.Logf("Received: %s", b[:]) } - t.Logf("Received: %s", b[:]) }) }