From 18a9096684ab0b1468a08d442d38b7b675d94810 Mon Sep 17 00:00:00 2001 From: Alexander NeonXP Kiryukhin Date: Sun, 28 Jul 2024 19:32:33 +0300 Subject: [PATCH] Session middleware --- context.go | 8 ++++++ logger.go | 4 +-- recover.go | 2 -- request_id.go | 11 ++----- session.go | 65 ++++++++++++++++++++++++++++++++++++++++++ session/bbolt/bbolt.go | 61 +++++++++++++++++++++++++++++++++++++++ session/bbolt/go.mod | 7 +++++ session/bbolt/go.sum | 4 +++ session/memstore.go | 25 ++++++++++++++++ session/store.go | 17 +++++++++++ 10 files changed, 191 insertions(+), 13 deletions(-) create mode 100644 context.go create mode 100644 session.go create mode 100644 session/bbolt/bbolt.go create mode 100644 session/bbolt/go.mod create mode 100644 session/bbolt/go.sum create mode 100644 session/memstore.go create mode 100644 session/store.go diff --git a/context.go b/context.go new file mode 100644 index 0000000..b844226 --- /dev/null +++ b/context.go @@ -0,0 +1,8 @@ +package middleware + +type ctxKey struct{} + +var ( + requestIDKey ctxKey + sessionKey ctxKey +) diff --git a/logger.go b/logger.go index 039bd19..065a6a7 100644 --- a/logger.go +++ b/logger.go @@ -15,9 +15,7 @@ func Logger(logger *slog.Logger) Middleware { slog.String("proto", r.Proto), slog.String("method", r.Method), slog.String("request_uri", r.RequestURI), - } - if requestID != "" { - args = append(args, slog.String("request_id", requestID)) + slog.String("request_id", requestID), } logger.InfoContext( r.Context(), diff --git a/recover.go b/recover.go index 6b5f2cb..a962b90 100644 --- a/recover.go +++ b/recover.go @@ -2,7 +2,6 @@ package middleware import ( "net/http" - "runtime/debug" "log/slog" ) @@ -15,7 +14,6 @@ func Recover(logger *slog.Logger) Middleware { if err == nil { return } - debug.PrintStack() requestID := GetRequestID(r) logger.ErrorContext( r.Context(), diff --git a/request_id.go b/request_id.go index 0e9a521..016b44a 100644 --- a/request_id.go +++ b/request_id.go @@ -7,12 +7,7 @@ import ( "go.neonxp.ru/objectid" ) -type ctxKeyRequestID int - -const ( - RequestIDKey ctxKeyRequestID = 0 - RequestIDHeader string = "X-Request-ID" -) +const RequestIDHeader string = "X-Request-ID" func RequestID(next http.Handler) http.Handler { objectid.Seed() @@ -22,12 +17,12 @@ func RequestID(next http.Handler) http.Handler { requestID = objectid.New().String() } - next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), RequestIDKey, requestID))) + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), requestIDKey, requestID))) }) } func GetRequestID(r *http.Request) string { - rid := r.Context().Value(RequestIDKey) + rid := r.Context().Value(requestIDKey) if rid == nil { return "" } diff --git a/session.go b/session.go new file mode 100644 index 0000000..11944d4 --- /dev/null +++ b/session.go @@ -0,0 +1,65 @@ +package middleware + +import ( + "context" + "net/http" + + "go.neonxp.ru/middleware/session" + "go.neonxp.ru/objectid" +) + +type SessionConfig struct { + SessionCookie string + Path string + Domain string + Secure bool + HttpOnly bool + MaxAge int +} + +type SessionManager struct { + SessionID string + Storer session.Store + MaxAge int +} + +func (s *SessionManager) Load(ctx context.Context) session.Value { + return s.Storer.Load(ctx, s.SessionID) +} + +func (s *SessionManager) Save(ctx context.Context, value session.Value) error { + return s.Storer.Save(ctx, s.SessionID, value) +} + +func (s *SessionManager) SetMaxAge(maxAge int) { + s.MaxAge = maxAge +} + +func Session(config *SessionConfig, storer session.Store) Middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionID := objectid.New().String() + cookie, err := r.Cookie(config.SessionCookie) + if err == nil { + sessionID = cookie.Value + } + sessionManager := &SessionManager{SessionID: sessionID, Storer: storer, MaxAge: config.MaxAge} + + h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), sessionKey, &sessionManager))) + + http.SetCookie(w, &http.Cookie{ + Name: config.SessionCookie, + Value: sessionID, + Path: config.Path, + Domain: config.Domain, + Secure: config.Secure, + HttpOnly: config.HttpOnly, + MaxAge: sessionManager.MaxAge, + }) + }) + } +} + +func SessionFromRequest(r *http.Request) *SessionManager { + return r.Context().Value(sessionKey).(*SessionManager) +} diff --git a/session/bbolt/bbolt.go b/session/bbolt/bbolt.go new file mode 100644 index 0000000..8e55d25 --- /dev/null +++ b/session/bbolt/bbolt.go @@ -0,0 +1,61 @@ +package bbolt + +import ( + "bytes" + "context" + "encoding/gob" + "log/slog" + + "go.etcd.io/bbolt" + "go.neonxp.ru/middleware/session" +) + +func New(db *bbolt.DB, bucketName []byte) session.Store { + return &Store{ + db: db, + bucketName: bucketName, + } +} + +type Store struct { + db *bbolt.DB + bucketName []byte +} + +func (s *Store) Load(ctx context.Context, sessionID string) session.Value { + v := session.Value{} + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(s.bucketName) + if bucket == nil { + // no bucket -- normal situation + return nil + } + vb := bucket.Get([]byte(sessionID)) + if vb == nil { + // no session -- no error + return nil + } + rdr := bytes.NewBuffer(vb) + + return gob.NewDecoder(rdr).Decode(&v) + }) + if err != nil { + slog.WarnContext(ctx, "failed load session", slog.Any("error", err)) + } + return v +} + +func (s *Store) Save(ctx context.Context, sessionID string, value session.Value) error { + return s.db.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(s.bucketName) + if err != nil { + return err + } + wrt := bytes.NewBuffer([]byte{}) + if err := gob.NewEncoder(wrt).Encode(value); err != nil { + return err + } + + return bucket.Put([]byte(sessionID), wrt.Bytes()) + }) +} diff --git a/session/bbolt/go.mod b/session/bbolt/go.mod new file mode 100644 index 0000000..f50214b --- /dev/null +++ b/session/bbolt/go.mod @@ -0,0 +1,7 @@ +module gitrepo.ru/neonxp/middleware/session/bbolt + +go 1.22.5 + +require go.etcd.io/bbolt v1.3.10 + +require golang.org/x/sys v0.4.0 // indirect diff --git a/session/bbolt/go.sum b/session/bbolt/go.sum new file mode 100644 index 0000000..3a4cae6 --- /dev/null +++ b/session/bbolt/go.sum @@ -0,0 +1,4 @@ +go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= +go.etcd.io/bbolt v1.3.10/go.mod h1:bK3UQLPJZly7IlNmV7uVHJDxfe5aK9Ll93e/74Y9oEQ= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/session/memstore.go b/session/memstore.go new file mode 100644 index 0000000..97d70a1 --- /dev/null +++ b/session/memstore.go @@ -0,0 +1,25 @@ +package session + +import ( + "context" + "sync" +) + +type MemoryStore struct { + store sync.Map +} + +func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value { + val, ok := s.store.Load(sessionID) + if ok { + return val.(Value) + } + + return Value{} +} + +func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error { + s.store.Store(sessionID, value) + + return nil +} diff --git a/session/store.go b/session/store.go new file mode 100644 index 0000000..adef69a --- /dev/null +++ b/session/store.go @@ -0,0 +1,17 @@ +package session + +import ( + "context" + "errors" +) + +var ( + ErrSessionNotFound = errors.New("session not found") +) + +type Store interface { + Load(ctx context.Context, sessionID string) Value + Save(ctx context.Context, sessionID string, value Value) error +} + +type Value map[string]any