diff --git a/cmd/yggstack/main.go b/cmd/yggstack/main.go index 31598f6..7d6ca98 100644 --- a/cmd/yggstack/main.go +++ b/cmd/yggstack/main.go @@ -14,6 +14,7 @@ import ( "regexp" "runtime" "strings" + "sync" "syscall" "github.com/gologme/log" @@ -39,9 +40,15 @@ type node struct { socks5Listener net.Listener } +type UDPSession struct { + conn *net.UDPConn + remoteAddr net.Addr +} + // The main function is responsible for configuring and starting Yggdrasil. func main() { - var expose types.TCPMappings + var exposetcp types.TCPMappings + var exposeudp types.UDPMappings genconf := flag.Bool("genconf", false, "print a new config to stdout") useconf := flag.Bool("useconf", false, "read HJSON/JSON config from stdin") useconffile := flag.String("useconffile", "", "read HJSON/JSON config from specified file path") @@ -57,7 +64,8 @@ func main() { loglevel := flag.String("loglevel", "info", "loglevel to enable") socks := flag.String("socks", "", "address to listen on for SOCKS, i.e. :1080; or UNIX socket file path, i.e. /tmp/yggstack.sock") nameserver := flag.String("nameserver", "", "the Yggdrasil IPv6 address to use as a DNS server for SOCKS") - flag.Var(&expose, "exposetcp", "TCP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") + flag.Var(&exposetcp, "exposetcp", "TCP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") + flag.Var(&exposeudp, "exposeudp", "UDP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") flag.Parse() // Catch interrupts from the operating system to exit gracefully. @@ -322,13 +330,13 @@ func main() { // Create TCP mappings { - for _, mapping := range expose { + for _, mapping := range exposetcp { go func(mapping types.TCPMapping) { listener, err := s.ListenTCP(mapping.Listen) if err != nil { panic(err) } - logger.Infof("Mapping Yggdrasil port %d to %s", mapping.Listen.Port, mapping.Mapped) + logger.Infof("Mapping Yggdrasil TCP port %d to %s", mapping.Listen.Port, mapping.Mapped) for { c, err := listener.Accept() if err != nil { @@ -346,6 +354,63 @@ func main() { } } + // Create UDP mappings + { + for _, mapping := range exposeudp { + go func(mapping types.UDPMapping) { + mtu := n.core.MTU() + udpListenConn, err := s.ListenUDP(mapping.Listen) + if err != nil { + panic(err) + } + logger.Infof("Mapping Yggdrasil UDP port %d to %s", mapping.Listen.Port, mapping.Mapped) + remoteUdpConnections := new(sync.Map) + udpBuffer := make([]byte, mtu) + for { + bytesRead, remoteUdpAddr, err := udpListenConn.ReadFrom(udpBuffer) + if err != nil { + if bytesRead == 0 { + continue + } + } + + remoteUdpAddrStr := remoteUdpAddr.String() + + connVal, ok := remoteUdpConnections.Load(remoteUdpAddrStr) + + if !ok { + logger.Infof("Creating new session for %s", remoteUdpAddr.String()) + udpFwdConn, err := net.DialUDP("udp", nil, mapping.Mapped) + if err != nil { + logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) + continue + } + udpSession := &UDPSession{ + conn: udpFwdConn, + remoteAddr: remoteUdpAddr, + } + remoteUdpConnections.Store(remoteUdpAddrStr, udpSession) + go types.ReverseProxyUDP(mtu, udpListenConn, remoteUdpAddr, *udpFwdConn) + } + + + udpSession, ok := connVal.(*UDPSession) + if !ok { + continue + } + + _, err = udpSession.conn.Write(udpBuffer[:bytesRead]) + if err != nil { + logger.Debugf("Cannot write from yggdrasil to udp listener: %q", err) + udpSession.conn.Close() + remoteUdpConnections.Delete(remoteUdpAddrStr) + continue + } + } + }(mapping) + } + } + // Block until we are told to shut down. <-ctx.Done() diff --git a/src/types/mapping.go b/src/types/mapping.go index 0d2dd3d..03e99c3 100644 --- a/src/types/mapping.go +++ b/src/types/mapping.go @@ -66,3 +66,63 @@ func (m *TCPMappings) Set(value string) error { *m = append(*m, mapping) return nil } + +type UDPMapping struct { + Listen *net.UDPAddr + Mapped *net.UDPAddr +} + +type UDPMappings []UDPMapping + +func (m *UDPMappings) String() string { + return "" +} + +func (m *UDPMappings) Set(value string) error { + tokens := strings.Split(value, ":") + if len(tokens) > 2 { + tokens = strings.SplitN(value, ":", 2) + host, port, err := net.SplitHostPort(tokens[1]) + if err != nil { + return fmt.Errorf("failed to split host and port: %w", err) + } + tokens = append(tokens[:1], host, port) + } + listenport, err := strconv.Atoi(tokens[0]) + if err != nil { + return fmt.Errorf("listen port is invalid: %w", err) + } + if listenport == 0 { + return fmt.Errorf("listen port must not be zero") + } + mapping := UDPMapping{ + Listen: &net.UDPAddr{ + Port: listenport, + }, + Mapped: &net.UDPAddr{ + IP: net.IPv6loopback, + Port: listenport, + }, + } + tokens = tokens[1:] + if len(tokens) > 0 { + mappedaddr := net.ParseIP(tokens[0]) + if mappedaddr == nil { + return fmt.Errorf("invalid mapped address %q", tokens[0]) + } + mapping.Mapped.IP = mappedaddr + tokens = tokens[1:] + } + if len(tokens) > 0 { + mappedport, err := strconv.Atoi(tokens[0]) + if err != nil { + return fmt.Errorf("mapped port is invalid: %w", err) + } + if mappedport == 0 { + return fmt.Errorf("mapped port must not be zero") + } + mapping.Mapped.Port = mappedport + } + *m = append(*m, mapping) + return nil +} diff --git a/src/types/mapping_test.go b/src/types/mapping_test.go index b96c0e6..23e68b6 100644 --- a/src/types/mapping_test.go +++ b/src/types/mapping_test.go @@ -3,26 +3,48 @@ package types import "testing" func TestEndpointMappings(t *testing.T) { - var mappings TCPMappings - if err := mappings.Set("1234"); err != nil { + var tcpMappings TCPMappings + if err := tcpMappings.Set("1234"); err != nil { t.Fatal(err) } - if err := mappings.Set("1234:192.168.1.1"); err != nil { + if err := tcpMappings.Set("1234:192.168.1.1"); err != nil { t.Fatal(err) } - if err := mappings.Set("1234:192.168.1.1:4321"); err != nil { + if err := tcpMappings.Set("1234:192.168.1.1:4321"); err != nil { t.Fatal(err) } - if err := mappings.Set("1234:[2000::1]:4321"); err != nil { + if err := tcpMappings.Set("1234:[2000::1]:4321"); err != nil { t.Fatal(err) } - if err := mappings.Set("a"); err == nil { + if err := tcpMappings.Set("a"); err == nil { t.Fatal("'a' should be an invalid exposed port") } - if err := mappings.Set("1234:localhost"); err == nil { + if err := tcpMappings.Set("1234:localhost"); err == nil { t.Fatal("mapped address must be an IP literal") } - if err := mappings.Set("1234:localhost:a"); err == nil { + if err := tcpMappings.Set("1234:localhost:a"); err == nil { + t.Fatal("'a' should be an invalid mapped port") + } + var udpMappings UDPMappings + if err := udpMappings.Set("1234"); err != nil { + t.Fatal(err) + } + if err := udpMappings.Set("1234:192.168.1.1"); err != nil { + t.Fatal(err) + } + if err := udpMappings.Set("1234:192.168.1.1:4321"); err != nil { + t.Fatal(err) + } + if err := udpMappings.Set("1234:[2000::1]:4321"); err != nil { + t.Fatal(err) + } + if err := udpMappings.Set("a"); err == nil { + t.Fatal("'a' should be an invalid exposed port") + } + if err := udpMappings.Set("1234:localhost"); err == nil { + t.Fatal("mapped address must be an IP literal") + } + if err := udpMappings.Set("1234:localhost:a"); err == nil { t.Fatal("'a' should be an invalid mapped port") } } diff --git a/src/types/udpproxy.go b/src/types/udpproxy.go new file mode 100644 index 0000000..7fc4375 --- /dev/null +++ b/src/types/udpproxy.go @@ -0,0 +1,22 @@ +package types + +import ( + "net" +) + +func ReverseProxyUDP(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.UDPConn) error { + buf := make([]byte, mtu) + for { + n, err := src.Read(buf[:]) + if err != nil { + return err + } + if n > 0 { + n, err = dst.WriteTo(buf[:n], dstAddr) + if err != nil { + return err + } + } + } + return nil +}