diff --git a/cmd/yggstack/main.go b/cmd/yggstack/main.go index d4d6c89d..0857af68 100644 --- a/cmd/yggstack/main.go +++ b/cmd/yggstack/main.go @@ -1,13 +1,11 @@ package main import ( - "context" "crypto/ed25519" "encoding/hex" "encoding/json" "flag" "fmt" - "io" "net" "os" "os/signal" @@ -18,6 +16,7 @@ import ( "github.com/hjson/hjson-go" "github.com/things-go/go-socks5" + "github.com/yggdrasil-network/yggdrasil-go/cmd/yggstack/types" "github.com/yggdrasil-network/yggdrasil-go/contrib/netstack" "github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/config" @@ -25,23 +24,15 @@ import ( "github.com/yggdrasil-network/yggdrasil-go/src/version" - "net/http" _ "net/http/pprof" ) -type nameResolver struct{} - -func (r *nameResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { - ip := net.ParseIP(name) - if ip == nil { - return nil, nil, fmt.Errorf("not a valid IP address") - } - return ctx, ip, nil -} - // The main function is responsible for configuring and starting Yggdrasil. func main() { + var expose types.TCPMappings socks := flag.String("socks", "", "address to listen on for SOCKS, i.e. :1080") + nameserver := flag.String("nameserver", "", "the Yggdrasil IPv6 address to use as a DNS server for SOCKS") + flag.Var(&expose, "exposetcp", "TCP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") args := setup.ParseArguments() // Create a new logger that logs output to stdout. @@ -162,7 +153,7 @@ func main() { } if *socks != "" { - resolver := &nameResolver{} + resolver := types.NewNameResolver(s, *nameserver) server := socks5.NewServer( socks5.WithDial(s.DialContext), socks5.WithResolver(resolver), @@ -170,15 +161,28 @@ func main() { go server.ListenAndServe("tcp", *socks) // nolint:errcheck } - listener, err := s.ListenTCP(&net.TCPAddr{Port: 80}) - if err != nil { - log.Panicln(err) + for _, mapping := range expose { + go func(mapping types.TCPMapping) { + listener, err := s.ListenTCP(mapping.Listen) + if err != nil { + panic(err) + } + logger.Infof("Mapping Yggdrasil port %d to %s", mapping.Listen.Port, mapping.Mapped) + for { + c, err := listener.Accept() + if err != nil { + panic(err) + } + r, err := net.DialTCP("tcp", nil, mapping.Mapped) + if err != nil { + logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) + _ = c.Close() + continue + } + types.ProxyTCP(n.MTU(), c, r) + } + }(mapping) } - http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - _, _ = io.WriteString(writer, "I am Yggstack!") - }) - httpServer := &http.Server{} - go httpServer.Serve(listener) // nolint:errcheck term := make(chan os.Signal, 1) signal.Notify(term, os.Interrupt, syscall.SIGTERM) diff --git a/cmd/yggstack/types/mapping.go b/cmd/yggstack/types/mapping.go new file mode 100644 index 00000000..0d2dd3d5 --- /dev/null +++ b/cmd/yggstack/types/mapping.go @@ -0,0 +1,68 @@ +package types + +import ( + "fmt" + "net" + "strconv" + "strings" +) + +type TCPMapping struct { + Listen *net.TCPAddr + Mapped *net.TCPAddr +} + +type TCPMappings []TCPMapping + +func (m *TCPMappings) String() string { + return "" +} + +func (m *TCPMappings) Set(value string) error { + tokens := strings.Split(value, ":") + if len(tokens) > 2 { + tokens = strings.SplitN(value, ":", 2) + host, port, err := net.SplitHostPort(tokens[1]) + if err != nil { + return fmt.Errorf("failed to split host and port: %w", err) + } + tokens = append(tokens[:1], host, port) + } + listenport, err := strconv.Atoi(tokens[0]) + if err != nil { + return fmt.Errorf("listen port is invalid: %w", err) + } + if listenport == 0 { + return fmt.Errorf("listen port must not be zero") + } + mapping := TCPMapping{ + Listen: &net.TCPAddr{ + Port: listenport, + }, + Mapped: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: listenport, + }, + } + tokens = tokens[1:] + if len(tokens) > 0 { + mappedaddr := net.ParseIP(tokens[0]) + if mappedaddr == nil { + return fmt.Errorf("invalid mapped address %q", tokens[0]) + } + mapping.Mapped.IP = mappedaddr + tokens = tokens[1:] + } + if len(tokens) > 0 { + mappedport, err := strconv.Atoi(tokens[0]) + if err != nil { + return fmt.Errorf("mapped port is invalid: %w", err) + } + if mappedport == 0 { + return fmt.Errorf("mapped port must not be zero") + } + mapping.Mapped.Port = mappedport + } + *m = append(*m, mapping) + return nil +} diff --git a/cmd/yggstack/types/mapping_test.go b/cmd/yggstack/types/mapping_test.go new file mode 100644 index 00000000..b96c0e68 --- /dev/null +++ b/cmd/yggstack/types/mapping_test.go @@ -0,0 +1,28 @@ +package types + +import "testing" + +func TestEndpointMappings(t *testing.T) { + var mappings TCPMappings + if err := mappings.Set("1234"); err != nil { + t.Fatal(err) + } + if err := mappings.Set("1234:192.168.1.1"); err != nil { + t.Fatal(err) + } + if err := mappings.Set("1234:192.168.1.1:4321"); err != nil { + t.Fatal(err) + } + if err := mappings.Set("1234:[2000::1]:4321"); err != nil { + t.Fatal(err) + } + if err := mappings.Set("a"); err == nil { + t.Fatal("'a' should be an invalid exposed port") + } + if err := mappings.Set("1234:localhost"); err == nil { + t.Fatal("mapped address must be an IP literal") + } + if err := mappings.Set("1234:localhost:a"); err == nil { + t.Fatal("'a' should be an invalid mapped port") + } +} diff --git a/cmd/yggstack/types/resolver.go b/cmd/yggstack/types/resolver.go new file mode 100644 index 00000000..d49efe8f --- /dev/null +++ b/cmd/yggstack/types/resolver.go @@ -0,0 +1,48 @@ +package types + +import ( + "context" + "fmt" + "net" + + "github.com/yggdrasil-network/yggdrasil-go/contrib/netstack" +) + +type NameResolver struct { + resolver *net.Resolver +} + +func NewNameResolver(stack *netstack.YggdrasilNetstack, nameserver string) *NameResolver { + res := &NameResolver{ + resolver: &net.Resolver{ + PreferGo: true, + }, + } + if nameserver != "" { + res.resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { // nolint:staticcheck + address = fmt.Sprintf("[%s]:53", nameserver) // nolint:staticcheck + if nameserver == "" { + return nil, fmt.Errorf("no nameserver configured") + } + return stack.DialContext(ctx, network, address) + } + } + return res +} + +func (r *NameResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { + ip := net.ParseIP(name) + if ip == nil { + addrs, err := r.resolver.LookupIP(ctx, "ip6", name) + if err != nil { + fmt.Println("failed to lookup", name, "due to error:", err) + return nil, nil, fmt.Errorf("failed to lookup %q: %s", name, err) + } + if len(addrs) == 0 { + fmt.Println("failed to lookup", name, "due to no addresses") + return nil, nil, fmt.Errorf("no addresses for %q", name) + } + return ctx, addrs[0], nil + } + return ctx, ip, nil +} diff --git a/cmd/yggstack/types/tcpproxy.go b/cmd/yggstack/types/tcpproxy.go new file mode 100644 index 00000000..71163b30 --- /dev/null +++ b/cmd/yggstack/types/tcpproxy.go @@ -0,0 +1,43 @@ +package types + +import "net" + +func connToChan(mtu uint64, conn net.Conn) chan []byte { + c := make(chan []byte) + go func() { + for { + b := make([]byte, mtu) + n, err := conn.Read(b[:]) + if err != nil { + c <- nil + return + } + if n > 0 { + c <- b[:n] + } + } + }() + return c +} + +func ProxyTCP(mtu uint64, c1, c2 net.Conn) { + p1, p2 := connToChan(mtu, c1), connToChan(mtu, c2) + defer c1.Close() + defer c2.Close() + for { + select { + case b := <-p1: + if b == nil { + return + } else if _, err := c2.Write(b); err != nil { + return + } + case b := <-p2: + if b == nil { + return + } else if _, err := c1.Write(b); err != nil { + return + } + } + } +} diff --git a/contrib/netstack/netstack.go b/contrib/netstack/netstack.go index 7cf24044..3bd11575 100644 --- a/contrib/netstack/netstack.go +++ b/contrib/netstack/netstack.go @@ -71,7 +71,11 @@ func (s *YggdrasilNetstack) DialContext(ctx context.Context, network, address st case "tcp", "tcp6": return gonet.DialContextTCP(ctx, s.stack, fa, pn) case "udp", "udp6": - return gonet.DialUDP(s.stack, nil, &fa, pn) + conn, err := gonet.DialUDP(s.stack, nil, &fa, pn) + if err != nil { + return nil, fmt.Errorf("gonet.DialUDP: %w", err) + } + return conn, nil default: return nil, fmt.Errorf("not supported") }