mirror of
https://github.com/yggdrasil-network/yggquic.git
synced 2025-05-19 16:35:09 +03:00
Tweaks
This commit is contained in:
parent
3a9cdfd9fd
commit
b412fc6f0d
2 changed files with 47 additions and 45 deletions
50
yggquic.go
50
yggquic.go
|
@ -24,6 +24,8 @@ import (
|
|||
)
|
||||
|
||||
type YggdrasilTransport struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
yggdrasil net.PacketConn
|
||||
listener *quic.Listener
|
||||
transport *quic.Transport
|
||||
|
@ -38,11 +40,10 @@ type yggdrasilConnection struct {
|
|||
context.Context
|
||||
context.CancelFunc
|
||||
quic.Connection
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
type yggdrasilStream struct {
|
||||
quic.Connection
|
||||
*yggdrasilConnection
|
||||
quic.Stream
|
||||
}
|
||||
|
||||
|
@ -71,13 +72,14 @@ func New(ygg *core.Core, cert tls.Certificate, qc *quic.Config) (*YggdrasilTrans
|
|||
yggdrasil: ygg,
|
||||
incoming: make(chan *yggdrasilStream),
|
||||
}
|
||||
tr.ctx, tr.cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
if tr.listener, err = tr.transport.Listen(tr.tlsConfig, tr.quicConfig); err != nil {
|
||||
return nil, fmt.Errorf("quic.Listen: %w", err)
|
||||
}
|
||||
|
||||
go tr.connectionAcceptLoop(context.TODO())
|
||||
go tr.connectionAcceptLoop(tr.ctx)
|
||||
return tr, nil
|
||||
}
|
||||
|
||||
|
@ -92,41 +94,41 @@ func (t *YggdrasilTransport) connectionAcceptLoop(ctx context.Context) {
|
|||
// will want to shut down the existing one and replace it with
|
||||
// this one.
|
||||
host := qc.RemoteAddr().String()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
yc := &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})}
|
||||
ctx, cancel := context.WithCancel(t.ctx)
|
||||
yc := &yggdrasilConnection{ctx, cancel, qc}
|
||||
if eqc, ok := t.connections.Swap(host, yc); ok {
|
||||
if eqc, ok := eqc.(*yggdrasilConnection); ok {
|
||||
eqc.CancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
go t.streamAcceptLoop(yc)
|
||||
|
||||
// Now if there are any in-progress dials, we can cancel those
|
||||
// too as we now have an open connection that we can open new
|
||||
// streams on.
|
||||
if dial, ok := t.dials.LoadAndDelete(host); ok {
|
||||
dial := dial.(*yggdrasilDial)
|
||||
dial.CancelFunc()
|
||||
dial.(*yggdrasilDial).CancelFunc()
|
||||
}
|
||||
|
||||
go t.streamAcceptLoop(yc)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *YggdrasilTransport) streamAcceptLoop(yc *yggdrasilConnection) {
|
||||
host := yc.RemoteAddr().String()
|
||||
|
||||
defer yc.CloseWithError(0, "Timed out") // nolint:errcheck
|
||||
defer t.connections.Delete(host)
|
||||
|
||||
for {
|
||||
qs, err := yc.AcceptStream(yc.Context)
|
||||
if err != nil {
|
||||
yc.CancelFunc()
|
||||
return
|
||||
}
|
||||
select {
|
||||
case t.incoming <- &yggdrasilStream{yc.Connection, qs}:
|
||||
case t.incoming <- &yggdrasilStream{yc, qs}:
|
||||
// An Accept call is waiting.
|
||||
case <-yc.Context.Done():
|
||||
// We've timed out waiting for a call to Accept
|
||||
// to handle the connection.
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -140,8 +142,12 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri
|
|||
if network != "yggdrasil" {
|
||||
return nil, fmt.Errorf("network must be 'yggdrasil'")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
// Check if there is already a dial to this host in progress.
|
||||
// If there is then we will wait for it.
|
||||
if dial, ok := t.dials.Load(host); ok {
|
||||
<-dial.(*yggdrasilDial).Done()
|
||||
}
|
||||
|
||||
// We might want to retrying once if part of the dial process fails,
|
||||
// but keep a track of whether we're already retrying.
|
||||
|
@ -149,22 +155,14 @@ func (t *YggdrasilTransport) DialContext(ctx context.Context, network, host stri
|
|||
retry:
|
||||
yc, ok := t.connections.Load(host)
|
||||
if !ok {
|
||||
// Check if there is already a dial to this host in progress.
|
||||
// If there is then we will wait for it.
|
||||
if dial, ok := t.dials.Load(host); ok {
|
||||
<-dial.(*yggdrasilDial).Done()
|
||||
}
|
||||
|
||||
// Even after a dial, there's no connection. This means we
|
||||
// probably failed to dial, so let's try it again.
|
||||
if yc, ok = t.connections.Load(host); !ok {
|
||||
// A cancellable context means we can cancel the dial in
|
||||
// progress from elsewhere if we need to.
|
||||
dialctx, dialcancel := context.WithCancel(ctx)
|
||||
defer dialcancel()
|
||||
|
||||
// Make a record of the dial context.
|
||||
dialctx, dialcancel := context.WithTimeout(ctx, time.Second*5)
|
||||
t.dials.Store(host, &yggdrasilDial{dialctx, dialcancel})
|
||||
defer dialcancel()
|
||||
defer t.dials.Delete(host)
|
||||
|
||||
// Decode the address from hex.
|
||||
|
@ -186,7 +184,7 @@ retry:
|
|||
// the accept loop so that streams can be accepted.
|
||||
{
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
yc = &yggdrasilConnection{ctx, cancel, qc, make(chan struct{})}
|
||||
yc = &yggdrasilConnection{ctx, cancel, qc}
|
||||
t.connections.Store(host, yc)
|
||||
go t.streamAcceptLoop(yc.(*yggdrasilConnection))
|
||||
}
|
||||
|
@ -205,7 +203,7 @@ retry:
|
|||
}
|
||||
return nil, err
|
||||
}
|
||||
return &yggdrasilStream{yc.Connection, qs}, err
|
||||
return &yggdrasilStream{yc, qs}, err
|
||||
}
|
||||
// We failed to open a session.
|
||||
return nil, net.ErrClosed
|
||||
|
|
|
@ -52,6 +52,7 @@ func TestQUICOverYggdrasil(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
destination := hex.EncodeToString(node1.PublicKey())
|
||||
for i := 0; i < 5; i++ {
|
||||
c, err := quic2.Dial("yggdrasil", destination)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -63,11 +64,13 @@ func TestQUICOverYggdrasil(t *testing.T) {
|
|||
if err = c.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Listen", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
c, err := quic1.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -79,5 +82,6 @@ func TestQUICOverYggdrasil(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Received: %s", b[:])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue