mirror of
				https://github.com/yggdrasil-network/yggstack.git
				synced 2025-11-04 08:25:06 +03:00 
			
		
		
		
	Initial commit
Based on previous work of @neilalexander: https://github.com/yggdrasil-network/yggdrasil-go@netstack Signed-off-by: Vasyl Gello <vasek.gello@gmail.com>
This commit is contained in:
		
						commit
						6e427fefec
					
				
					 9 changed files with 944 additions and 0 deletions
				
			
		
							
								
								
									
										68
									
								
								src/types/mapping.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								src/types/mapping.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,68 @@
 | 
			
		|||
package types
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TCPMapping struct {
 | 
			
		||||
	Listen *net.TCPAddr
 | 
			
		||||
	Mapped *net.TCPAddr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TCPMappings []TCPMapping
 | 
			
		||||
 | 
			
		||||
func (m *TCPMappings) String() string {
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *TCPMappings) Set(value string) error {
 | 
			
		||||
	tokens := strings.Split(value, ":")
 | 
			
		||||
	if len(tokens) > 2 {
 | 
			
		||||
		tokens = strings.SplitN(value, ":", 2)
 | 
			
		||||
		host, port, err := net.SplitHostPort(tokens[1])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to split host and port: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		tokens = append(tokens[:1], host, port)
 | 
			
		||||
	}
 | 
			
		||||
	listenport, err := strconv.Atoi(tokens[0])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("listen port is invalid: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	if listenport == 0 {
 | 
			
		||||
		return fmt.Errorf("listen port must not be zero")
 | 
			
		||||
	}
 | 
			
		||||
	mapping := TCPMapping{
 | 
			
		||||
		Listen: &net.TCPAddr{
 | 
			
		||||
			Port: listenport,
 | 
			
		||||
		},
 | 
			
		||||
		Mapped: &net.TCPAddr{
 | 
			
		||||
			IP:   net.IPv6loopback,
 | 
			
		||||
			Port: listenport,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	tokens = tokens[1:]
 | 
			
		||||
	if len(tokens) > 0 {
 | 
			
		||||
		mappedaddr := net.ParseIP(tokens[0])
 | 
			
		||||
		if mappedaddr == nil {
 | 
			
		||||
			return fmt.Errorf("invalid mapped address %q", tokens[0])
 | 
			
		||||
		}
 | 
			
		||||
		mapping.Mapped.IP = mappedaddr
 | 
			
		||||
		tokens = tokens[1:]
 | 
			
		||||
	}
 | 
			
		||||
	if len(tokens) > 0 {
 | 
			
		||||
		mappedport, err := strconv.Atoi(tokens[0])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("mapped port is invalid: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		if mappedport == 0 {
 | 
			
		||||
			return fmt.Errorf("mapped port must not be zero")
 | 
			
		||||
		}
 | 
			
		||||
		mapping.Mapped.Port = mappedport
 | 
			
		||||
	}
 | 
			
		||||
	*m = append(*m, mapping)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										28
									
								
								src/types/mapping_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								src/types/mapping_test.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
package types
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
func TestEndpointMappings(t *testing.T) {
 | 
			
		||||
	var mappings TCPMappings
 | 
			
		||||
	if err := mappings.Set("1234"); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if err := mappings.Set("1234:192.168.1.1"); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if err := mappings.Set("1234:192.168.1.1:4321"); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if err := mappings.Set("1234:[2000::1]:4321"); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if err := mappings.Set("a"); err == nil {
 | 
			
		||||
		t.Fatal("'a' should be an invalid exposed port")
 | 
			
		||||
	}
 | 
			
		||||
	if err := mappings.Set("1234:localhost"); err == nil {
 | 
			
		||||
		t.Fatal("mapped address must be an IP literal")
 | 
			
		||||
	}
 | 
			
		||||
	if err := mappings.Set("1234:localhost:a"); err == nil {
 | 
			
		||||
		t.Fatal("'a' should be an invalid mapped port")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										71
									
								
								src/types/resolver.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								src/types/resolver.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,71 @@
 | 
			
		|||
package types
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/ed25519"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/yggdrasil-network/yggdrasil-go/src/address"
 | 
			
		||||
	"github.com/yggdrasil-network/yggstack/contrib/netstack"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const NameMappingSuffix = ".pk.ygg"
 | 
			
		||||
 | 
			
		||||
type NameResolver struct {
 | 
			
		||||
	resolver *net.Resolver
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewNameResolver(stack *netstack.YggdrasilNetstack, nameserver string) *NameResolver {
 | 
			
		||||
	res := &NameResolver{
 | 
			
		||||
		resolver: &net.Resolver{
 | 
			
		||||
			PreferGo: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	if nameserver != "" {
 | 
			
		||||
		res.resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { // nolint:staticcheck
 | 
			
		||||
			if nameserver == "" {
 | 
			
		||||
				return nil, fmt.Errorf("no nameserver configured")
 | 
			
		||||
			}
 | 
			
		||||
			address, port, found := strings.Cut(nameserver, ":")
 | 
			
		||||
			if !found {
 | 
			
		||||
				port = "53"
 | 
			
		||||
			}
 | 
			
		||||
			address = net.JoinHostPort(nameserver, port)
 | 
			
		||||
			return stack.DialContext(ctx, network, address)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *NameResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
 | 
			
		||||
	if strings.HasSuffix(name, NameMappingSuffix) {
 | 
			
		||||
		name = strings.TrimSuffix(name, NameMappingSuffix)
 | 
			
		||||
		// Check if remaining string contains a dot and
 | 
			
		||||
		// assume publickey is a rightmost token
 | 
			
		||||
		name = name[strings.LastIndex(name, ".")+1:]
 | 
			
		||||
		var pk [ed25519.PublicKeySize]byte
 | 
			
		||||
		if b, err := hex.DecodeString(name); err != nil {
 | 
			
		||||
			return nil, nil, fmt.Errorf("hex.DecodeString: %w", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			copy(pk[:], b)
 | 
			
		||||
			return ctx, net.IP(address.AddrForKey(pk[:])[:]), nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	ip := net.ParseIP(name)
 | 
			
		||||
	if ip == nil {
 | 
			
		||||
		addrs, err := r.resolver.LookupIP(ctx, "ip6", name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("failed to lookup", name, "due to error:", err)
 | 
			
		||||
			return nil, nil, fmt.Errorf("failed to lookup %q: %s", name, err)
 | 
			
		||||
		}
 | 
			
		||||
		if len(addrs) == 0 {
 | 
			
		||||
			fmt.Println("failed to lookup", name, "due to no addresses")
 | 
			
		||||
			return nil, nil, fmt.Errorf("no addresses for %q", name)
 | 
			
		||||
		}
 | 
			
		||||
		return ctx, addrs[0], nil
 | 
			
		||||
	}
 | 
			
		||||
	return ctx, ip, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										42
									
								
								src/types/tcpproxy.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								src/types/tcpproxy.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,42 @@
 | 
			
		|||
package types
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func tcpProxyFunc(mtu uint64, dst, src net.Conn) error {
 | 
			
		||||
	buf := make([]byte, mtu)
 | 
			
		||||
	for {
 | 
			
		||||
		n, err := src.Read(buf[:])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if n > 0 {
 | 
			
		||||
			n, err = dst.Write(buf[:n])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ProxyTCP(mtu uint64, c1, c2 net.Conn) error {
 | 
			
		||||
	// Start proxying
 | 
			
		||||
	errCh := make(chan error, 2)
 | 
			
		||||
	go func() { errCh <- tcpProxyFunc(mtu, c1, c2) }()
 | 
			
		||||
	go func() { errCh <- tcpProxyFunc(mtu, c2, c1) }()
 | 
			
		||||
 | 
			
		||||
	// Wait
 | 
			
		||||
	for i := 0; i < 2; i++ {
 | 
			
		||||
		e := <-errCh
 | 
			
		||||
		if e != nil {
 | 
			
		||||
			// Close connections and return
 | 
			
		||||
			c1.Close()
 | 
			
		||||
			c2.Close()
 | 
			
		||||
			return e
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue