This commit is contained in:
Neil Alexander 2023-12-09 22:01:36 +00:00
parent 3a9cdfd9fd
commit b412fc6f0d
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 47 additions and 45 deletions

View file

@ -24,6 +24,8 @@ import (
) )
type YggdrasilTransport struct { type YggdrasilTransport struct {
ctx context.Context
cancel context.CancelFunc
yggdrasil net.PacketConn yggdrasil net.PacketConn
listener *quic.Listener listener *quic.Listener
transport *quic.Transport transport *quic.Transport
@ -38,11 +40,10 @@ type yggdrasilConnection struct {
context.Context context.Context
context.CancelFunc context.CancelFunc
quic.Connection quic.Connection
done chan struct{}
} }
type yggdrasilStream struct { type yggdrasilStream struct {
quic.Connection *yggdrasilConnection
quic.Stream quic.Stream
} }
@ -71,13 +72,14 @@ func New(ygg *core.Core, cert tls.Certificate, qc *quic.Config) (*YggdrasilTrans
yggdrasil: ygg, yggdrasil: ygg,
incoming: make(chan *yggdrasilStream), incoming: make(chan *yggdrasilStream),
} }
tr.ctx, tr.cancel = context.WithCancel(context.Background())
var err error var err error
if tr.listener, err = tr.transport.Listen(tr.tlsConfig, tr.quicConfig); err != nil { if tr.listener, err = tr.transport.Listen(tr.tlsConfig, tr.quicConfig); err != nil {
return nil, fmt.Errorf("quic.Listen: %w", err) return nil, fmt.Errorf("quic.Listen: %w", err)
} }
go tr.connectionAcceptLoop(context.TODO()) go tr.connectionAcceptLoop(tr.ctx)
return tr, nil return tr, nil
} }
@ -92,41 +94,41 @@ func (t *YggdrasilTransport) connectionAcceptLoop(ctx context.Context) {
// will want to shut down the existing one and replace it with // will want to shut down the existing one and replace it with
// this one. // this one.
host := qc.RemoteAddr().String() host := qc.RemoteAddr().String()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(t.ctx)
yc := &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})} yc := &yggdrasilConnection{ctx, cancel, qc}
if eqc, ok := t.connections.Swap(host, yc); ok { if eqc, ok := t.connections.Swap(host, yc); ok {
if eqc, ok := eqc.(*yggdrasilConnection); ok { if eqc, ok := eqc.(*yggdrasilConnection); ok {
eqc.CancelFunc() eqc.CancelFunc()
} }
} }
go t.streamAcceptLoop(yc)
// Now if there are any in-progress dials, we can cancel those // Now if there are any in-progress dials, we can cancel those
// too as we now have an open connection that we can open new // too as we now have an open connection that we can open new
// streams on. // streams on.
if dial, ok := t.dials.LoadAndDelete(host); ok { if dial, ok := t.dials.LoadAndDelete(host); ok {
dial := dial.(*yggdrasilDial) dial.(*yggdrasilDial).CancelFunc()
dial.CancelFunc()
} }
go t.streamAcceptLoop(yc)
} }
} }
func (t *YggdrasilTransport) streamAcceptLoop(yc *yggdrasilConnection) { func (t *YggdrasilTransport) streamAcceptLoop(yc *yggdrasilConnection) {
host := yc.RemoteAddr().String() host := yc.RemoteAddr().String()
defer yc.CloseWithError(0, "Timed out") // nolint:errcheck defer yc.CloseWithError(0, "Timed out") // nolint:errcheck
defer t.connections.Delete(host) defer t.connections.Delete(host)
for { for {
qs, err := yc.AcceptStream(yc.Context) qs, err := yc.AcceptStream(yc.Context)
if err != nil { if err != nil {
yc.CancelFunc()
return return
} }
select { select {
case t.incoming <- &yggdrasilStream{yc.Connection, qs}: case t.incoming <- &yggdrasilStream{yc, qs}:
// An Accept call is waiting.
case <-yc.Context.Done(): case <-yc.Context.Done():
// We've timed out waiting for a call to Accept
// to handle the connection.
return return
} }
} }
@ -140,8 +142,12 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri
if network != "yggdrasil" { if network != "yggdrasil" {
return nil, fmt.Errorf("network must be 'yggdrasil'") return nil, fmt.Errorf("network must be 'yggdrasil'")
} }
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel() // Check if there is already a dial to this host in progress.
// If there is then we will wait for it.
if dial, ok := t.dials.Load(host); ok {
<-dial.(*yggdrasilDial).Done()
}
// We might want to retrying once if part of the dial process fails, // We might want to retrying once if part of the dial process fails,
// but keep a track of whether we're already retrying. // but keep a track of whether we're already retrying.
@ -149,22 +155,14 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri
retry: retry:
yc, ok := t.connections.Load(host) yc, ok := t.connections.Load(host)
if !ok { if !ok {
// Check if there is already a dial to this host in progress.
// If there is then we will wait for it.
if dial, ok := t.dials.Load(host); ok {
<-dial.(*yggdrasilDial).Done()
}
// Even after a dial, there's no connection. This means we // Even after a dial, there's no connection. This means we
// probably failed to dial, so let's try it again. // probably failed to dial, so let's try it again.
if yc, ok = t.connections.Load(host); !ok { if yc, ok = t.connections.Load(host); !ok {
// A cancellable context means we can cancel the dial in // A cancellable context means we can cancel the dial in
// progress from elsewhere if we need to. // progress from elsewhere if we need to.
dialctx, dialcancel := context.WithCancel(ctx) dialctx, dialcancel := context.WithTimeout(ctx, time.Second*5)
defer dialcancel()
// Make a record of the dial context.
t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel}) t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel})
defer dialcancel()
defer t.dials.Delete(host) defer t.dials.Delete(host)
// Decode the address from hex. // Decode the address from hex.
@ -186,7 +184,7 @@ retry:
// the accept loop so that streams can be accepted. // the accept loop so that streams can be accepted.
{ {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
yc = &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})} yc = &yggdrasilConnection{ctx, cancel, qc}
t.connections.Store(host, yc) t.connections.Store(host, yc)
go t.streamAcceptLoop(yc.(*yggdrasilConnection)) go t.streamAcceptLoop(yc.(*yggdrasilConnection))
} }
@ -205,7 +203,7 @@ retry:
} }
return nil, err return nil, err
} }
return &yggdrasilStream{yc.Connection, qs}, err return &yggdrasilStream{yc, qs}, err
} }
// We failed to open a session. // We failed to open a session.
return nil, net.ErrClosed return nil, net.ErrClosed

View file

@ -52,6 +52,7 @@ func TestQUICOverYggdrasil(t *testing.T) {
t.Parallel() t.Parallel()
destination := hex.EncodeToString(node1.PublicKey()) destination := hex.EncodeToString(node1.PublicKey())
for i := 0; i < 5; i++ {
c, err := quic2.Dial("yggdrasil", destination) c, err := quic2.Dial("yggdrasil", destination)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -63,11 +64,13 @@ func TestQUICOverYggdrasil(t *testing.T) {
if err = c.Close(); err != nil { if err = c.Close(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}
}) })
t.Run("Listen", func(t *testing.T) { t.Run("Listen", func(t *testing.T) {
t.Parallel() t.Parallel()
for i := 0; i < 5; i++ {
c, err := quic1.Accept() c, err := quic1.Accept()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -79,5 +82,6 @@ func TestQUICOverYggdrasil(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("Received: %s", b[:]) t.Logf("Received: %s", b[:])
}
}) })
} }