Try to fix #4

... by catching TCP RST packets in WritePackets and sending them
during the next WritePackets call where no RST packet is being sent

Signed-off-by: Vasyl Gello <vasek.gello@gmail.com>
This commit is contained in:
Vasyl Gello 2024-07-18 11:48:42 +03:00
parent b160b3f66d
commit 30d51ba566

View file

@ -12,6 +12,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
) )
type YggdrasilNIC struct { type YggdrasilNIC struct {
@ -20,6 +21,7 @@ type YggdrasilNIC struct {
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
readBuf []byte readBuf []byte
writeBuf []byte writeBuf []byte
rstPackets chan *stack.PacketBuffer
} }
func (s *YggdrasilNetstack) NewYggdrasilNIC(ygg *core.Core) tcpip.Error { func (s *YggdrasilNetstack) NewYggdrasilNIC(ygg *core.Core) tcpip.Error {
@ -29,6 +31,7 @@ func (s *YggdrasilNetstack) NewYggdrasilNIC(ygg *core.Core) tcpip.Error {
ipv6rwc: rwc, ipv6rwc: rwc,
readBuf: make([]byte, mtu), readBuf: make([]byte, mtu),
writeBuf: make([]byte, mtu), writeBuf: make([]byte, mtu),
rstPackets: make(chan *stack.PacketBuffer, 100),
} }
if err := s.stack.CreateNIC(1, nic); err != nil { if err := s.stack.CreateNIC(1, nic); err != nil {
return err return err
@ -48,6 +51,15 @@ func (s *YggdrasilNetstack) NewYggdrasilNIC(ygg *core.Core) tcpip.Error {
nic.dispatcher.DeliverNetworkPacket(ipv6.ProtocolNumber, pkb) nic.dispatcher.DeliverNetworkPacket(ipv6.ProtocolNumber, pkb)
} }
}() }()
go func() {
for {
pkt := <- nic.rstPackets
if pkt == nil {
continue
}
_ = nic.writePacket(pkt)
}
}()
_, snet, err := net.ParseCIDR("0200::/7") _, snet, err := net.ParseCIDR("0200::/7")
if err != nil { if err != nil {
return &tcpip.ErrBadAddress{} return &tcpip.ErrBadAddress{}
@ -93,21 +105,48 @@ func (*YggdrasilNIC) LinkAddress() tcpip.LinkAddress { return "" }
func (*YggdrasilNIC) Wait() {} func (*YggdrasilNIC) Wait() {}
func (e *YggdrasilNIC) writePacket(
pkt *stack.PacketBuffer,
) tcpip.Error {
// We need to recover from panic() here because
// parser in ToView() gets confused on some packets
// without payload and panics
defer func() {
r := recover()
if r != nil {
}
}()
vv := pkt.ToView()
n, err := vv.Read(e.writeBuf)
if err != nil {
return &tcpip.ErrAborted{}
}
_, err = e.ipv6rwc.Write(e.writeBuf[:n])
if err != nil {
return &tcpip.ErrAborted{}
}
return nil
}
func (e *YggdrasilNIC) WritePackets( func (e *YggdrasilNIC) WritePackets(
list stack.PacketBufferList, list stack.PacketBufferList,
) (int, tcpip.Error) { ) (int, tcpip.Error) {
var i int = 0 var i int = 0
var err tcpip.Error = nil
for i, pkt := range list.AsSlice() { for i, pkt := range list.AsSlice() {
vv := pkt.ToView() if pkt.Data().Size() == 0 {
n, err := vv.Read(e.writeBuf) if pkt.Network().TransportProtocol() == tcp.ProtocolNumber {
if err != nil { tcpHeader := header.TCP(pkt.TransportHeader().Slice())
log.Println(err) if (tcpHeader.Flags() & header.TCPFlagRst) == header.TCPFlagRst {
return i - 1, &tcpip.ErrAborted{} e.rstPackets <- pkt
continue
} }
_, err = e.ipv6rwc.Write(e.writeBuf[:n]) }
}
err = e.writePacket(pkt)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return i - 1, &tcpip.ErrAborted{} return i - 1, err
} }
} }