diff --git a/gemini/gemini.go b/gemini/gemini.go index d8348c8..d44b176 100644 --- a/gemini/gemini.go +++ b/gemini/gemini.go @@ -49,6 +49,7 @@ var ( ErrServerClosed = errors.New("gemini: server closed") ErrHeaderTooLong = errors.New("gemini: header too long") ErrMissingFile = errors.New("gemini: no such file") + ErrEmptyRequest = errors.New("gemini: empty request") ) type Request struct { @@ -90,6 +91,11 @@ type Server struct { Handler Handler // handler to invoke ReadTimeout time.Duration MaxOpenConns int + + // internal + listener net.Listener + shutdown bool + closed chan struct{} } func (s *Server) log(v string) { @@ -107,10 +113,13 @@ func (s *Server) logf(format string, v ...interface{}) { } func (s *Server) ListenAndServe() error { + s.closed = make(chan struct{}) + // outer for loop, if listener closes we will restart it. This may be useful if we switch out // TLSConfig. //for { - listener, err := tls.Listen("tcp", s.Addr, s.TLSConfig) + var err error + s.listener, err = tls.Listen("tcp", s.Addr, s.TLSConfig) if err != nil { return fmt.Errorf("gemini server listen: %w", err) } @@ -119,15 +128,26 @@ func (s *Server) ListenAndServe() error { go s.handleConnectionQueue(queue) for { - conn, err := listener.Accept() + conn, err := s.listener.Accept() if err != nil { s.logf("server accept error: %v", err) break } queue <- conn + + // un-stuck call after shutdown will trigger a drop here + if s.shutdown { + break + } } + // closed confirms the accept call stopped + close(s.closed) + //if s.shutdown { + // return nil //} - return nil + //} + s.log("closing listener gracefully") + return s.listener.Close() } func (s *Server) handleConnectionQueue(queue chan net.Conn) { @@ -171,6 +191,11 @@ func (s *Server) handleConnection(conn net.Conn, sem chan struct{}) { } func (s *Server) handleRequestError(conn net.Conn, req request) { + if errors.Is(req.err, ErrEmptyRequest) { + // silently ignore empty requests. + return + } + s.logf("server error: '%s' %v", strings.TrimSpace(req.rawuri), req.err) var gmierr *GmiError @@ -192,17 +217,22 @@ type request struct { } func requestChannel(c net.Conn, rsp chan request) { + req := &request{} r, err := readHeader(c) - r.err = err - rsp <- *r + if r != nil { + req = r + } + req.err = err + rsp <- *req } func readHeader(c net.Conn) (*request, error) { - r := &request{} req, err := bufio.NewReader(c).ReadString('\r') if err != nil { - return nil, Error(StatusTemporaryFailure, errors.New("error reading request")) + return nil, Error(StatusTemporaryFailure, ErrEmptyRequest) } + + r := &request{} r.rawuri = req requestURL := strings.TrimSpace(req) @@ -237,7 +267,30 @@ func readHeader(c net.Conn) (*request, error) { } func (s *Server) Shutdown(ctx context.Context) error { + s.log("shutdown request received") + t := time.Now() + go func() { + s.shutdown = true + // un-stuck call to self + conn, err := tls.Dial("tcp", "localhost:1965", &tls.Config{ + InsecureSkipVerify: true, + }) + if err != nil { + s.logf("un-stuck call error: %v", err) + return + } + defer conn.Close() + }() + select { + case <-s.closed: + s.log("all clients exited") + case <-ctx.Done(): + s.logf("shutdown: context deadline exceeded after %v, terminating listener", time.Since(t)) + if err := s.listener.Close(); err != nil { + s.logf("error while closing listener %v", err) + } + } return nil } diff --git a/main.go b/main.go index 476d0c0..96a9650 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "errors" "flag" @@ -10,6 +11,7 @@ import ( "log" "mime" "os" + "os/signal" "path" "path/filepath" "strings" @@ -47,7 +49,7 @@ func main() { flag.Parse() // TODO: rotate on SIGHUP - flogger := log.New(os.Stdout, "", log.LUTC|log.Ldate|log.Ltime) + mlogger := log.New(os.Stdout, "", log.LUTC|log.Ldate|log.Ltime) if logs != "" { logpath := filepath.Join(logs, "access.log") accessLog, err := os.OpenFile(logpath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) @@ -56,7 +58,7 @@ func main() { } defer accessLog.Close() - flogger.SetOutput(accessLog) + mlogger.SetOutput(accessLog) } var dlogger *log.Logger @@ -98,7 +100,7 @@ func main() { } mux := gemini.NewMux() - mux.Use(logger(flogger)) + mux.Use(logger(mlogger)) mux.Handle(gemini.HandlerFunc(fileserver(root))) server := &gemini.Server{ @@ -111,29 +113,27 @@ func main() { Logger: dlogger, } - //confirm := make(chan struct{}, 1) - //go func() { - if err := server.ListenAndServe(); err != nil && !errors.Is(err, gemini.ErrServerClosed) { - log.Fatal("ListenAndServe terminated unexpectedly") + confirm := make(chan struct{}, 1) + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, gemini.ErrServerClosed) { + log.Fatalf("ListenAndServe terminated unexpectedly: %v", err) + } + close(confirm) + }() + + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt) + <-stop + + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + if err := server.Shutdown(ctx); err != nil { + cancel() + log.Fatal("ListenAndServe shutdown") } - // close(confirm) - //}() + <-confirm + cancel() - /* - stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt) - <-stop - - ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) - if err := server.Shutdown(ctx); err != nil { - cancel() - log.Fatal("ListenAndServe shutdown") - } - - <-confirm - cancel() - */ /* hup := make(chan os.Signal, 1) signal.Notify(hup, syscall.SIGHUP) @@ -154,7 +154,7 @@ func logger(log *log.Logger) func(next gemini.Handler) gemini.Handler { ip := strings.Split(r.RemoteAddr, ":")[0] hostname, _ := os.Hostname() - fmt.Printf("%s %s - - [%s] \"%s\" - %v\n", + fmt.Fprintf(log.Writer(), "%s %s - - [%s] \"%s\" - %v\n", hostname, ip, t.Format("02/Jan/2006:15:04:05 -0700"),