Пересмотрел механизм сессий

This commit is contained in:
Alexander NeonXP Kiryukhin 2024-07-30 00:45:25 +03:00
parent c261597b9a
commit 623eaf165a
Signed by: NeonXP
GPG key ID: 35E33E1AB7776B39
6 changed files with 98 additions and 70 deletions

View file

@ -4,8 +4,4 @@ type ctxKey int
const ( const (
requestIDKey ctxKey = iota requestIDKey ctxKey = iota
SessionIDKey
SessionValueKey
SessionConfigKey
SessionStorerKey
) )

View file

@ -2,7 +2,6 @@ package session
import ( import (
"bytes" "bytes"
"context"
"encoding/gob" "encoding/gob"
"log/slog" "log/slog"
@ -21,8 +20,8 @@ type BoltStore struct {
bucketName []byte bucketName []byte
} }
func (s *BoltStore) Load(ctx context.Context, sessionID string) Value { func (s *BoltStore) Load(sessionID string) Values {
v := Value{} v := Values{}
err := s.db.View(func(tx *bbolt.Tx) error { err := s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(s.bucketName) bucket := tx.Bucket(s.bucketName)
if bucket == nil { if bucket == nil {
@ -39,12 +38,12 @@ func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
return gob.NewDecoder(rdr).Decode(&v) return gob.NewDecoder(rdr).Decode(&v)
}) })
if err != nil { if err != nil {
slog.WarnContext(ctx, "failed load session", slog.Any("error", err)) slog.Warn("failed load session", slog.Any("error", err))
} }
return v return v
} }
func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) error { func (s *BoltStore) Save(sessionID string, value Values) error {
return s.db.Update(func(tx *bbolt.Tx) error { return s.db.Update(func(tx *bbolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(s.bucketName) bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
if err != nil { if err != nil {
@ -59,7 +58,7 @@ func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) err
}) })
} }
func (s *BoltStore) Remove(ctx context.Context, sessionID string) error { func (s *BoltStore) Remove(sessionID string) error {
return s.db.Update(func(tx *bbolt.Tx) error { return s.db.Update(func(tx *bbolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(s.bucketName) bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
if err != nil { if err != nil {

View file

@ -0,0 +1,9 @@
package session
type ctxKey int
const (
sessionManagerKey ctxKey = iota
sessionIDKey
sessionValueKey
)

View file

@ -1,7 +1,6 @@
package session package session
import ( import (
"context"
"sync" "sync"
) )
@ -9,22 +8,22 @@ type MemoryStore struct {
store sync.Map store sync.Map
} }
func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value { func (s *MemoryStore) Load(sessionID string) Values {
val, ok := s.store.Load(sessionID) val, ok := s.store.Load(sessionID)
if ok { if ok {
return val.(Value) return val.(Values)
} }
return Value{} return Values{}
} }
func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error { func (s *MemoryStore) Save(sessionID string, value Values) error {
s.store.Store(sessionID, value) s.store.Store(sessionID, value)
return nil return nil
} }
func (s *MemoryStore) Remove(ctx context.Context, sessionID string) error { func (s *MemoryStore) Remove(sessionID string) error {
s.store.Delete(sessionID) s.store.Delete(sessionID)
return nil return nil

View file

@ -4,10 +4,9 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"sync" "time"
"go.neonxp.ru/mux" "go.neonxp.ru/mux"
"go.neonxp.ru/mux/middleware"
"go.neonxp.ru/objectid" "go.neonxp.ru/objectid"
) )
@ -17,7 +16,7 @@ type Config struct {
Domain string Domain string
Secure bool Secure bool
HttpOnly bool HttpOnly bool
MaxAge int MaxAge time.Duration
} }
var DefaultConfig Config = Config{ var DefaultConfig Config = Config{
@ -26,69 +25,104 @@ var DefaultConfig Config = Config{
Domain: "", Domain: "",
Secure: false, Secure: false,
HttpOnly: true, HttpOnly: true,
MaxAge: 30 * 3600, MaxAge: 365 * 24 * time.Hour,
} }
func Middleware(config Config, storer Store) mux.Middleware { var (
if storer == nil { ErrSessionNotFound = errors.New("session not found")
storer = &MemoryStore{store: sync.Map{}} ErrNoSessionInContext = errors.New("no session in context")
} )
type SessionManager struct {
config *Config
storer Store
}
func New(storer Store) *SessionManager {
return NewWithConfig(&DefaultConfig, storer)
}
func NewWithConfig(config *Config, storer Store) *SessionManager {
return &SessionManager{
config: config,
storer: storer,
}
}
func (s *SessionManager) Middleware() mux.Middleware {
return func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ( var (
sessionID string sessionID string
values Value values Values
) )
cookie, err := r.Cookie(config.SessionCookie) cookie, err := r.Cookie(s.config.SessionCookie)
switch { switch {
case err == nil: case err == nil:
sessionID = cookie.Value sessionID = cookie.Value
values = storer.Load(r.Context(), sessionID) values = s.storer.Load(sessionID)
case errors.Is(err, http.ErrNoCookie): case errors.Is(err, http.ErrNoCookie):
sessionID = objectid.New().String() sessionID = objectid.New().String()
values = Value{}
} }
http.SetCookie(w, &http.Cookie{ ctx := context.WithValue(r.Context(), sessionManagerKey, s)
Name: config.SessionCookie, ctx = context.WithValue(ctx, sessionIDKey, sessionID)
Value: sessionID, ctx = context.WithValue(ctx, sessionValueKey, values)
Path: config.Path,
Domain: config.Domain,
Secure: config.Secure,
HttpOnly: config.HttpOnly,
MaxAge: config.MaxAge,
})
ctx := context.WithValue(r.Context(), middleware.SessionValueKey, &values)
ctx = context.WithValue(ctx, middleware.SessionIDKey, sessionID)
ctx = context.WithValue(ctx, middleware.SessionConfigKey, config)
ctx = context.WithValue(ctx, middleware.SessionStorerKey, storer)
h.ServeHTTP(w, r.WithContext(ctx)) h.ServeHTTP(w, r.WithContext(ctx))
storer.Save(r.Context(), sessionID, values)
}) })
} }
} }
func FromRequest(r *http.Request) *Value { func (s *SessionManager) Values(ctx context.Context) Values {
return r.Context().Value(middleware.SessionValueKey).(*Value) aValue := ctx.Value(sessionValueKey)
values, ok := aValue.(Values)
if !ok || values == nil {
values = Values{}
}
return values
} }
func Clear(w http.ResponseWriter, r *http.Request) { func (s *SessionManager) Save(w http.ResponseWriter, r *http.Request, values Values) error {
storer := r.Context().Value(middleware.SessionStorerKey).(Store) aSessionID := r.Context().Value(sessionIDKey)
sessionID := r.Context().Value(middleware.SessionIDKey).(string) sessionID, ok := aSessionID.(string)
storer.Remove(r.Context(), sessionID) if !ok {
config := r.Context().Value(middleware.SessionConfigKey).(Config) return ErrNoSessionInContext
}
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: config.SessionCookie, Name: s.config.SessionCookie,
Value: sessionID, Value: sessionID,
Path: config.Path, Path: s.config.Path,
Domain: config.Domain, Domain: s.config.Domain,
Secure: config.Secure, Secure: s.config.Secure,
HttpOnly: config.HttpOnly, HttpOnly: s.config.HttpOnly,
MaxAge: int(s.config.MaxAge.Seconds()),
})
return s.storer.Save(sessionID, values)
}
func (s *SessionManager) Clear(w http.ResponseWriter, r *http.Request) error {
aSessionID := r.Context().Value(sessionIDKey)
sessionID, ok := aSessionID.(string)
if !ok {
return ErrNoSessionInContext
}
http.SetCookie(w, &http.Cookie{
Name: s.config.SessionCookie,
Value: sessionID,
Path: s.config.Path,
Domain: s.config.Domain,
Secure: s.config.Secure,
HttpOnly: s.config.HttpOnly,
MaxAge: -1, MaxAge: -1,
}) })
return s.storer.Remove(sessionID)
}
func FromRequest(r *http.Request) *SessionManager {
return r.Context().Value(sessionManagerKey).(*SessionManager)
} }

View file

@ -1,18 +1,9 @@
package session package session
import (
"context"
"errors"
)
var (
ErrSessionNotFound = errors.New("session not found")
)
type Store interface { type Store interface {
Load(ctx context.Context, sessionID string) Value Load(sessionID string) Values
Save(ctx context.Context, sessionID string, value Value) error Save(sessionID string, value Values) error
Remove(ctx context.Context, sessionID string) error Remove(sessionID string) error
} }
type Value map[string]any type Values map[string]any