Fix unix domain socket creation/cleanup logic

* If there is another instance still listening on the same
    Unix domain socket, bail out

  * If there is a leftover from crashed yggstack etc,
    clean the socket file and proceed

Signed-off-by: Vasyl Gello <vasek.gello@gmail.com>
This commit is contained in:
Vasyl Gello 2024-07-15 13:15:22 +03:00
parent 23d4321be4
commit ef1c547a3f
No known key found for this signature in database
GPG key ID: 8A52BC6C291FB280

View file

@ -12,6 +12,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"regexp" "regexp"
"runtime"
"strings" "strings"
"syscall" "syscall"
@ -273,44 +274,48 @@ func main() {
// Create SOCKS server // Create SOCKS server
{ {
if socks != nil { if socks != nil && *socks != "" {
if nameserver != nil { socksOptions := []socks5.Option{
if strings.Contains(*socks, ":") { socks5.WithDial(s.DialContext),
logger.Infof("Starting SOCKS server on %s", *socks) }
resolver := types.NewNameResolver(s, *nameserver) if nameserver != nil && *nameserver != "" {
socksOptions := []socks5.Option{ resolver := types.NewNameResolver(s, *nameserver)
socks5.WithDial(s.DialContext), socksOptions = append(socksOptions, socks5.WithResolver(resolver))
socks5.WithResolver(resolver), } else {
} logger.Warningf("DNS nameserver is not set!")
if logger.GetLevel("debug") { logger.Warningf("SOCKS server will not be able to resolve hostnames other than .pk.ygg !")
socksOptions = append(socksOptions, socks5.WithLogger(logger)) }
} if logger.GetLevel("debug") {
server := socks5.NewServer(socksOptions...) socksOptions = append(socksOptions, socks5.WithLogger(logger))
go server.ListenAndServe("tcp", *socks) // nolint:errcheck }
} else { server := socks5.NewServer(socksOptions...)
logger.Infof("Starting SOCKS server with socket file %s", *socks) if strings.Contains(*socks, ":") {
_, err := os.Stat(*socks) logger.Infof("Starting SOCKS server on %s", *socks)
if os.IsNotExist(err) { go server.ListenAndServe("tcp", *socks) // nolint:errcheck
n.socks5Listener, err = net.Listen("unix", *socks) } else {
logger.Infof("Starting SOCKS server with socket file %s", *socks)
n.socks5Listener, err = net.Listen("unix", *socks)
if err != nil {
// If address in use, try connecting to
// the socket to see if other yggstack
// instance is listening on it
if isErrorAddressAlreadyInUse(err) {
_, err = net.Dial("unix", *socks)
if err != nil { if err != nil {
panic(err) // Unlink dead socket if not connected
err = os.RemoveAll(*socks)
if err != nil {
panic(err)
}
} else {
panic(fmt.Errorf("Another yggstack instance is listening on socket '%s'", *socks))
} }
resolver := types.NewNameResolver(s, *nameserver)
socksOptions := []socks5.Option{
socks5.WithDial(s.DialContext),
socks5.WithResolver(resolver),
}
if logger.GetLevel("debug") {
socksOptions = append(socksOptions, socks5.WithLogger(logger))
}
server := socks5.NewServer(socksOptions...)
go server.Serve(n.socks5Listener) // nolint:errcheck
} else if err != nil {
logger.Errorf("Cannot create socket file %s: %s", *socks, err)
} else { } else {
panic(errors.New(fmt.Sprintf("Socket file %s already exists", *socks))) panic(err)
} }
} }
go server.Serve(n.socks5Listener) // nolint:errcheck
} }
} }
} }
@ -349,11 +354,34 @@ func main() {
_ = n.multicast.Stop() _ = n.multicast.Stop()
if n.socks5Listener != nil { if n.socks5Listener != nil {
_ = n.socks5Listener.Close() _ = n.socks5Listener.Close()
_ = os.RemoveAll(*socks)
logger.Infof("Stopped UNIX socket listener") logger.Infof("Stopped UNIX socket listener")
} }
n.core.Stop() n.core.Stop()
} }
// Helper to detect if socket address is in use
// https://stackoverflow.com/a/52152912
func isErrorAddressAlreadyInUse(err error) bool {
var eOsSyscall *os.SyscallError
if !errors.As(err, &eOsSyscall) {
return false
}
var errErrno syscall.Errno // doesn't need a "*" (ptr) because it's already a ptr (uintptr)
if !errors.As(eOsSyscall, &errErrno) {
return false
}
if errors.Is(errErrno, syscall.EADDRINUSE) {
return true
}
const WSAEADDRINUSE = 10048
if runtime.GOOS == "windows" && errErrno == WSAEADDRINUSE {
return true
}
return false
}
// Helper to set logging level
func setLogLevel(loglevel string, logger *log.Logger) { func setLogLevel(loglevel string, logger *log.Logger) {
levels := [...]string{"error", "warn", "info", "debug", "trace"} levels := [...]string{"error", "warn", "info", "debug", "trace"}
loglevel = strings.ToLower(loglevel) loglevel = strings.ToLower(loglevel)