diff --git a/cmd/yggdrasilctl/main.go b/cmd/yggdrasilctl/main.go index 815e6095..e3bbf24c 100644 --- a/cmd/yggdrasilctl/main.go +++ b/cmd/yggdrasilctl/main.go @@ -66,30 +66,7 @@ func run() int { cmdLineEnv.setEndpoint(logger) - var conn net.Conn - u, err := url.Parse(cmdLineEnv.endpoint) - - if err == nil { - switch strings.ToLower(u.Scheme) { - case "unix": - logger.Println("Connecting to UNIX socket", cmdLineEnv.endpoint[7:]) - conn, err = net.Dial("unix", cmdLineEnv.endpoint[7:]) - case "tcp": - logger.Println("Connecting to TCP socket", u.Host) - conn, err = net.Dial("tcp", u.Host) - default: - logger.Println("Unknown protocol or malformed address - check your endpoint") - err = errors.New("protocol not supported") - } - } else { - logger.Println("Connecting to TCP socket", u.Host) - conn, err = net.Dial("tcp", cmdLineEnv.endpoint) - } - - if err != nil { - panic(err) - } - + conn := connect(cmdLineEnv.endpoint, logger) logger.Println("Connected") defer conn.Close() @@ -249,6 +226,35 @@ func (cmdLineEnv *CmdLineEnv)setEndpoint(logger *log.Logger) { } } +func connect(endpoint string, logger *log.Logger) net.Conn { + var conn net.Conn + + u, err := url.Parse(endpoint) + + if err == nil { + switch strings.ToLower(u.Scheme) { + case "unix": + logger.Println("Connecting to UNIX socket", endpoint[7:]) + conn, err = net.Dial("unix", endpoint[7:]) + case "tcp": + logger.Println("Connecting to TCP socket", u.Host) + conn, err = net.Dial("tcp", u.Host) + default: + logger.Println("Unknown protocol or malformed address - check your endpoint") + err = errors.New("protocol not supported") + } + } else { + logger.Println("Connecting to TCP socket", u.Host) + conn, err = net.Dial("tcp", endpoint) + } + + if err != nil { + panic(err) + } + + return conn +} + func runAll(recv map[string]interface{}, verbose bool) { req := recv["request"].(map[string]interface{}) res := recv["response"].(map[string]interface{})