Try to fix #4

... by catching TCP RST packets in WritePackets and sending them
during the next WritePackets call where no RST packet is being sent

Signed-off-by: Vasyl Gello <vasek.gello@gmail.com>
This commit is contained in:
Vasyl Gello 2024-07-18 11:48:42 +03:00
parent b160b3f66d
commit 30d51ba566

View file

@ -12,6 +12,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
type YggdrasilNIC struct {
@ -20,15 +21,17 @@ type YggdrasilNIC struct {
dispatcher stack.NetworkDispatcher
readBuf []byte
writeBuf []byte
rstPackets chan *stack.PacketBuffer
}
func (s *YggdrasilNetstack) NewYggdrasilNIC(ygg *core.Core) tcpip.Error {
rwc := ipv6rwc.NewReadWriteCloser(ygg)
mtu := rwc.MTU()
nic := &YggdrasilNIC{
ipv6rwc: rwc,
readBuf: make([]byte, mtu),
writeBuf: make([]byte, mtu),
ipv6rwc: rwc,
readBuf: make([]byte, mtu),
writeBuf: make([]byte, mtu),
rstPackets: make(chan *stack.PacketBuffer, 100),
}
if err := s.stack.CreateNIC(1, nic); err != nil {
return err
@ -48,6 +51,15 @@ func (s *YggdrasilNetstack) NewYggdrasilNIC(ygg *core.Core) tcpip.Error {
nic.dispatcher.DeliverNetworkPacket(ipv6.ProtocolNumber, pkb)
}
}()
go func() {
for {
pkt := <- nic.rstPackets
if pkt == nil {
continue
}
_ = nic.writePacket(pkt)
}
}()
_, snet, err := net.ParseCIDR("0200::/7")
if err != nil {
return &tcpip.ErrBadAddress{}
@ -93,21 +105,48 @@ func (*YggdrasilNIC) LinkAddress() tcpip.LinkAddress { return "" }
func (*YggdrasilNIC) Wait() {}
func (e *YggdrasilNIC) writePacket(
pkt *stack.PacketBuffer,
) tcpip.Error {
// We need to recover from panic() here because
// parser in ToView() gets confused on some packets
// without payload and panics
defer func() {
r := recover()
if r != nil {
}
}()
vv := pkt.ToView()
n, err := vv.Read(e.writeBuf)
if err != nil {
return &tcpip.ErrAborted{}
}
_, err = e.ipv6rwc.Write(e.writeBuf[:n])
if err != nil {
return &tcpip.ErrAborted{}
}
return nil
}
func (e *YggdrasilNIC) WritePackets(
list stack.PacketBufferList,
) (int, tcpip.Error) {
var i int = 0
var err tcpip.Error = nil
for i, pkt := range list.AsSlice() {
vv := pkt.ToView()
n, err := vv.Read(e.writeBuf)
if err != nil {
log.Println(err)
return i - 1, &tcpip.ErrAborted{}
if pkt.Data().Size() == 0 {
if pkt.Network().TransportProtocol() == tcp.ProtocolNumber {
tcpHeader := header.TCP(pkt.TransportHeader().Slice())
if (tcpHeader.Flags() & header.TCPFlagRst) == header.TCPFlagRst {
e.rstPackets <- pkt
continue
}
}
}
_, err = e.ipv6rwc.Write(e.writeBuf[:n])
err = e.writePacket(pkt)
if err != nil {
log.Println(err)
return i - 1, &tcpip.ErrAborted{}
return i - 1, err
}
}