diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index d6d0d1a6..95d40151 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -29,6 +29,7 @@ import ( "github.com/yggdrasil-network/yggdrasil-go/src/defaults" "github.com/yggdrasil-network/yggdrasil-go/src/core" + "github.com/yggdrasil-network/yggdrasil-go/src/ipv6rwc" "github.com/yggdrasil-network/yggdrasil-go/src/multicast" "github.com/yggdrasil-network/yggdrasil-go/src/tuntap" "github.com/yggdrasil-network/yggdrasil-go/src/version" @@ -353,7 +354,8 @@ func run(args yggArgs, ctx context.Context, done chan struct{}) { } n.multicast.SetupAdminHandlers(n.admin) // Start the TUN/TAP interface - if err := n.tuntap.Init(&n.core, cfg, logger, nil); err != nil { + rwc := ipv6rwc.NewReadWriteCloser(&n.core) + if err := n.tuntap.Init(rwc, cfg, logger, nil); err != nil { logger.Errorln("An error occurred initialising TUN/TAP:", err) } else if err := n.tuntap.Start(); err != nil { logger.Errorln("An error occurred starting TUN/TAP:", err) diff --git a/cmd/yggdrasilctl/cmd_line_env.go b/cmd/yggdrasilctl/cmd_line_env.go new file mode 100644 index 00000000..bd6df8fc --- /dev/null +++ b/cmd/yggdrasilctl/cmd_line_env.go @@ -0,0 +1,94 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "io/ioutil" + "log" + "os" + + "github.com/hjson/hjson-go" + "golang.org/x/text/encoding/unicode" + + "github.com/yggdrasil-network/yggdrasil-go/src/defaults" +) + +type CmdLineEnv struct { + args []string + endpoint, server string + injson, verbose, ver bool +} + +func newCmdLineEnv() CmdLineEnv { + var cmdLineEnv CmdLineEnv + cmdLineEnv.endpoint = defaults.GetDefaults().DefaultAdminListen + return cmdLineEnv +} + +func (cmdLineEnv *CmdLineEnv) parseFlagsAndArgs() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] command [key=value] [key=value] ...\n\n", os.Args[0]) + fmt.Println("Options:") + flag.PrintDefaults() + fmt.Println() + fmt.Println("Please note that options must always specified BEFORE the command\non the command line or they will be ignored.") + fmt.Println() + fmt.Println("Commands:\n - Use \"list\" for a list of available commands") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" - ", os.Args[0], "list") + fmt.Println(" - ", os.Args[0], "getPeers") + fmt.Println(" - ", os.Args[0], "-v getSelf") + fmt.Println(" - ", os.Args[0], "setTunTap name=auto mtu=1500 tap_mode=false") + fmt.Println(" - ", os.Args[0], "-endpoint=tcp://localhost:9001 getDHT") + fmt.Println(" - ", os.Args[0], "-endpoint=unix:///var/run/ygg.sock getDHT") + } + + server := flag.String("endpoint", cmdLineEnv.endpoint, "Admin socket endpoint") + injson := flag.Bool("json", false, "Output in JSON format (as opposed to pretty-print)") + verbose := flag.Bool("v", false, "Verbose output (includes public keys)") + ver := flag.Bool("version", false, "Prints the version of this build") + + flag.Parse() + + cmdLineEnv.args = flag.Args() + cmdLineEnv.server = *server + cmdLineEnv.injson = *injson + cmdLineEnv.verbose = *verbose + cmdLineEnv.ver = *ver +} + +func (cmdLineEnv *CmdLineEnv) setEndpoint(logger *log.Logger) { + if cmdLineEnv.server == cmdLineEnv.endpoint { + if config, err := ioutil.ReadFile(defaults.GetDefaults().DefaultConfigFile); err == nil { + if bytes.Equal(config[0:2], []byte{0xFF, 0xFE}) || + bytes.Equal(config[0:2], []byte{0xFE, 0xFF}) { + utf := unicode.UTF16(unicode.BigEndian, unicode.UseBOM) + decoder := utf.NewDecoder() + config, err = decoder.Bytes(config) + if err != nil { + panic(err) + } + } + var dat map[string]interface{} + if err := hjson.Unmarshal(config, &dat); err != nil { + panic(err) + } + if ep, ok := dat["AdminListen"].(string); ok && (ep != "none" && ep != "") { + cmdLineEnv.endpoint = ep + logger.Println("Found platform default config file", defaults.GetDefaults().DefaultConfigFile) + logger.Println("Using endpoint", cmdLineEnv.endpoint, "from AdminListen") + } else { + logger.Println("Configuration file doesn't contain appropriate AdminListen option") + logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen) + } + } else { + logger.Println("Can't open config file from default location", defaults.GetDefaults().DefaultConfigFile) + logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen) + } + } else { + cmdLineEnv.endpoint = cmdLineEnv.server + logger.Println("Using endpoint", cmdLineEnv.endpoint, "from command line") + } +} diff --git a/cmd/yggdrasilctl/main.go b/cmd/yggdrasilctl/main.go index 91923392..788b4f19 100644 --- a/cmd/yggdrasilctl/main.go +++ b/cmd/yggdrasilctl/main.go @@ -6,7 +6,6 @@ import ( "errors" "flag" "fmt" - "io/ioutil" "log" "net" "net/url" @@ -15,10 +14,6 @@ import ( "strconv" "strings" - "golang.org/x/text/encoding/unicode" - - "github.com/hjson/hjson-go" - "github.com/yggdrasil-network/yggdrasil-go/src/defaults" "github.com/yggdrasil-network/yggdrasil-go/src/version" ) @@ -32,6 +27,7 @@ func main() { func run() int { logbuffer := &bytes.Buffer{} logger := log.New(logbuffer, "", log.Flags()) + defer func() int { if r := recover(); r != nil { logger.Println("Fatal error:", r) @@ -41,97 +37,24 @@ func run() int { return 0 }() - endpoint := defaults.GetDefaults().DefaultAdminListen + cmdLineEnv := newCmdLineEnv() + cmdLineEnv.parseFlagsAndArgs() - flag.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] command [key=value] [key=value] ...\n\n", os.Args[0]) - fmt.Println("Options:") - flag.PrintDefaults() - fmt.Println() - fmt.Println("Please note that options must always specified BEFORE the command\non the command line or they will be ignored.") - fmt.Println() - fmt.Println("Commands:\n - Use \"list\" for a list of available commands") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" - ", os.Args[0], "list") - fmt.Println(" - ", os.Args[0], "getPeers") - fmt.Println(" - ", os.Args[0], "-v getSelf") - fmt.Println(" - ", os.Args[0], "setTunTap name=auto mtu=1500 tap_mode=false") - fmt.Println(" - ", os.Args[0], "-endpoint=tcp://localhost:9001 getDHT") - fmt.Println(" - ", os.Args[0], "-endpoint=unix:///var/run/ygg.sock getDHT") - } - server := flag.String("endpoint", endpoint, "Admin socket endpoint") - injson := flag.Bool("json", false, "Output in JSON format (as opposed to pretty-print)") - verbose := flag.Bool("v", false, "Verbose output (includes public keys)") - ver := flag.Bool("version", false, "Prints the version of this build") - flag.Parse() - args := flag.Args() - - if *ver { + if cmdLineEnv.ver { fmt.Println("Build name:", version.BuildName()) fmt.Println("Build version:", version.BuildVersion()) fmt.Println("To get the version number of the running Yggdrasil node, run", os.Args[0], "getSelf") return 0 } - if len(args) == 0 { + if len(cmdLineEnv.args) == 0 { flag.Usage() return 0 } - if *server == endpoint { - if config, err := ioutil.ReadFile(defaults.GetDefaults().DefaultConfigFile); err == nil { - if bytes.Equal(config[0:2], []byte{0xFF, 0xFE}) || - bytes.Equal(config[0:2], []byte{0xFE, 0xFF}) { - utf := unicode.UTF16(unicode.BigEndian, unicode.UseBOM) - decoder := utf.NewDecoder() - config, err = decoder.Bytes(config) - if err != nil { - panic(err) - } - } - var dat map[string]interface{} - if err := hjson.Unmarshal(config, &dat); err != nil { - panic(err) - } - if ep, ok := dat["AdminListen"].(string); ok && (ep != "none" && ep != "") { - endpoint = ep - logger.Println("Found platform default config file", defaults.GetDefaults().DefaultConfigFile) - logger.Println("Using endpoint", endpoint, "from AdminListen") - } else { - logger.Println("Configuration file doesn't contain appropriate AdminListen option") - logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen) - } - } else { - logger.Println("Can't open config file from default location", defaults.GetDefaults().DefaultConfigFile) - logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen) - } - } else { - endpoint = *server - logger.Println("Using endpoint", endpoint, "from command line") - } + cmdLineEnv.setEndpoint(logger) - var conn net.Conn - u, err := url.Parse(endpoint) - if err == nil { - switch strings.ToLower(u.Scheme) { - case "unix": - logger.Println("Connecting to UNIX socket", endpoint[7:]) - conn, err = net.Dial("unix", endpoint[7:]) - case "tcp": - logger.Println("Connecting to TCP socket", u.Host) - conn, err = net.Dial("tcp", u.Host) - default: - logger.Println("Unknown protocol or malformed address - check your endpoint") - err = errors.New("protocol not supported") - } - } else { - logger.Println("Connecting to TCP socket", u.Host) - conn, err = net.Dial("tcp", endpoint) - } - if err != nil { - panic(err) - } + conn := connect(cmdLineEnv.endpoint, logger) logger.Println("Connected") defer conn.Close() @@ -140,7 +63,7 @@ func run() int { send := make(admin_info) recv := make(admin_info) - for c, a := range args { + for c, a := range cmdLineEnv.args { if c == 0 { if strings.HasPrefix(a, "-") { logger.Printf("Ignoring flag %s as it should be specified before other parameters\n", a) @@ -176,7 +99,9 @@ func run() int { if err := encoder.Encode(&send); err != nil { panic(err) } + logger.Printf("Request sent") + if err := decoder.Decode(&recv); err == nil { logger.Printf("Response received") if recv["status"] == "error" { @@ -195,252 +120,16 @@ func run() int { fmt.Println("Missing response body (malformed response?)") return 1 } - req := recv["request"].(map[string]interface{}) res := recv["response"].(map[string]interface{}) - if *injson { + if cmdLineEnv.injson { if json, err := json.MarshalIndent(res, "", " "); err == nil { fmt.Println(string(json)) } return 0 } - switch strings.ToLower(req["request"].(string)) { - case "dot": - fmt.Println(res["dot"]) - case "list", "getpeers", "getswitchpeers", "getdht", "getsessions", "dhtping": - maxWidths := make(map[string]int) - var keyOrder []string - keysOrdered := false - - for _, tlv := range res { - for slk, slv := range tlv.(map[string]interface{}) { - if !keysOrdered { - for k := range slv.(map[string]interface{}) { - if !*verbose { - if k == "box_pub_key" || k == "box_sig_key" || k == "nodeinfo" || k == "was_mtu_fixed" { - continue - } - } - keyOrder = append(keyOrder, fmt.Sprint(k)) - } - sort.Strings(keyOrder) - keysOrdered = true - } - for k, v := range slv.(map[string]interface{}) { - if len(fmt.Sprint(slk)) > maxWidths["key"] { - maxWidths["key"] = len(fmt.Sprint(slk)) - } - if len(fmt.Sprint(v)) > maxWidths[k] { - maxWidths[k] = len(fmt.Sprint(v)) - if maxWidths[k] < len(k) { - maxWidths[k] = len(k) - } - } - } - } - - if len(keyOrder) > 0 { - fmt.Printf("%-"+fmt.Sprint(maxWidths["key"])+"s ", "") - for _, v := range keyOrder { - fmt.Printf("%-"+fmt.Sprint(maxWidths[v])+"s ", v) - } - fmt.Println() - } - - for slk, slv := range tlv.(map[string]interface{}) { - fmt.Printf("%-"+fmt.Sprint(maxWidths["key"])+"s ", slk) - for _, k := range keyOrder { - preformatted := slv.(map[string]interface{})[k] - var formatted string - switch k { - case "bytes_sent", "bytes_recvd": - formatted = fmt.Sprintf("%d", uint(preformatted.(float64))) - case "uptime", "last_seen": - seconds := uint(preformatted.(float64)) % 60 - minutes := uint(preformatted.(float64)/60) % 60 - hours := uint(preformatted.(float64) / 60 / 60) - formatted = fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) - default: - formatted = fmt.Sprint(preformatted) - } - fmt.Printf("%-"+fmt.Sprint(maxWidths[k])+"s ", formatted) - } - fmt.Println() - } - } - case "gettuntap", "settuntap": - for k, v := range res { - fmt.Println("Interface name:", k) - if mtu, ok := v.(map[string]interface{})["mtu"].(float64); ok { - fmt.Println("Interface MTU:", mtu) - } - if tap_mode, ok := v.(map[string]interface{})["tap_mode"].(bool); ok { - fmt.Println("TAP mode:", tap_mode) - } - } - case "getself": - for k, v := range res["self"].(map[string]interface{}) { - if buildname, ok := v.(map[string]interface{})["build_name"].(string); ok && buildname != "unknown" { - fmt.Println("Build name:", buildname) - } - if buildversion, ok := v.(map[string]interface{})["build_version"].(string); ok && buildversion != "unknown" { - fmt.Println("Build version:", buildversion) - } - fmt.Println("IPv6 address:", k) - if subnet, ok := v.(map[string]interface{})["subnet"].(string); ok { - fmt.Println("IPv6 subnet:", subnet) - } - if boxSigKey, ok := v.(map[string]interface{})["key"].(string); ok { - fmt.Println("Public key:", boxSigKey) - } - if coords, ok := v.(map[string]interface{})["coords"].(string); ok { - fmt.Println("Coords:", coords) - } - if *verbose { - if nodeID, ok := v.(map[string]interface{})["node_id"].(string); ok { - fmt.Println("Node ID:", nodeID) - } - if boxPubKey, ok := v.(map[string]interface{})["box_pub_key"].(string); ok { - fmt.Println("Public encryption key:", boxPubKey) - } - if boxSigKey, ok := v.(map[string]interface{})["box_sig_key"].(string); ok { - fmt.Println("Public signing key:", boxSigKey) - } - } - } - case "getswitchqueues": - maximumqueuesize := float64(4194304) - portqueues := make(map[float64]float64) - portqueuesize := make(map[float64]float64) - portqueuepackets := make(map[float64]float64) - v := res["switchqueues"].(map[string]interface{}) - if queuecount, ok := v["queues_count"].(float64); ok { - fmt.Printf("Active queue count: %d queues\n", uint(queuecount)) - } - if queuesize, ok := v["queues_size"].(float64); ok { - fmt.Printf("Active queue size: %d bytes\n", uint(queuesize)) - } - if highestqueuecount, ok := v["highest_queues_count"].(float64); ok { - fmt.Printf("Highest queue count: %d queues\n", uint(highestqueuecount)) - } - if highestqueuesize, ok := v["highest_queues_size"].(float64); ok { - fmt.Printf("Highest queue size: %d bytes\n", uint(highestqueuesize)) - } - if m, ok := v["maximum_queues_size"].(float64); ok { - maximumqueuesize = m - fmt.Printf("Maximum queue size: %d bytes\n", uint(maximumqueuesize)) - } - if queues, ok := v["queues"].([]interface{}); ok { - if len(queues) != 0 { - fmt.Println("Active queues:") - for _, v := range queues { - queueport := v.(map[string]interface{})["queue_port"].(float64) - queuesize := v.(map[string]interface{})["queue_size"].(float64) - queuepackets := v.(map[string]interface{})["queue_packets"].(float64) - queueid := v.(map[string]interface{})["queue_id"].(string) - portqueues[queueport]++ - portqueuesize[queueport] += queuesize - portqueuepackets[queueport] += queuepackets - queuesizepercent := (100 / maximumqueuesize) * queuesize - fmt.Printf("- Switch port %d, Stream ID: %v, size: %d bytes (%d%% full), %d packets\n", - uint(queueport), []byte(queueid), uint(queuesize), - uint(queuesizepercent), uint(queuepackets)) - } - } - } - if len(portqueuesize) > 0 && len(portqueuepackets) > 0 { - fmt.Println("Aggregated statistics by switchport:") - for k, v := range portqueuesize { - queuesizepercent := (100 / (portqueues[k] * maximumqueuesize)) * v - fmt.Printf("- Switch port %d, size: %d bytes (%d%% full), %d packets\n", - uint(k), uint(v), uint(queuesizepercent), uint(portqueuepackets[k])) - } - } - case "addpeer", "removepeer", "addallowedencryptionpublickey", "removeallowedencryptionpublickey", "addsourcesubnet", "addroute", "removesourcesubnet", "removeroute": - if _, ok := res["added"]; ok { - for _, v := range res["added"].([]interface{}) { - fmt.Println("Added:", fmt.Sprint(v)) - } - } - if _, ok := res["not_added"]; ok { - for _, v := range res["not_added"].([]interface{}) { - fmt.Println("Not added:", fmt.Sprint(v)) - } - } - if _, ok := res["removed"]; ok { - for _, v := range res["removed"].([]interface{}) { - fmt.Println("Removed:", fmt.Sprint(v)) - } - } - if _, ok := res["not_removed"]; ok { - for _, v := range res["not_removed"].([]interface{}) { - fmt.Println("Not removed:", fmt.Sprint(v)) - } - } - case "getallowedencryptionpublickeys": - if _, ok := res["allowed_box_pubs"]; !ok { - fmt.Println("All connections are allowed") - } else if res["allowed_box_pubs"] == nil { - fmt.Println("All connections are allowed") - } else { - fmt.Println("Connections are allowed only from the following public box keys:") - for _, v := range res["allowed_box_pubs"].([]interface{}) { - fmt.Println("-", v) - } - } - case "getmulticastinterfaces": - if _, ok := res["multicast_interfaces"]; !ok { - fmt.Println("No multicast interfaces found") - } else if res["multicast_interfaces"] == nil { - fmt.Println("No multicast interfaces found") - } else { - fmt.Println("Multicast peer discovery is active on:") - for _, v := range res["multicast_interfaces"].([]interface{}) { - fmt.Println("-", v) - } - } - case "getsourcesubnets": - if _, ok := res["source_subnets"]; !ok { - fmt.Println("No source subnets found") - } else if res["source_subnets"] == nil { - fmt.Println("No source subnets found") - } else { - fmt.Println("Source subnets:") - for _, v := range res["source_subnets"].([]interface{}) { - fmt.Println("-", v) - } - } - case "getroutes": - if routes, ok := res["routes"].(map[string]interface{}); !ok { - fmt.Println("No routes found") - } else { - if res["routes"] == nil || len(routes) == 0 { - fmt.Println("No routes found") - } else { - fmt.Println("Routes:") - for k, v := range routes { - if pv, ok := v.(string); ok { - fmt.Println("-", k, " via ", pv) - } - } - } - } - case "settunnelrouting": - fallthrough - case "gettunnelrouting": - if enabled, ok := res["enabled"].(bool); !ok { - fmt.Println("Tunnel routing is disabled") - } else if !enabled { - fmt.Println("Tunnel routing is disabled") - } else { - fmt.Println("Tunnel routing is enabled") - } - default: - if json, err := json.MarshalIndent(recv["response"], "", " "); err == nil { - fmt.Println(string(json)) - } - } + handleAll(recv, cmdLineEnv.verbose) } else { logger.Println("Error receiving response:", err) } @@ -448,5 +137,321 @@ func run() int { if v, ok := recv["status"]; ok && v != "success" { return 1 } + return 0 } + +func connect(endpoint string, logger *log.Logger) net.Conn { + var conn net.Conn + + u, err := url.Parse(endpoint) + + if err == nil { + switch strings.ToLower(u.Scheme) { + case "unix": + logger.Println("Connecting to UNIX socket", endpoint[7:]) + conn, err = net.Dial("unix", endpoint[7:]) + case "tcp": + logger.Println("Connecting to TCP socket", u.Host) + conn, err = net.Dial("tcp", u.Host) + default: + logger.Println("Unknown protocol or malformed address - check your endpoint") + err = errors.New("protocol not supported") + } + } else { + logger.Println("Connecting to TCP socket", u.Host) + conn, err = net.Dial("tcp", endpoint) + } + + if err != nil { + panic(err) + } + + return conn +} + +func handleAll(recv map[string]interface{}, verbose bool) { + req := recv["request"].(map[string]interface{}) + res := recv["response"].(map[string]interface{}) + + switch strings.ToLower(req["request"].(string)) { + case "dot": + handleDot(res) + case "list", "getpeers", "getswitchpeers", "getdht", "getsessions", "dhtping": + handleVariousInfo(res, verbose) + case "gettuntap", "settuntap": + handleGetAndSetTunTap(res) + case "getself": + handleGetSelf(res, verbose) + case "getswitchqueues": + handleGetSwitchQueues(res) + case "addpeer", "removepeer", "addallowedencryptionpublickey", "removeallowedencryptionpublickey", "addsourcesubnet", "addroute", "removesourcesubnet", "removeroute": + handleAddsAndRemoves(res) + case "getallowedencryptionpublickeys": + handleGetAllowedEncryptionPublicKeys(res) + case "getmulticastinterfaces": + handleGetMulticastInterfaces(res) + case "getsourcesubnets": + handleGetSourceSubnets(res) + case "getroutes": + handleGetRoutes(res) + case "settunnelrouting": + fallthrough + case "gettunnelrouting": + handleGetTunnelRouting(res) + default: + if json, err := json.MarshalIndent(recv["response"], "", " "); err == nil { + fmt.Println(string(json)) + } + } +} + +func handleDot(res map[string]interface{}) { + fmt.Println(res["dot"]) +} + +func handleVariousInfo(res map[string]interface{}, verbose bool) { + maxWidths := make(map[string]int) + var keyOrder []string + keysOrdered := false + + for _, tlv := range res { + for slk, slv := range tlv.(map[string]interface{}) { + if !keysOrdered { + for k := range slv.(map[string]interface{}) { + if !verbose { + if k == "box_pub_key" || k == "box_sig_key" || k == "nodeinfo" || k == "was_mtu_fixed" { + continue + } + } + keyOrder = append(keyOrder, fmt.Sprint(k)) + } + sort.Strings(keyOrder) + keysOrdered = true + } + for k, v := range slv.(map[string]interface{}) { + if len(fmt.Sprint(slk)) > maxWidths["key"] { + maxWidths["key"] = len(fmt.Sprint(slk)) + } + if len(fmt.Sprint(v)) > maxWidths[k] { + maxWidths[k] = len(fmt.Sprint(v)) + if maxWidths[k] < len(k) { + maxWidths[k] = len(k) + } + } + } + } + + if len(keyOrder) > 0 { + fmt.Printf("%-"+fmt.Sprint(maxWidths["key"])+"s ", "") + for _, v := range keyOrder { + fmt.Printf("%-"+fmt.Sprint(maxWidths[v])+"s ", v) + } + fmt.Println() + } + + for slk, slv := range tlv.(map[string]interface{}) { + fmt.Printf("%-"+fmt.Sprint(maxWidths["key"])+"s ", slk) + for _, k := range keyOrder { + preformatted := slv.(map[string]interface{})[k] + var formatted string + switch k { + case "bytes_sent", "bytes_recvd": + formatted = fmt.Sprintf("%d", uint(preformatted.(float64))) + case "uptime", "last_seen": + seconds := uint(preformatted.(float64)) % 60 + minutes := uint(preformatted.(float64)/60) % 60 + hours := uint(preformatted.(float64) / 60 / 60) + formatted = fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + default: + formatted = fmt.Sprint(preformatted) + } + fmt.Printf("%-"+fmt.Sprint(maxWidths[k])+"s ", formatted) + } + fmt.Println() + } + } +} + +func handleGetAndSetTunTap(res map[string]interface{}) { + for k, v := range res { + fmt.Println("Interface name:", k) + if mtu, ok := v.(map[string]interface{})["mtu"].(float64); ok { + fmt.Println("Interface MTU:", mtu) + } + if tap_mode, ok := v.(map[string]interface{})["tap_mode"].(bool); ok { + fmt.Println("TAP mode:", tap_mode) + } + } +} + +func handleGetSelf(res map[string]interface{}, verbose bool) { + for k, v := range res["self"].(map[string]interface{}) { + if buildname, ok := v.(map[string]interface{})["build_name"].(string); ok && buildname != "unknown" { + fmt.Println("Build name:", buildname) + } + if buildversion, ok := v.(map[string]interface{})["build_version"].(string); ok && buildversion != "unknown" { + fmt.Println("Build version:", buildversion) + } + fmt.Println("IPv6 address:", k) + if subnet, ok := v.(map[string]interface{})["subnet"].(string); ok { + fmt.Println("IPv6 subnet:", subnet) + } + if boxSigKey, ok := v.(map[string]interface{})["key"].(string); ok { + fmt.Println("Public key:", boxSigKey) + } + if coords, ok := v.(map[string]interface{})["coords"].(string); ok { + fmt.Println("Coords:", coords) + } + if verbose { + if nodeID, ok := v.(map[string]interface{})["node_id"].(string); ok { + fmt.Println("Node ID:", nodeID) + } + if boxPubKey, ok := v.(map[string]interface{})["box_pub_key"].(string); ok { + fmt.Println("Public encryption key:", boxPubKey) + } + if boxSigKey, ok := v.(map[string]interface{})["box_sig_key"].(string); ok { + fmt.Println("Public signing key:", boxSigKey) + } + } + } +} + +func handleGetSwitchQueues(res map[string]interface{}) { + maximumqueuesize := float64(4194304) + portqueues := make(map[float64]float64) + portqueuesize := make(map[float64]float64) + portqueuepackets := make(map[float64]float64) + v := res["switchqueues"].(map[string]interface{}) + if queuecount, ok := v["queues_count"].(float64); ok { + fmt.Printf("Active queue count: %d queues\n", uint(queuecount)) + } + if queuesize, ok := v["queues_size"].(float64); ok { + fmt.Printf("Active queue size: %d bytes\n", uint(queuesize)) + } + if highestqueuecount, ok := v["highest_queues_count"].(float64); ok { + fmt.Printf("Highest queue count: %d queues\n", uint(highestqueuecount)) + } + if highestqueuesize, ok := v["highest_queues_size"].(float64); ok { + fmt.Printf("Highest queue size: %d bytes\n", uint(highestqueuesize)) + } + if m, ok := v["maximum_queues_size"].(float64); ok { + maximumqueuesize = m + fmt.Printf("Maximum queue size: %d bytes\n", uint(maximumqueuesize)) + } + if queues, ok := v["queues"].([]interface{}); ok { + if len(queues) != 0 { + fmt.Println("Active queues:") + for _, v := range queues { + queueport := v.(map[string]interface{})["queue_port"].(float64) + queuesize := v.(map[string]interface{})["queue_size"].(float64) + queuepackets := v.(map[string]interface{})["queue_packets"].(float64) + queueid := v.(map[string]interface{})["queue_id"].(string) + portqueues[queueport]++ + portqueuesize[queueport] += queuesize + portqueuepackets[queueport] += queuepackets + queuesizepercent := (100 / maximumqueuesize) * queuesize + fmt.Printf("- Switch port %d, Stream ID: %v, size: %d bytes (%d%% full), %d packets\n", + uint(queueport), []byte(queueid), uint(queuesize), + uint(queuesizepercent), uint(queuepackets)) + } + } + } + if len(portqueuesize) > 0 && len(portqueuepackets) > 0 { + fmt.Println("Aggregated statistics by switchport:") + for k, v := range portqueuesize { + queuesizepercent := (100 / (portqueues[k] * maximumqueuesize)) * v + fmt.Printf("- Switch port %d, size: %d bytes (%d%% full), %d packets\n", + uint(k), uint(v), uint(queuesizepercent), uint(portqueuepackets[k])) + } + } +} + +func handleAddsAndRemoves(res map[string]interface{}) { + if _, ok := res["added"]; ok { + for _, v := range res["added"].([]interface{}) { + fmt.Println("Added:", fmt.Sprint(v)) + } + } + if _, ok := res["not_added"]; ok { + for _, v := range res["not_added"].([]interface{}) { + fmt.Println("Not added:", fmt.Sprint(v)) + } + } + if _, ok := res["removed"]; ok { + for _, v := range res["removed"].([]interface{}) { + fmt.Println("Removed:", fmt.Sprint(v)) + } + } + if _, ok := res["not_removed"]; ok { + for _, v := range res["not_removed"].([]interface{}) { + fmt.Println("Not removed:", fmt.Sprint(v)) + } + } +} + +func handleGetAllowedEncryptionPublicKeys(res map[string]interface{}) { + if _, ok := res["allowed_box_pubs"]; !ok { + fmt.Println("All connections are allowed") + } else if res["allowed_box_pubs"] == nil { + fmt.Println("All connections are allowed") + } else { + fmt.Println("Connections are allowed only from the following public box keys:") + for _, v := range res["allowed_box_pubs"].([]interface{}) { + fmt.Println("-", v) + } + } +} + +func handleGetMulticastInterfaces(res map[string]interface{}) { + if _, ok := res["multicast_interfaces"]; !ok { + fmt.Println("No multicast interfaces found") + } else if res["multicast_interfaces"] == nil { + fmt.Println("No multicast interfaces found") + } else { + fmt.Println("Multicast peer discovery is active on:") + for _, v := range res["multicast_interfaces"].([]interface{}) { + fmt.Println("-", v) + } + } +} + +func handleGetSourceSubnets(res map[string]interface{}) { + if _, ok := res["source_subnets"]; !ok { + fmt.Println("No source subnets found") + } else if res["source_subnets"] == nil { + fmt.Println("No source subnets found") + } else { + fmt.Println("Source subnets:") + for _, v := range res["source_subnets"].([]interface{}) { + fmt.Println("-", v) + } + } +} + +func handleGetRoutes(res map[string]interface{}) { + if routes, ok := res["routes"].(map[string]interface{}); !ok { + fmt.Println("No routes found") + } else { + if res["routes"] == nil || len(routes) == 0 { + fmt.Println("No routes found") + } else { + fmt.Println("Routes:") + for k, v := range routes { + if pv, ok := v.(string); ok { + fmt.Println("-", k, " via ", pv) + } + } + } + } +} + +func handleGetTunnelRouting(res map[string]interface{}) { + if enabled, ok := res["enabled"].(bool); !ok { + fmt.Println("Tunnel routing is disabled") + } else if !enabled { + fmt.Println("Tunnel routing is disabled") + } else { + fmt.Println("Tunnel routing is enabled") + } +} diff --git a/contrib/systemd/yggdrasil.service b/contrib/systemd/yggdrasil.service index 3002e61b..27d27907 100644 --- a/contrib/systemd/yggdrasil.service +++ b/contrib/systemd/yggdrasil.service @@ -10,7 +10,7 @@ Group=yggdrasil ProtectHome=true ProtectSystem=true SyslogIdentifier=yggdrasil -CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW +CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE ExecStartPre=+-/sbin/modprobe tun ExecStart=/usr/bin/yggdrasil -useconffile /etc/yggdrasil.conf ExecReload=/bin/kill -HUP $MAINPID diff --git a/src/address/address.go b/src/address/address.go index 7add23ac..0e2400ed 100644 --- a/src/address/address.go +++ b/src/address/address.go @@ -64,7 +64,7 @@ func AddrForKey(publicKey ed25519.PublicKey) *Address { buf[idx] = ^buf[idx] } var addr Address - var temp []byte + var temp = make([]byte, 0, 32) done := false ones := byte(0) bits := byte(0) diff --git a/src/core/api.go b/src/core/api.go index 05d9f36f..c312923d 100644 --- a/src/core/api.go +++ b/src/core/api.go @@ -48,7 +48,7 @@ type Session struct { func (c *Core) GetSelf() Self { var self Self - s := c.pc.PacketConn.Debug.GetSelf() + s := c.PacketConn.PacketConn.Debug.GetSelf() self.Key = s.Key self.Root = s.Root self.Coords = s.Coords @@ -63,7 +63,7 @@ func (c *Core) GetPeers() []Peer { names[info.conn] = info.lname } c.links.mutex.Unlock() - ps := c.pc.PacketConn.Debug.GetPeers() + ps := c.PacketConn.PacketConn.Debug.GetPeers() for _, p := range ps { var info Peer info.Key = p.Key @@ -81,7 +81,7 @@ func (c *Core) GetPeers() []Peer { func (c *Core) GetDHT() []DHTEntry { var dhts []DHTEntry - ds := c.pc.PacketConn.Debug.GetDHT() + ds := c.PacketConn.PacketConn.Debug.GetDHT() for _, d := range ds { var info DHTEntry info.Key = d.Key @@ -94,7 +94,7 @@ func (c *Core) GetDHT() []DHTEntry { func (c *Core) GetPaths() []PathEntry { var paths []PathEntry - ps := c.pc.PacketConn.Debug.GetPaths() + ps := c.PacketConn.PacketConn.Debug.GetPaths() for _, p := range ps { var info PathEntry info.Key = p.Key @@ -106,7 +106,7 @@ func (c *Core) GetPaths() []PathEntry { func (c *Core) GetSessions() []Session { var sessions []Session - ss := c.pc.Debug.GetSessions() + ss := c.PacketConn.Debug.GetSessions() for _, s := range ss { var info Session info.Key = s.Key @@ -239,43 +239,6 @@ func (c *Core) PublicKey() ed25519.PublicKey { return c.public } -func (c *Core) MaxMTU() uint64 { - return c.store.maxSessionMTU() -} - -func (c *Core) SetMTU(mtu uint64) { - if mtu < 1280 { - mtu = 1280 - } - c.store.mutex.Lock() - c.store.mtu = mtu - c.store.mutex.Unlock() -} - -func (c *Core) MTU() uint64 { - c.store.mutex.Lock() - mtu := c.store.mtu - c.store.mutex.Unlock() - return mtu -} - -// Implement io.ReadWriteCloser - -func (c *Core) Read(p []byte) (n int, err error) { - n, err = c.store.readPC(p) - return -} - -func (c *Core) Write(p []byte) (n int, err error) { - n, err = c.store.writePC(p) - return -} - -func (c *Core) Close() error { - c.Stop() - return nil -} - // Hack to get the admin stuff working, TODO something cleaner type AddHandler interface { diff --git a/src/core/core.go b/src/core/core.go index 89d49177..0332980b 100644 --- a/src/core/core.go +++ b/src/core/core.go @@ -7,10 +7,12 @@ import ( "errors" "fmt" "io/ioutil" + "net" "net/url" "time" - iw "github.com/Arceliar/ironwood/encrypted" + iwe "github.com/Arceliar/ironwood/encrypted" + iwt "github.com/Arceliar/ironwood/types" "github.com/Arceliar/phony" "github.com/gologme/log" @@ -26,13 +28,12 @@ type Core struct { // We're going to keep our own copy of the provided config - that way we can // guarantee that it will be covered by the mutex phony.Inbox - pc *iw.PacketConn + *iwe.PacketConn config *config.NodeConfig // Config secret ed25519.PrivateKey public ed25519.PublicKey links links proto protoHandler - store keyStore log *log.Logger addPeerTimer *time.Timer ctx context.Context @@ -62,9 +63,8 @@ func (c *Core) _init() error { c.public = c.secret.Public().(ed25519.PublicKey) // TODO check public against current.PublicKey, error if they don't match - c.pc, err = iw.NewPacketConn(c.secret) + c.PacketConn, err = iwe.NewPacketConn(c.secret) c.ctx, c.ctxCancel = context.WithCancel(context.Background()) - c.store.init(c) c.proto.init(c) if err := c.proto.nodeinfo.setNodeInfo(c.config.NodeInfo, c.config.NodeInfoPrivacy); err != nil { return fmt.Errorf("setNodeInfo: %w", err) @@ -161,23 +161,79 @@ func (c *Core) _start(nc *config.NodeConfig, log *log.Logger) error { // Stop shuts down the Yggdrasil node. func (c *Core) Stop() { - phony.Block(c, c._stop) + phony.Block(c, func() { + c.log.Infoln("Stopping...") + c._close() + c.log.Infoln("Stopped") + }) +} + +func (c *Core) Close() error { + var err error + phony.Block(c, func() { + err = c._close() + }) + return err } // This function is unsafe and should only be ran by the core actor. -func (c *Core) _stop() { - c.log.Infoln("Stopping...") +func (c *Core) _close() error { c.ctxCancel() - c.pc.Close() + err := c.PacketConn.Close() if c.addPeerTimer != nil { c.addPeerTimer.Stop() c.addPeerTimer = nil } _ = c.links.stop() - /* FIXME this deadlocks, need a waitgroup or something to coordinate shutdown - for _, peer := range c.GetPeers() { - c.DisconnectPeer(peer.Port) - } - */ - c.log.Infoln("Stopped") + return err +} + +func (c *Core) MTU() uint64 { + const sessionTypeOverhead = 1 + return c.PacketConn.MTU() - sessionTypeOverhead +} + +func (c *Core) ReadFrom(p []byte) (n int, from net.Addr, err error) { + buf := make([]byte, c.PacketConn.MTU(), 65535) + for { + bs := buf + n, from, err = c.PacketConn.ReadFrom(bs) + if err != nil { + return 0, from, err + } + if n == 0 { + continue + } + switch bs[0] { + case typeSessionTraffic: + // This is what we want to handle here + case typeSessionProto: + var key keyArray + copy(key[:], from.(iwt.Addr)) + data := append([]byte(nil), bs[1:n]...) + c.proto.handleProto(nil, key, data) + continue + default: + continue + } + bs = bs[1:n] + copy(p, bs) + if len(p) < len(bs) { + n = len(p) + } else { + n = len(bs) + } + return + } +} + +func (c *Core) WriteTo(p []byte, addr net.Addr) (n int, err error) { + buf := make([]byte, 0, 65535) + buf = append(buf, typeSessionTraffic) + buf = append(buf, p...) + n, err = c.PacketConn.WriteTo(buf, addr) + if n > 0 { + n -= 1 + } + return } diff --git a/src/core/core_test.go b/src/core/core_test.go index dd60af21..fcfe2e31 100644 --- a/src/core/core_test.go +++ b/src/core/core_test.go @@ -44,13 +44,11 @@ func CreateAndConnectTwo(t testing.TB, verbose bool) (nodeA *Core, nodeB *Core) if err := nodeA.Start(GenerateConfig(), GetLoggerWithPrefix("A: ", verbose)); err != nil { t.Fatal(err) } - nodeA.SetMTU(1500) nodeB = new(Core) if err := nodeB.Start(GenerateConfig(), GetLoggerWithPrefix("B: ", verbose)); err != nil { t.Fatal(err) } - nodeB.SetMTU(1500) u, err := url.Parse("tcp://" + nodeA.links.tcp.getAddr().String()) if err != nil { @@ -94,7 +92,7 @@ func CreateEchoListener(t testing.TB, nodeA *Core, bufLen int, repeats int) chan buf := make([]byte, bufLen) res := make([]byte, bufLen) for i := 0; i < repeats; i++ { - n, err := nodeA.Read(buf) + n, from, err := nodeA.ReadFrom(buf) if err != nil { t.Error(err) return @@ -106,7 +104,7 @@ func CreateEchoListener(t testing.TB, nodeA *Core, bufLen int, repeats int) chan copy(res, buf) copy(res[8:24], buf[24:40]) copy(res[24:40], buf[8:24]) - _, err = nodeA.Write(res) + _, err = nodeA.WriteTo(res, from) if err != nil { t.Error(err) } @@ -141,12 +139,12 @@ func TestCore_Start_Transfer(t *testing.T) { msg[0] = 0x60 copy(msg[8:24], nodeB.Address()) copy(msg[24:40], nodeA.Address()) - _, err := nodeB.Write(msg) + _, err := nodeB.WriteTo(msg, nodeA.LocalAddr()) if err != nil { t.Fatal(err) } buf := make([]byte, msgLen) - _, err = nodeB.Read(buf) + _, _, err = nodeB.ReadFrom(buf) if err != nil { t.Fatal(err) } @@ -179,12 +177,13 @@ func BenchmarkCore_Start_Transfer(b *testing.B) { b.SetBytes(int64(msgLen)) b.ResetTimer() + addr := nodeA.LocalAddr() for i := 0; i < b.N; i++ { - _, err := nodeB.Write(msg) + _, err := nodeB.WriteTo(msg, addr) if err != nil { b.Fatal(err) } - _, err = nodeB.Read(buf) + _, _, err = nodeB.ReadFrom(buf) if err != nil { b.Fatal(err) } diff --git a/src/core/link.go b/src/core/link.go index 165b18b2..ccab9219 100644 --- a/src/core/link.go +++ b/src/core/link.go @@ -230,7 +230,7 @@ func (intf *link) handler() (chan struct{}, error) { intf.links.core.log.Infof("Connected %s: %s, source %s", strings.ToUpper(intf.info.linkType), themString, intf.info.local) // Run the handler - err = intf.links.core.pc.HandleConn(ed25519.PublicKey(intf.info.key[:]), intf.conn) + err = intf.links.core.HandleConn(ed25519.PublicKey(intf.info.key[:]), intf.conn) // TODO don't report an error if it's just a 'use of closed network connection' if err != nil { intf.links.core.log.Infof("Disconnected %s: %s, source %s; error: %s", diff --git a/src/core/nodeinfo.go b/src/core/nodeinfo.go index 30644710..4ca21d73 100644 --- a/src/core/nodeinfo.go +++ b/src/core/nodeinfo.go @@ -129,7 +129,7 @@ func (m *nodeinfo) _sendReq(key keyArray, callback func(nodeinfo NodeInfoPayload if callback != nil { m._addCallback(key, callback) } - _, _ = m.proto.core.pc.WriteTo([]byte{typeSessionProto, typeProtoNodeInfoRequest}, iwt.Addr(key[:])) + _, _ = m.proto.core.PacketConn.WriteTo([]byte{typeSessionProto, typeProtoNodeInfoRequest}, iwt.Addr(key[:])) } func (m *nodeinfo) handleReq(from phony.Actor, key keyArray) { @@ -146,7 +146,7 @@ func (m *nodeinfo) handleRes(from phony.Actor, key keyArray, info NodeInfoPayloa func (m *nodeinfo) _sendRes(key keyArray) { bs := append([]byte{typeSessionProto, typeProtoNodeInfoResponse}, m._getNodeInfo()...) - _, _ = m.proto.core.pc.WriteTo(bs, iwt.Addr(key[:])) + _, _ = m.proto.core.PacketConn.WriteTo(bs, iwt.Addr(key[:])) } // Admin socket stuff diff --git a/src/core/proto.go b/src/core/proto.go index 557ac1d5..e60caeff 100644 --- a/src/core/proto.go +++ b/src/core/proto.go @@ -1,6 +1,7 @@ package core import ( + "crypto/ed25519" "encoding/hex" "encoding/json" "errors" @@ -29,6 +30,8 @@ type reqInfo struct { timer *time.Timer // time.AfterFunc cleanup } +type keyArray [ed25519.PublicKeySize]byte + type protoHandler struct { phony.Inbox core *Core @@ -149,7 +152,7 @@ func (p *protoHandler) _handleGetPeersRequest(key keyArray) { for _, pinfo := range peers { tmp := append(bs, pinfo.Key[:]...) const responseOverhead = 2 // 1 debug type, 1 getpeers type - if uint64(len(tmp))+responseOverhead > p.core.store.maxSessionMTU() { + if uint64(len(tmp))+responseOverhead > p.core.MTU() { break } bs = tmp @@ -191,7 +194,7 @@ func (p *protoHandler) _handleGetDHTRequest(key keyArray) { for _, dinfo := range dinfos { tmp := append(bs, dinfo.Key[:]...) const responseOverhead = 2 // 1 debug type, 1 getdht type - if uint64(len(tmp))+responseOverhead > p.core.store.maxSessionMTU() { + if uint64(len(tmp))+responseOverhead > p.core.MTU() { break } bs = tmp @@ -209,7 +212,7 @@ func (p *protoHandler) _handleGetDHTResponse(key keyArray, bs []byte) { func (p *protoHandler) _sendDebug(key keyArray, dType uint8, data []byte) { bs := append([]byte{typeSessionProto, typeProtoDebug, dType}, data...) - _, _ = p.core.pc.WriteTo(bs, iwt.Addr(key[:])) + _, _ = p.core.PacketConn.WriteTo(bs, iwt.Addr(key[:])) } // Admin socket stuff diff --git a/src/core/types.go b/src/core/types.go index e325b55e..258563a1 100644 --- a/src/core/types.go +++ b/src/core/types.go @@ -1,12 +1,5 @@ package core -// Out-of-band packet types -const ( - typeKeyDummy = iota // nolint:deadcode,varcheck - typeKeyLookup - typeKeyResponse -) - // In-band packet types const ( typeSessionDummy = iota // nolint:deadcode,varcheck diff --git a/src/core/icmpv6.go b/src/ipv6rwc/icmpv6.go similarity index 99% rename from src/core/icmpv6.go rename to src/ipv6rwc/icmpv6.go index d15fbbcb..8faf1d51 100644 --- a/src/core/icmpv6.go +++ b/src/ipv6rwc/icmpv6.go @@ -1,4 +1,4 @@ -package core +package ipv6rwc // The ICMPv6 module implements functions to easily create ICMPv6 // packets. These functions, when mixed with the built-in Go IPv6 diff --git a/src/core/keystore.go b/src/ipv6rwc/ipv6rwc.go similarity index 77% rename from src/core/keystore.go rename to src/ipv6rwc/ipv6rwc.go index 21fb8459..1c715f0f 100644 --- a/src/core/keystore.go +++ b/src/ipv6rwc/ipv6rwc.go @@ -1,4 +1,4 @@ -package core +package ipv6rwc import ( "crypto/ed25519" @@ -14,14 +14,22 @@ import ( iwt "github.com/Arceliar/ironwood/types" "github.com/yggdrasil-network/yggdrasil-go/src/address" + "github.com/yggdrasil-network/yggdrasil-go/src/core" ) const keyStoreTimeout = 2 * time.Minute +// Out-of-band packet types +const ( + typeKeyDummy = iota // nolint:deadcode,varcheck + typeKeyLookup + typeKeyResponse +) + type keyArray [ed25519.PublicKeySize]byte type keyStore struct { - core *Core + core *core.Core address address.Address subnet address.Subnet mutex sync.Mutex @@ -45,11 +53,11 @@ type buffer struct { timeout *time.Timer } -func (k *keyStore) init(core *Core) { - k.core = core - k.address = *address.AddrForKey(k.core.public) - k.subnet = *address.SubnetForKey(k.core.public) - if err := k.core.pc.SetOutOfBandHandler(k.oobHandler); err != nil { +func (k *keyStore) init(c *core.Core) { + k.core = c + k.address = *address.AddrForKey(k.core.PublicKey()) + k.subnet = *address.SubnetForKey(k.core.PublicKey()) + if err := k.core.SetOutOfBandHandler(k.oobHandler); err != nil { err = fmt.Errorf("tun.core.SetOutOfBandHander: %w", err) panic(err) } @@ -66,7 +74,7 @@ func (k *keyStore) sendToAddress(addr address.Address, bs []byte) { if info := k.addrToInfo[addr]; info != nil { k.resetTimeout(info) k.mutex.Unlock() - _, _ = k.core.pc.WriteTo(bs, iwt.Addr(info.key[:])) + _, _ = k.core.WriteTo(bs, iwt.Addr(info.key[:])) } else { var buf *buffer if buf = k.addrBuffer[addr]; buf == nil { @@ -95,7 +103,7 @@ func (k *keyStore) sendToSubnet(subnet address.Subnet, bs []byte) { if info := k.subnetToInfo[subnet]; info != nil { k.resetTimeout(info) k.mutex.Unlock() - _, _ = k.core.pc.WriteTo(bs, iwt.Addr(info.key[:])) + _, _ = k.core.WriteTo(bs, iwt.Addr(info.key[:])) } else { var buf *buffer if buf = k.subnetBuffer[subnet]; buf == nil { @@ -135,11 +143,11 @@ func (k *keyStore) update(key ed25519.PublicKey) *keyInfo { k.resetTimeout(info) k.mutex.Unlock() if buf := k.addrBuffer[info.address]; buf != nil { - k.core.pc.WriteTo(buf.packet, iwt.Addr(info.key[:])) + k.core.WriteTo(buf.packet, iwt.Addr(info.key[:])) delete(k.addrBuffer, info.address) } if buf := k.subnetBuffer[info.subnet]; buf != nil { - k.core.pc.WriteTo(buf.packet, iwt.Addr(info.key[:])) + k.core.WriteTo(buf.packet, iwt.Addr(info.key[:])) delete(k.subnetBuffer, info.subnet) } } else { @@ -191,46 +199,29 @@ func (k *keyStore) oobHandler(fromKey, toKey ed25519.PublicKey, data []byte) { } func (k *keyStore) sendKeyLookup(partial ed25519.PublicKey) { - sig := ed25519.Sign(k.core.secret, partial[:]) + sig := ed25519.Sign(k.core.PrivateKey(), partial[:]) bs := append([]byte{typeKeyLookup}, sig...) - _ = k.core.pc.SendOutOfBand(partial, bs) + _ = k.core.SendOutOfBand(partial, bs) } func (k *keyStore) sendKeyResponse(dest ed25519.PublicKey) { - sig := ed25519.Sign(k.core.secret, dest[:]) + sig := ed25519.Sign(k.core.PrivateKey(), dest[:]) bs := append([]byte{typeKeyResponse}, sig...) - _ = k.core.pc.SendOutOfBand(dest, bs) -} - -func (k *keyStore) maxSessionMTU() uint64 { - const sessionTypeOverhead = 1 - return k.core.pc.MTU() - sessionTypeOverhead + _ = k.core.SendOutOfBand(dest, bs) } func (k *keyStore) readPC(p []byte) (int, error) { - buf := make([]byte, k.core.pc.MTU(), 65535) + buf := make([]byte, k.core.MTU(), 65535) for { bs := buf - n, from, err := k.core.pc.ReadFrom(bs) + n, from, err := k.core.ReadFrom(bs) if err != nil { return n, err } if n == 0 { continue } - switch bs[0] { - case typeSessionTraffic: - // This is what we want to handle here - case typeSessionProto: - var key keyArray - copy(key[:], from.(iwt.Addr)) - data := append([]byte(nil), bs[1:n]...) - k.core.proto.handleProto(nil, key, data) - continue - default: - continue - } - bs = bs[1:n] + bs = bs[:n] if len(bs) == 0 { continue } @@ -294,15 +285,69 @@ func (k *keyStore) writePC(bs []byte) (int, error) { strErr := fmt.Sprint("incorrect source address: ", net.IP(srcAddr[:]).String()) return 0, errors.New(strErr) } - buf := make([]byte, 1+len(bs), 65535) - buf[0] = typeSessionTraffic - copy(buf[1:], bs) if dstAddr.IsValid() { - k.sendToAddress(dstAddr, buf) + k.sendToAddress(dstAddr, bs) } else if dstSubnet.IsValid() { - k.sendToSubnet(dstSubnet, buf) + k.sendToSubnet(dstSubnet, bs) } else { return 0, errors.New("invalid destination address") } return len(bs), nil } + +// Exported API + +func (k *keyStore) MaxMTU() uint64 { + return k.core.MTU() +} + +func (k *keyStore) SetMTU(mtu uint64) { + if mtu > k.MaxMTU() { + mtu = k.MaxMTU() + } + if mtu < 1280 { + mtu = 1280 + } + k.mutex.Lock() + k.mtu = mtu + k.mutex.Unlock() +} + +func (k *keyStore) MTU() uint64 { + k.mutex.Lock() + mtu := k.mtu + k.mutex.Unlock() + return mtu +} + +type ReadWriteCloser struct { + keyStore +} + +func NewReadWriteCloser(c *core.Core) *ReadWriteCloser { + rwc := new(ReadWriteCloser) + rwc.init(c) + return rwc +} + +func (rwc *ReadWriteCloser) Address() address.Address { + return rwc.address +} + +func (rwc *ReadWriteCloser) Subnet() address.Subnet { + return rwc.subnet +} + +func (rwc *ReadWriteCloser) Read(p []byte) (n int, err error) { + return rwc.readPC(p) +} + +func (rwc *ReadWriteCloser) Write(p []byte) (n int, err error) { + return rwc.writePC(p) +} + +func (rwc *ReadWriteCloser) Close() error { + err := rwc.core.Close() + rwc.core.Stop() + return err +} diff --git a/src/tuntap/iface.go b/src/tuntap/iface.go index e72b091f..f629399a 100644 --- a/src/tuntap/iface.go +++ b/src/tuntap/iface.go @@ -17,7 +17,7 @@ func (tun *TunAdapter) read() { begin := TUN_OFFSET_BYTES end := begin + n bs := buf[begin:end] - if _, err := tun.core.Write(bs); err != nil { + if _, err := tun.rwc.Write(bs); err != nil { tun.log.Debugln("Unable to send packet:", err) } } @@ -27,7 +27,7 @@ func (tun *TunAdapter) write() { var buf [TUN_OFFSET_BYTES + 65535]byte for { bs := buf[TUN_OFFSET_BYTES:] - n, err := tun.core.Read(bs) + n, err := tun.rwc.Read(bs) if err != nil { tun.log.Errorln("Exiting tun writer due to core read error:", err) return diff --git a/src/tuntap/tun.go b/src/tuntap/tun.go index 8cbe537b..fb483a0f 100644 --- a/src/tuntap/tun.go +++ b/src/tuntap/tun.go @@ -21,8 +21,8 @@ import ( "github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/config" - "github.com/yggdrasil-network/yggdrasil-go/src/core" "github.com/yggdrasil-network/yggdrasil-go/src/defaults" + "github.com/yggdrasil-network/yggdrasil-go/src/ipv6rwc" ) type MTU uint16 @@ -32,7 +32,7 @@ type MTU uint16 // should pass this object to the yggdrasil.SetRouterAdapter() function before // calling yggdrasil.Start(). type TunAdapter struct { - core *core.Core + rwc *ipv6rwc.ReadWriteCloser config *config.NodeConfig log *log.Logger addr address.Address @@ -94,8 +94,8 @@ func MaximumMTU() uint64 { // Init initialises the TUN module. You must have acquired a Listener from // the Yggdrasil core before this point and it must not be in use elsewhere. -func (tun *TunAdapter) Init(core *core.Core, config *config.NodeConfig, log *log.Logger, options interface{}) error { - tun.core = core +func (tun *TunAdapter) Init(rwc *ipv6rwc.ReadWriteCloser, config *config.NodeConfig, log *log.Logger, options interface{}) error { + tun.rwc = rwc tun.config = config tun.log = log return nil @@ -120,9 +120,8 @@ func (tun *TunAdapter) _start() error { if tun.config == nil { return errors.New("no configuration available to TUN") } - pk := tun.core.PublicKey() - tun.addr = *address.AddrForKey(pk) - tun.subnet = *address.SubnetForKey(pk) + tun.addr = tun.rwc.Address() + tun.subnet = tun.rwc.Subnet() addr := fmt.Sprintf("%s/%d", net.IP(tun.addr[:]).String(), 8*len(address.GetPrefix())-1) if tun.config.IfName == "none" || tun.config.IfName == "dummy" { tun.log.Debugln("Not starting TUN as ifname is none or dummy") @@ -131,8 +130,8 @@ func (tun *TunAdapter) _start() error { return nil } mtu := tun.config.IfMTU - if tun.core.MaxMTU() < mtu { - mtu = tun.core.MaxMTU() + if tun.rwc.MaxMTU() < mtu { + mtu = tun.rwc.MaxMTU() } if err := tun.setup(tun.config.IfName, addr, mtu); err != nil { return err @@ -140,7 +139,7 @@ func (tun *TunAdapter) _start() error { if tun.MTU() != mtu { tun.log.Warnf("Warning: Interface MTU %d automatically adjusted to %d (supported range is 1280-%d)", tun.config.IfMTU, tun.MTU(), MaximumMTU()) } - tun.core.SetMTU(tun.MTU()) + tun.rwc.SetMTU(tun.MTU()) tun.isOpen = true tun.ckr.init(tun) tun.isEnabled = true