[WIP] Introduce TCP/UDP local/remote port forwarding

Signed-off-by: Vasyl Gello <vasek.gello@gmail.com>
This commit is contained in:
Vasyl Gello 2024-07-18 22:24:46 +03:00
parent 30d51ba566
commit 0783b429fd
4 changed files with 514 additions and 82 deletions

View file

@ -7,62 +7,269 @@ import (
"strings"
)
func parseMappingString(value string) (first_address string, first_port int, second_address string, second_port int, err error) {
var first_port_string string = ""
var second_port_string string = ""
tokens := strings.Split(value, ":")
tokens_len := len(tokens)
// If token count is 1, then it is first and second port the same
if tokens_len == 1 {
first_port, err = strconv.Atoi(tokens[0])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
second_port = first_port
}
// If token count is 2, then it is <first-port>:<second-port>
if tokens_len == 2 {
first_port, err = strconv.Atoi(tokens[0])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
second_port, err = strconv.Atoi(tokens[1])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
}
// If token count is 3, parse it as
// <first-port>:<second-address>:<second-port>
if tokens_len == 3 {
first_port, err = strconv.Atoi(tokens[0])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
second_address, second_port_string, err = net.SplitHostPort(
tokens[1] + ":" + tokens[2])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
second_port, err = strconv.Atoi(second_port_string)
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
}
// If token count is 4, parse it as
// <first-address>:<first-port>:<second-address>:<second-port>
if tokens_len == 4 {
first_address, first_port_string, err = net.SplitHostPort(
tokens[0] + ":" + tokens[1])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
second_address, second_port_string, err = net.SplitHostPort(
tokens[0] + ":" + tokens[1])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
first_port, err = strconv.Atoi(first_port_string)
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
second_port, err = strconv.Atoi(second_port_string)
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
}
if tokens_len > 4 {
// Last token needs to be the second_port
second_port, err = strconv.Atoi(tokens[tokens_len-1])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
// Cut seen tokens
tokens = tokens[:tokens_len-1]
tokens_len = len(tokens)
if strings.HasSuffix(tokens[tokens_len-1], "]") {
// Reverse-walk over tokens to find the end of
// numeric ipv6 address
for i := tokens_len - 1; i >= 0; i-- {
if strings.HasPrefix(tokens[i], "[") {
// Store second address
second_address = strings.Join(tokens[i:], ":")
second_address, _ = strings.CutPrefix(second_address, "[")
second_address, _ = strings.CutSuffix(second_address, "]")
// Cut seen tokens
tokens = tokens[:i]
// break from loop
break
}
}
} else {
// next is second address in non-numerical-ipv6 form
second_address = tokens[tokens_len-1]
tokens = tokens[:tokens_len-1]
}
tokens_len = len(tokens)
if tokens_len < 1 {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
// Last token needs to be the first_port
first_port, err = strconv.Atoi(tokens[tokens_len-1])
if err != nil {
return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value)
}
// Cut seen tokens
tokens = tokens[:tokens_len-1]
tokens_len = len(tokens)
if tokens_len > 0 {
if strings.HasSuffix(tokens[tokens_len-1], "]") {
// Reverse-walk over tokens to find the end of
// numeric ipv6 address
for i := tokens_len - 1; i >= 0; i-- {
if strings.HasPrefix(tokens[i], "[") {
// Store first address
first_address = strings.Join(tokens[i:], ":")
first_address, _ = strings.CutPrefix(first_address, "[")
first_address, _ = strings.CutSuffix(first_address, "]")
// break from loop
break
}
}
} else {
// next is first address in non-numerical-ipv6 form
first_address = tokens[tokens_len-1]
}
}
}
if first_port == 0 || second_port == 0 {
return "", 0, "", 0, fmt.Errorf("Ports must not be zero")
}
return first_address, first_port, second_address, second_port, nil
}
type TCPMapping struct {
Listen *net.TCPAddr
Mapped *net.TCPAddr
}
type TCPMappings []TCPMapping
type TCPLocalMappings []TCPMapping
func (m *TCPMappings) String() string {
func (m *TCPLocalMappings) String() string {
return ""
}
func (m *TCPMappings) Set(value string) error {
tokens := strings.Split(value, ":")
if len(tokens) > 2 {
tokens = strings.SplitN(value, ":", 2)
host, port, err := net.SplitHostPort(tokens[1])
if err != nil {
return fmt.Errorf("failed to split host and port: %w", err)
}
tokens = append(tokens[:1], host, port)
}
listenport, err := strconv.Atoi(tokens[0])
func (m *TCPLocalMappings) Set(value string) error {
first_address, first_port, second_address, second_port, err :=
parseMappingString(value)
if err != nil {
return fmt.Errorf("listen port is invalid: %w", err)
return err
}
if listenport == 0 {
return fmt.Errorf("listen port must not be zero")
// First address can be ipv4/ipv6
// Second address can be only Yggdrasil ipv6
if !strings.Contains(second_address, ":") {
return fmt.Errorf("Yggdrasil listening address can be only IPv6")
}
// Create mapping
mapping := TCPMapping{
Listen: &net.TCPAddr{
Port: listenport,
Port: first_port,
},
Mapped: &net.TCPAddr{
IP: net.IPv6loopback,
Port: listenport,
Port: second_port,
},
}
tokens = tokens[1:]
if len(tokens) > 0 {
mappedaddr := net.ParseIP(tokens[0])
if first_address != "" {
listenaddr := net.ParseIP(first_address)
if listenaddr == nil {
return fmt.Errorf("invalid listen address %q", first_address)
}
mapping.Listen.IP = listenaddr
}
if second_address != "" {
mappedaddr := net.ParseIP(second_address)
if mappedaddr == nil {
return fmt.Errorf("invalid mapped address %q", tokens[0])
return fmt.Errorf("invalid mapped address %q", second_address)
}
// TODO: Filter Yggdrasil IPs here
mapping.Mapped.IP = mappedaddr
}
*m = append(*m, mapping)
return nil
}
type TCPRemoteMappings []TCPMapping
func (m *TCPRemoteMappings) String() string {
return ""
}
func (m *TCPRemoteMappings) Set(value string) error {
first_address, first_port, second_address, second_port, err :=
parseMappingString(value)
if err != nil {
return err
}
// First address must be empty
// Second address can be ipv4/ipv6
if first_address != "" {
return fmt.Errorf("Yggdrasil listening must be empty")
}
// Create mapping
mapping := TCPMapping{
Listen: &net.TCPAddr{
Port: first_port,
},
Mapped: &net.TCPAddr{
IP: net.IPv6loopback,
Port: second_port,
},
}
if first_address != "" {
listenaddr := net.ParseIP(first_address)
if listenaddr == nil {
return fmt.Errorf("invalid listen address %q", first_address)
}
mapping.Listen.IP = listenaddr
}
if second_address != "" {
mappedaddr := net.ParseIP(second_address)
if mappedaddr == nil {
return fmt.Errorf("invalid mapped address %q", second_address)
}
mapping.Mapped.IP = mappedaddr
tokens = tokens[1:]
}
if len(tokens) > 0 {
mappedport, err := strconv.Atoi(tokens[0])
if err != nil {
return fmt.Errorf("mapped port is invalid: %w", err)
}
if mappedport == 0 {
return fmt.Errorf("mapped port must not be zero")
}
mapping.Mapped.Port = mappedport
}
*m = append(*m, mapping)
return nil
}
@ -72,57 +279,109 @@ type UDPMapping struct {
Mapped *net.UDPAddr
}
type UDPMappings []UDPMapping
type UDPLocalMappings []UDPMapping
func (m *UDPMappings) String() string {
func (m *UDPLocalMappings) String() string {
return ""
}
func (m *UDPMappings) Set(value string) error {
tokens := strings.Split(value, ":")
if len(tokens) > 2 {
tokens = strings.SplitN(value, ":", 2)
host, port, err := net.SplitHostPort(tokens[1])
if err != nil {
return fmt.Errorf("failed to split host and port: %w", err)
}
tokens = append(tokens[:1], host, port)
}
listenport, err := strconv.Atoi(tokens[0])
func (m *UDPLocalMappings) Set(value string) error {
first_address, first_port, second_address, second_port, err :=
parseMappingString(value)
if err != nil {
return fmt.Errorf("listen port is invalid: %w", err)
return err
}
if listenport == 0 {
return fmt.Errorf("listen port must not be zero")
// First address can be ipv4/ipv6
// Second address can be only Yggdrasil ipv6
if !strings.Contains(second_address, ":") {
return fmt.Errorf("Yggdrasil listening address can be only IPv6")
}
// Create mapping
mapping := UDPMapping{
Listen: &net.UDPAddr{
Port: listenport,
Port: first_port,
},
Mapped: &net.UDPAddr{
IP: net.IPv6loopback,
Port: listenport,
Port: second_port,
},
}
tokens = tokens[1:]
if len(tokens) > 0 {
mappedaddr := net.ParseIP(tokens[0])
if first_address != "" {
listenaddr := net.ParseIP(first_address)
if listenaddr == nil {
return fmt.Errorf("invalid listen address %q", first_address)
}
mapping.Listen.IP = listenaddr
}
if second_address != "" {
mappedaddr := net.ParseIP(second_address)
if mappedaddr == nil {
return fmt.Errorf("invalid mapped address %q", tokens[0])
return fmt.Errorf("invalid mapped address %q", second_address)
}
// TODO: Filter Yggdrasil IPs here
mapping.Mapped.IP = mappedaddr
tokens = tokens[1:]
}
if len(tokens) > 0 {
mappedport, err := strconv.Atoi(tokens[0])
if err != nil {
return fmt.Errorf("mapped port is invalid: %w", err)
}
if mappedport == 0 {
return fmt.Errorf("mapped port must not be zero")
}
mapping.Mapped.Port = mappedport
}
*m = append(*m, mapping)
return nil
}
type UDPRemoteMappings []UDPMapping
func (m *UDPRemoteMappings) String() string {
return ""
}
func (m *UDPRemoteMappings) Set(value string) error {
first_address, first_port, second_address, second_port, err :=
parseMappingString(value)
if err != nil {
return err
}
// First address must be empty
// Second address can be ipv4/ipv6
if first_address != "" {
return fmt.Errorf("Yggdrasil listening must be empty")
}
// Create mapping
mapping := UDPMapping{
Listen: &net.UDPAddr{
Port: first_port,
},
Mapped: &net.UDPAddr{
IP: net.IPv6loopback,
Port: second_port,
},
}
if first_address != "" {
listenaddr := net.ParseIP(first_address)
if listenaddr == nil {
return fmt.Errorf("invalid listen address %q", first_address)
}
mapping.Listen.IP = listenaddr
}
if second_address != "" {
mappedaddr := net.ParseIP(second_address)
if mappedaddr == nil {
return fmt.Errorf("invalid mapped address %q", second_address)
}
mapping.Mapped.IP = mappedaddr
}
*m = append(*m, mapping)
return nil
}

View file

@ -13,18 +13,48 @@ func TestEndpointMappings(t *testing.T) {
if err := tcpMappings.Set("1234:192.168.1.1:4321"); err != nil {
t.Fatal(err)
}
if err := tcpMappings.Set("192.168.1.2:1234:192.168.1.1:4321"); err != nil {
t.Fatal(err)
}
if err := tcpMappings.Set("1234:[2000::1]:4321"); err != nil {
t.Fatal(err)
}
if err := tcpMappings.Set("[2001:1]:1234:[2000::1]:4321"); err != nil {
t.Fatal(err)
}
if err := tcpMappings.Set("a"); err == nil {
t.Fatal("'a' should be an invalid exposed port")
}
if err := tcpMappings.Set("1234:localhost"); err == nil {
t.Fatal("mapped address must be an IP literal")
}
if err := tcpMappings.Set("127.0.0.1:1234:localhost"); err == nil {
t.Fatal("mapped address must be an IP literal")
}
if err := tcpMappings.Set("[2000:1]:1234:localhost"); err == nil {
t.Fatal("mapped address must be an IP literal")
}
if err := tcpMappings.Set("localhost:1234:127.0.0.1"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := tcpMappings.Set("localhost:1234:127.0.0.1"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := tcpMappings.Set("localhost:1234:[2000:1]"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := tcpMappings.Set("localhost:1234:[2000:1]"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := tcpMappings.Set("1234:localhost:a"); err == nil {
t.Fatal("'a' should be an invalid mapped port")
}
if err := tcpMappings.Set("127.0.0.1:1234:127.0.0.1:a"); err == nil {
t.Fatal("'a' should be an invalid mapped port")
}
if err := tcpMappings.Set("[2000::1]:1234:[2000::1]:a"); err == nil {
t.Fatal("'a' should be an invalid mapped port")
}
var udpMappings UDPMappings
if err := udpMappings.Set("1234"); err != nil {
t.Fatal(err)
@ -35,16 +65,46 @@ func TestEndpointMappings(t *testing.T) {
if err := udpMappings.Set("1234:192.168.1.1:4321"); err != nil {
t.Fatal(err)
}
if err := udpMappings.Set("192.168.1.2:1234:192.168.1.1:4321"); err != nil {
t.Fatal(err)
}
if err := udpMappings.Set("1234:[2000::1]:4321"); err != nil {
t.Fatal(err)
}
if err := udpMappings.Set("[2001:1]:1234:[2000::1]:4321"); err != nil {
t.Fatal(err)
}
if err := udpMappings.Set("a"); err == nil {
t.Fatal("'a' should be an invalid exposed port")
}
if err := udpMappings.Set("1234:localhost"); err == nil {
t.Fatal("mapped address must be an IP literal")
}
if err := udpMappings.Set("127.0.0.1:1234:localhost"); err == nil {
t.Fatal("mapped address must be an IP literal")
}
if err := udpMappings.Set("[2000:1]:1234:localhost"); err == nil {
t.Fatal("mapped address must be an IP literal")
}
if err := udpMappings.Set("localhost:1234:127.0.0.1"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := udpMappings.Set("localhost:1234:127.0.0.1"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := udpMappings.Set("localhost:1234:[2000:1]"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := udpMappings.Set("localhost:1234:[2000:1]"); err == nil {
t.Fatal("listen address must be an IP literal")
}
if err := udpMappings.Set("1234:localhost:a"); err == nil {
t.Fatal("'a' should be an invalid mapped port")
}
if err := udpMappings.Set("127.0.0.1:1234:127.0.0.1:a"); err == nil {
t.Fatal("'a' should be an invalid mapped port")
}
if err := udpMappings.Set("[2000::1]:1234:[2000::1]:a"); err == nil {
t.Fatal("'a' should be an invalid mapped port")
}
}

View file

@ -4,7 +4,7 @@ import (
"net"
)
func ReverseProxyUDP(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.UDPConn) error {
func ReverseProxyUDPConn(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.UDPConn) error {
buf := make([]byte, mtu)
for {
n, err := src.Read(buf[:])
@ -20,3 +20,20 @@ func ReverseProxyUDP(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.U
}
return nil
}
func ReverseProxyUDPPacketConn(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn) error {
buf := make([]byte, mtu)
for {
n, _, err := src.ReadFrom(buf[:])
if err != nil {
return err
}
if n > 0 {
n, err = dst.WriteTo(buf[:n], dstAddr)
if err != nil {
return err
}
}
}
return nil
}