diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index 82b85cd4..6e7e312b 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "crypto/ed25519" "encoding/hex" "encoding/json" @@ -12,6 +11,7 @@ import ( "os/signal" "regexp" "strings" + "sync" "syscall" "github.com/gologme/log" @@ -37,6 +37,11 @@ type node struct { admin *admin.AdminSocket } +var ( + sigCh = make(chan os.Signal, 1) + doneCh = make(chan struct{}) +) + // The main function is responsible for configuring and starting Yggdrasil. func main() { genconf := flag.Bool("genconf", false, "print a new config to stdout") @@ -53,12 +58,8 @@ func main() { getpkey := flag.Bool("publickey", false, "use in combination with either -useconf or -useconffile, outputs your public key") loglevel := flag.String("loglevel", "info", "loglevel to enable") flag.Parse() - - // Catch interrupts from the operating system to exit gracefully. - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - - // Capture the service being stopped on Windows. - minwinsvc.SetOnExit(cancel) + + signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) // Create a new logger that logs output to stdout. var logger *log.Logger @@ -271,15 +272,53 @@ func main() { n.tun.SetupAdminHandlers(n.admin) } } + + //Windows service shutdown service + minwinsvc.SetOnExit(func() { + logger.Infof("Shutting down service ...") + sigCh <- os.Interrupt + // Wait for all parts to shutdown properly + <-doneCh + }) // Block until we are told to shut down. - <-ctx.Done() + <-sigCh - // Shut down the node. - _ = n.admin.Stop() - _ = n.multicast.Stop() - _ = n.tun.Stop() + // Shut down the node using a wait group to synchronize + var wg sync.WaitGroup + + wg.Add(3) + + go func() { + defer wg.Done() + if err := n.admin.Stop(); err != nil { + logger.Errorf("Error stopping admin: %v", err) + } + }() + + go func() { + defer wg.Done() + if err := n.multicast.Stop(); err != nil { + logger.Errorf("Error stopping multicast: %v", err) + } + }() + + go func() { + defer wg.Done() + if err := n.tun.Stop(); err != nil { + logger.Errorf("Error stopping tun: %v", err) + } + }() + + // Stop the core synchronously since it's not in a goroutine n.core.Stop() + + // Wait for all goroutines to finish + wg.Wait() + + // Notify that shutdown is complete + close(doneCh) + close(sigCh) } func setLogLevel(loglevel string, logger *log.Logger) {