diff --git a/logger.go b/logger.go index 6a8ce7a..039bd19 100644 --- a/logger.go +++ b/logger.go @@ -6,17 +6,24 @@ import ( "log/slog" ) -func Logger(handler http.Handler, logger *slog.Logger) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handler.ServeHTTP(w, r) - requestID := GetRequestID(r) - logger.InfoContext( - r.Context(), - "request", - slog.String("proto", r.Proto), - slog.String("method", r.Method), - slog.String("request_uri", r.RequestURI), - slog.String("request_id", requestID), - ) - }) +func Logger(logger *slog.Logger) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + requestID := GetRequestID(r) + args := []any{ + slog.String("proto", r.Proto), + slog.String("method", r.Method), + slog.String("request_uri", r.RequestURI), + } + if requestID != "" { + args = append(args, slog.String("request_id", requestID)) + } + logger.InfoContext( + r.Context(), + "request", + args..., + ) + }) + } } diff --git a/recover.go b/recover.go index cbe12ac..6b5f2cb 100644 --- a/recover.go +++ b/recover.go @@ -7,26 +7,28 @@ import ( "log/slog" ) -func Recover(handler http.Handler, logger *slog.Logger) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - err := recover() - if err == nil { - return - } - debug.PrintStack() - requestID := GetRequestID(r) - logger.ErrorContext( - r.Context(), - "panic", - slog.Any("panic", err), - slog.String("proto", r.Proto), - slog.String("method", r.Method), - slog.String("request_uri", r.RequestURI), - slog.String("request_id", requestID), - ) - }() +func Recover(logger *slog.Logger) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + err := recover() + if err == nil { + return + } + debug.PrintStack() + requestID := GetRequestID(r) + logger.ErrorContext( + r.Context(), + "panic", + slog.Any("panic", err), + slog.String("proto", r.Proto), + slog.String("method", r.Method), + slog.String("request_uri", r.RequestURI), + slog.String("request_id", requestID), + ) + }() - handler.ServeHTTP(w, r) - }) + next.ServeHTTP(w, r) + }) + } } diff --git a/request_id.go b/request_id.go index 67a0c82..b3650ef 100644 --- a/request_id.go +++ b/request_id.go @@ -14,14 +14,14 @@ const ( RequestIDHeader string = "X-Request-ID" ) -func RequestID(handler http.Handler) http.Handler { +func RequestID(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestID := r.Header.Get(RequestIDHeader) if requestID == "" { requestID = uuid.NewString() } - handler.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), RequestIDKey, requestID))) + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), RequestIDKey, requestID))) }) } diff --git a/use.go b/use.go new file mode 100644 index 0000000..6610e2f --- /dev/null +++ b/use.go @@ -0,0 +1,13 @@ +package middleware + +import "net/http" + +type Middleware func(http.Handler) http.Handler + +func Use(handler http.Handler, middlewares ...Middleware) http.Handler { + for _, h := range middlewares { + handler = h(handler) + } + + return handler +}