commit 2916082d5ed94ef86ad58bdb7256ae07b214c4f3 Author: Alexander NeonXP Kiryukhin Date: Mon Jul 29 02:38:17 2024 +0300 Начальный коммит diff --git a/binder.go b/binder.go new file mode 100644 index 0000000..f618a1a --- /dev/null +++ b/binder.go @@ -0,0 +1,88 @@ +package mux + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "reflect" + "strconv" + "strings" +) + +func Bind[T any](r *http.Request, obj *T) error { + contentType := r.Header.Get("Content-Type") + switch { + case strings.HasPrefix(contentType, "multipart/form-data"), + strings.HasPrefix(contentType, "application/x-www-form-urlencoded"): + if err := r.ParseForm(); err != nil { + return err + } + return bindForm(r.Form, obj) + case strings.HasPrefix(contentType, "application/json"): + defer r.Body.Close() + return json.NewDecoder(r.Body).Decode(obj) + case r.Method == http.MethodGet: + return bindForm(r.URL.Query(), obj) + case r.Method == http.MethodPost: + return fmt.Errorf("invalid content-type: %s", contentType) + } + + return nil +} + +func bindForm(values url.Values, obj any) error { + val := reflect.ValueOf(obj) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + fields := val.NumField() + + for i := 0; i < fields; i++ { + f := val.Field(i) + if !f.IsValid() { + continue + } + if !f.CanSet() { + continue + } + t := val.Type().Field(i) + k := t.Tag.Get("form") + if k == "" { + continue + } + if !values.Has(k) { + continue + } + v := values.Get(k) + + switch f.Type().Kind() { + case reflect.Bool: + switch v { + case "on", "true", "1": + f.SetBool(true) + default: + f.SetBool(false) + } + case reflect.Int, reflect.Int64: + if i, e := strconv.ParseInt(v, 0, 0); e == nil { + f.SetInt(i) + } else { + return fmt.Errorf("could not set int value of %s: %s", k, e) + } + case reflect.Float64: + if fl, e := strconv.ParseFloat(v, 64); e == nil { + f.SetFloat(fl) + } else { + return fmt.Errorf("could not set float64 value of %s: %s", k, e) + } + case reflect.String: + f.SetString(v) + default: + return fmt.Errorf("unsupported format %v for field %s", f.Type().Kind(), k) + } + } + + return nil +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..416d578 --- /dev/null +++ b/error.go @@ -0,0 +1,19 @@ +package mux + +import ( + "context" + "encoding/json" + "io" +) + +var DefaultErrorHandler func(err error) Renderer = func(err error) Renderer { + return RendererFunc(func(ctx context.Context, w io.Writer) error { + return json.NewEncoder(w).Encode(errorStruct{ + Message: err.Error(), + }) + }) +} + +type errorStruct struct { + Message string `json:"message"` +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2c9b4ee --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module go.neonxp.ru/mux + +go 1.22.5 + +require go.neonxp.ru/objectid v0.0.2 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3ce48c1 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +go.neonxp.ru/objectid v0.0.2 h1:Z/G6zvBxmUq0NTq681oGH8pTbBWwi6VA22YOYludIPs= +go.neonxp.ru/objectid v0.0.2/go.mod h1:s0dRi//oe1liiKcor1KmWx09WzkD6Wtww8ZaIv+VLBs= diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..2b56105 --- /dev/null +++ b/middleware.go @@ -0,0 +1,13 @@ +package mux + +import "net/http" + +type Middleware func(http.Handler) http.Handler + +func Use(handler http.Handler, middlewares ...Middleware) http.Handler { + for _, h := range middlewares { + handler = h(handler) + } + + return handler +} diff --git a/middleware/context.go b/middleware/context.go new file mode 100644 index 0000000..b9ad45f --- /dev/null +++ b/middleware/context.go @@ -0,0 +1,11 @@ +package middleware + +type ctxKey int + +const ( + requestIDKey ctxKey = iota + sessionIDKey + sessionValueKey + sessionConfigKey + sessionStorerKey +) diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 0000000..80117da --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "log/slog" + "net/http" + "time" + + "go.neonxp.ru/mux" +) + +type wrappedResponse struct { + http.ResponseWriter + statusCode int +} + +func (w *wrappedResponse) WriteHeader(code int) { + w.statusCode = code + w.ResponseWriter.WriteHeader(code) +} + +func Logger(logger *slog.Logger) mux.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestID := GetRequestID(r) + args := []any{ + slog.String("proto", r.Proto), + slog.String("method", r.Method), + slog.String("request_uri", r.RequestURI), + slog.String("request_id", requestID), + } + logger.InfoContext( + r.Context(), + "start request", + args..., + ) + t := time.Now() + wr := &wrappedResponse{ResponseWriter: w, statusCode: http.StatusOK} + next.ServeHTTP(wr, r) + args = append(args, slog.String("response_time", time.Since(t).String())) + args = append(args, slog.Int("response_status", wr.statusCode)) + logger.InfoContext( + r.Context(), + "finish request", + args..., + ) + }) + } +} diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 0000000..b34d582 --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "log/slog" + "net/http" + + "go.neonxp.ru/mux" +) + +func Recover(logger *slog.Logger) mux.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + err := recover() + if err == nil { + return + } + requestID := GetRequestID(r) + logger.ErrorContext( + r.Context(), + "panic", + slog.Any("panic", err), + slog.String("proto", r.Proto), + slog.String("method", r.Method), + slog.String("request_uri", r.RequestURI), + slog.String("request_id", requestID), + ) + }() + + next.ServeHTTP(w, r) + }) + } +} diff --git a/middleware/request_id.go b/middleware/request_id.go new file mode 100644 index 0000000..016b44a --- /dev/null +++ b/middleware/request_id.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "context" + "net/http" + + "go.neonxp.ru/objectid" +) + +const RequestIDHeader string = "X-Request-ID" + +func RequestID(next http.Handler) http.Handler { + objectid.Seed() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestID := r.Header.Get(RequestIDHeader) + if requestID == "" { + requestID = objectid.New().String() + } + + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), requestIDKey, requestID))) + }) +} + +func GetRequestID(r *http.Request) string { + rid := r.Context().Value(requestIDKey) + if rid == nil { + return "" + } + srid, ok := rid.(string) + if !ok { + return "" + } + + return srid +} diff --git a/middleware/session.go b/middleware/session.go new file mode 100644 index 0000000..838e088 --- /dev/null +++ b/middleware/session.go @@ -0,0 +1,89 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + + "go.neonxp.ru/mux" + "go.neonxp.ru/mux/middleware/session" + "go.neonxp.ru/objectid" +) + +type SessionConfig struct { + SessionCookie string + Path string + Domain string + Secure bool + HttpOnly bool + MaxAge int +} + +var DefaultSessionConfig SessionConfig = SessionConfig{ + SessionCookie: "_session", + Path: "/", + Domain: "", + Secure: false, + HttpOnly: true, + MaxAge: 30 * 3600, +} + +func Session(config SessionConfig, storer session.Store) mux.Middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ( + sessionID string + values session.Value + ) + cookie, err := r.Cookie(config.SessionCookie) + switch { + case err == nil: + sessionID = cookie.Value + values = storer.Load(r.Context(), sessionID) + case errors.Is(err, http.ErrNoCookie): + sessionID = objectid.New().String() + values = session.Value{} + } + + http.SetCookie(w, &http.Cookie{ + Name: config.SessionCookie, + Value: sessionID, + Path: config.Path, + Domain: config.Domain, + Secure: config.Secure, + HttpOnly: config.HttpOnly, + MaxAge: config.MaxAge, + }) + + ctx := context.WithValue(r.Context(), sessionValueKey, &values) + ctx = context.WithValue(ctx, sessionIDKey, sessionID) + ctx = context.WithValue(ctx, sessionConfigKey, config) + ctx = context.WithValue(ctx, sessionStorerKey, storer) + + h.ServeHTTP(w, r.WithContext(ctx)) + + storer.Save(r.Context(), sessionID, values) + + }) + } +} + +func SessionFromRequest(r *http.Request) *session.Value { + return r.Context().Value(sessionValueKey).(*session.Value) +} + +func ClearSession(w http.ResponseWriter, r *http.Request) { + storer := r.Context().Value(sessionStorerKey).(session.Store) + sessionID := r.Context().Value(sessionIDKey).(string) + storer.Remove(r.Context(), sessionID) + config := r.Context().Value(sessionConfigKey).(SessionConfig) + http.SetCookie(w, &http.Cookie{ + Name: config.SessionCookie, + Value: sessionID, + Path: config.Path, + Domain: config.Domain, + Secure: config.Secure, + HttpOnly: config.HttpOnly, + MaxAge: -1, + }) +} diff --git a/middleware/session/bbolt.go b/middleware/session/bbolt.go new file mode 100644 index 0000000..1068ed8 --- /dev/null +++ b/middleware/session/bbolt.go @@ -0,0 +1,71 @@ +package session + +import ( + "bytes" + "context" + "encoding/gob" + "log/slog" + + "go.etcd.io/bbolt" +) + +func New(db *bbolt.DB, bucketName []byte) Store { + return &BoltStore{ + db: db, + bucketName: bucketName, + } +} + +type BoltStore struct { + db *bbolt.DB + bucketName []byte +} + +func (s *BoltStore) Load(ctx context.Context, sessionID string) Value { + v := Value{} + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(s.bucketName) + if bucket == nil { + // no bucket -- normal situation + return nil + } + vb := bucket.Get([]byte(sessionID)) + if vb == nil { + // no session -- no error + return nil + } + rdr := bytes.NewBuffer(vb) + + return gob.NewDecoder(rdr).Decode(&v) + }) + if err != nil { + slog.WarnContext(ctx, "failed load session", slog.Any("error", err)) + } + return v +} + +func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) error { + return s.db.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(s.bucketName) + if err != nil { + return err + } + wrt := bytes.NewBuffer([]byte{}) + if err := gob.NewEncoder(wrt).Encode(value); err != nil { + return err + } + + return bucket.Put([]byte(sessionID), wrt.Bytes()) + }) +} + +func (s *BoltStore) Remove(ctx context.Context, sessionID string) error { + return s.db.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(s.bucketName) + if err != nil { + return err + } + + return bucket.Delete([]byte(sessionID)) + }) +} diff --git a/middleware/session/memstore.go b/middleware/session/memstore.go new file mode 100644 index 0000000..2fcef39 --- /dev/null +++ b/middleware/session/memstore.go @@ -0,0 +1,31 @@ +package session + +import ( + "context" + "sync" +) + +type MemoryStore struct { + store sync.Map +} + +func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value { + val, ok := s.store.Load(sessionID) + if ok { + return val.(Value) + } + + return Value{} +} + +func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error { + s.store.Store(sessionID, value) + + return nil +} + +func (s *MemoryStore) Remove(ctx context.Context, sessionID string) error { + s.store.Delete(sessionID) + + return nil +} diff --git a/middleware/session/store.go b/middleware/session/store.go new file mode 100644 index 0000000..b74a8aa --- /dev/null +++ b/middleware/session/store.go @@ -0,0 +1,18 @@ +package session + +import ( + "context" + "errors" +) + +var ( + ErrSessionNotFound = errors.New("session not found") +) + +type Store interface { + Load(ctx context.Context, sessionID string) Value + Save(ctx context.Context, sessionID string, value Value) error + Remove(ctx context.Context, sessionID string) error +} + +type Value map[string]any diff --git a/redirect.go b/redirect.go new file mode 100644 index 0000000..a392234 --- /dev/null +++ b/redirect.go @@ -0,0 +1,8 @@ +package mux + +import "net/http" + +func Redirect(w http.ResponseWriter, code int, location string) { + w.Header().Add("Location", location) + w.WriteHeader(code) +} diff --git a/render.go b/render.go new file mode 100644 index 0000000..8b39090 --- /dev/null +++ b/render.go @@ -0,0 +1,24 @@ +package mux + +import ( + "context" + "io" + "log/slog" + "net/http" +) + +type Renderer interface { + Render(context.Context, io.Writer) error +} + +func Render(w http.ResponseWriter, r *http.Request, renderable Renderer) { + if err := renderable.Render(r.Context(), w); err != nil { + slog.ErrorContext(r.Context(), "failed render template", slog.Any("err", err)) + } +} + +type RendererFunc func(context.Context, io.Writer) error + +func (r RendererFunc) Render(ctx context.Context, w io.Writer) error { + return r(ctx, w) +}