This commit is contained in:
Neil Alexander 2023-12-09 22:01:36 +00:00
parent 3a9cdfd9fd
commit b412fc6f0d
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 47 additions and 45 deletions

View file

@ -24,6 +24,8 @@ import (
)
type YggdrasilTransport struct {
ctx context.Context
cancel context.CancelFunc
yggdrasil net.PacketConn
listener *quic.Listener
transport *quic.Transport
@ -38,11 +40,10 @@ type yggdrasilConnection struct {
context.Context
context.CancelFunc
quic.Connection
done chan struct{}
}
type yggdrasilStream struct {
quic.Connection
*yggdrasilConnection
quic.Stream
}
@ -71,13 +72,14 @@ func New(ygg *core.Core, cert tls.Certificate, qc *quic.Config) (*YggdrasilTrans
yggdrasil: ygg,
incoming: make(chan *yggdrasilStream),
}
tr.ctx, tr.cancel = context.WithCancel(context.Background())
var err error
if tr.listener, err = tr.transport.Listen(tr.tlsConfig, tr.quicConfig); err != nil {
return nil, fmt.Errorf("quic.Listen: %w", err)
}
go tr.connectionAcceptLoop(context.TODO())
go tr.connectionAcceptLoop(tr.ctx)
return tr, nil
}
@ -92,41 +94,41 @@ func (t *YggdrasilTransport) connectionAcceptLoop(ctx context.Context) {
// will want to shut down the existing one and replace it with
// this one.
host := qc.RemoteAddr().String()
ctx, cancel := context.WithCancel(ctx)
yc := &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})}
ctx, cancel := context.WithCancel(t.ctx)
yc := &yggdrasilConnection{ctx, cancel, qc}
if eqc, ok := t.connections.Swap(host, yc); ok {
if eqc, ok := eqc.(*yggdrasilConnection); ok {
eqc.CancelFunc()
}
}
go t.streamAcceptLoop(yc)
// Now if there are any in-progress dials, we can cancel those
// too as we now have an open connection that we can open new
// streams on.
if dial, ok := t.dials.LoadAndDelete(host); ok {
dial := dial.(*yggdrasilDial)
dial.CancelFunc()
dial.(*yggdrasilDial).CancelFunc()
}
go t.streamAcceptLoop(yc)
}
}
func (t *YggdrasilTransport) streamAcceptLoop(yc *yggdrasilConnection) {
host := yc.RemoteAddr().String()
defer yc.CloseWithError(0, "Timed out") // nolint:errcheck
defer t.connections.Delete(host)
for {
qs, err := yc.AcceptStream(yc.Context)
if err != nil {
yc.CancelFunc()
return
}
select {
case t.incoming <- &yggdrasilStream{yc.Connection, qs}:
case t.incoming <- &yggdrasilStream{yc, qs}:
// An Accept call is waiting.
case <-yc.Context.Done():
// We've timed out waiting for a call to Accept
// to handle the connection.
return
}
}
@ -140,8 +142,12 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri
if network != "yggdrasil" {
return nil, fmt.Errorf("network must be 'yggdrasil'")
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
// Check if there is already a dial to this host in progress.
// If there is then we will wait for it.
if dial, ok := t.dials.Load(host); ok {
<-dial.(*yggdrasilDial).Done()
}
// We might want to retrying once if part of the dial process fails,
// but keep a track of whether we're already retrying.
@ -149,22 +155,14 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri
retry:
yc, ok := t.connections.Load(host)
if !ok {
// Check if there is already a dial to this host in progress.
// If there is then we will wait for it.
if dial, ok := t.dials.Load(host); ok {
<-dial.(*yggdrasilDial).Done()
}
// Even after a dial, there's no connection. This means we
// probably failed to dial, so let's try it again.
if yc, ok = t.connections.Load(host); !ok {
// A cancellable context means we can cancel the dial in
// progress from elsewhere if we need to.
dialctx, dialcancel := context.WithCancel(ctx)
defer dialcancel()
// Make a record of the dial context.
dialctx, dialcancel := context.WithTimeout(ctx, time.Second*5)
t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel})
defer dialcancel()
defer t.dials.Delete(host)
// Decode the address from hex.
@ -186,7 +184,7 @@ retry:
// the accept loop so that streams can be accepted.
{
ctx, cancel := context.WithCancel(context.Background())
yc = &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})}
yc = &yggdrasilConnection{ctx, cancel, qc}
t.connections.Store(host, yc)
go t.streamAcceptLoop(yc.(*yggdrasilConnection))
}
@ -205,7 +203,7 @@ retry:
}
return nil, err
}
return &yggdrasilStream{yc.Connection, qs}, err
return &yggdrasilStream{yc, qs}, err
}
// We failed to open a session.
return nil, net.ErrClosed

View file

@ -52,32 +52,36 @@ func TestQUICOverYggdrasil(t *testing.T) {
t.Parallel()
destination := hex.EncodeToString(node1.PublicKey())
c, err := quic2.Dial("yggdrasil", destination)
if err != nil {
t.Fatal(err)
}
t.Logf("Opened connection to %q", c.RemoteAddr().String())
if _, err = c.Write([]byte("Hello!")); err != nil {
t.Fatal(err)
}
if err = c.Close(); err != nil {
t.Fatal(err)
for i := 0; i < 5; i++ {
c, err := quic2.Dial("yggdrasil", destination)
if err != nil {
t.Fatal(err)
}
t.Logf("Opened connection to %q", c.RemoteAddr().String())
if _, err = c.Write([]byte("Hello!")); err != nil {
t.Fatal(err)
}
if err = c.Close(); err != nil {
t.Fatal(err)
}
}
})
t.Run("Listen", func(t *testing.T) {
t.Parallel()
c, err := quic1.Accept()
if err != nil {
t.Fatal(err)
}
t.Logf("Accepted connection from %q", c.RemoteAddr())
for i := 0; i < 5; i++ {
c, err := quic1.Accept()
if err != nil {
t.Fatal(err)
}
t.Logf("Accepted connection from %q", c.RemoteAddr())
b, err := io.ReadAll(c)
if err != nil {
t.Fatal(err)
b, err := io.ReadAll(c)
if err != nil {
t.Fatal(err)
}
t.Logf("Received: %s", b[:])
}
t.Logf("Received: %s", b[:])
})
}