Пересмотрел механизм сессий
This commit is contained in:
parent
c261597b9a
commit
623eaf165a
6 changed files with 98 additions and 70 deletions
|
@ -4,8 +4,4 @@ type ctxKey int
|
|||
|
||||
const (
|
||||
requestIDKey ctxKey = iota
|
||||
SessionIDKey
|
||||
SessionValueKey
|
||||
SessionConfigKey
|
||||
SessionStorerKey
|
||||
)
|
||||
|
|
|
@ -2,7 +2,6 @@ package session
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"log/slog"
|
||||
|
||||
|
@ -21,8 +20,8 @@ type BoltStore struct {
|
|||
bucketName []byte
|
||||
}
|
||||
|
||||
func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
|
||||
v := Value{}
|
||||
func (s *BoltStore) Load(sessionID string) Values {
|
||||
v := Values{}
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(s.bucketName)
|
||||
if bucket == nil {
|
||||
|
@ -39,12 +38,12 @@ func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
|
|||
return gob.NewDecoder(rdr).Decode(&v)
|
||||
})
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "failed load session", slog.Any("error", err))
|
||||
slog.Warn("failed load session", slog.Any("error", err))
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) error {
|
||||
func (s *BoltStore) Save(sessionID string, value Values) error {
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
|
||||
if err != nil {
|
||||
|
@ -59,7 +58,7 @@ func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) err
|
|||
})
|
||||
}
|
||||
|
||||
func (s *BoltStore) Remove(ctx context.Context, sessionID string) error {
|
||||
func (s *BoltStore) Remove(sessionID string) error {
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
|
||||
if err != nil {
|
||||
|
|
9
middleware/session/context.go
Normal file
9
middleware/session/context.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package session
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
sessionManagerKey ctxKey = iota
|
||||
sessionIDKey
|
||||
sessionValueKey
|
||||
)
|
|
@ -1,7 +1,6 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -9,22 +8,22 @@ type MemoryStore struct {
|
|||
store sync.Map
|
||||
}
|
||||
|
||||
func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value {
|
||||
func (s *MemoryStore) Load(sessionID string) Values {
|
||||
val, ok := s.store.Load(sessionID)
|
||||
if ok {
|
||||
return val.(Value)
|
||||
return val.(Values)
|
||||
}
|
||||
|
||||
return Value{}
|
||||
return Values{}
|
||||
}
|
||||
|
||||
func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error {
|
||||
func (s *MemoryStore) Save(sessionID string, value Values) error {
|
||||
s.store.Store(sessionID, value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStore) Remove(ctx context.Context, sessionID string) error {
|
||||
func (s *MemoryStore) Remove(sessionID string) error {
|
||||
s.store.Delete(sessionID)
|
||||
|
||||
return nil
|
||||
|
|
|
@ -4,10 +4,9 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.neonxp.ru/mux"
|
||||
"go.neonxp.ru/mux/middleware"
|
||||
"go.neonxp.ru/objectid"
|
||||
)
|
||||
|
||||
|
@ -17,7 +16,7 @@ type Config struct {
|
|||
Domain string
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
MaxAge int
|
||||
MaxAge time.Duration
|
||||
}
|
||||
|
||||
var DefaultConfig Config = Config{
|
||||
|
@ -26,69 +25,104 @@ var DefaultConfig Config = Config{
|
|||
Domain: "",
|
||||
Secure: false,
|
||||
HttpOnly: true,
|
||||
MaxAge: 30 * 3600,
|
||||
MaxAge: 365 * 24 * time.Hour,
|
||||
}
|
||||
|
||||
func Middleware(config Config, storer Store) mux.Middleware {
|
||||
if storer == nil {
|
||||
storer = &MemoryStore{store: sync.Map{}}
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
ErrNoSessionInContext = errors.New("no session in context")
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
config *Config
|
||||
storer Store
|
||||
}
|
||||
|
||||
func New(storer Store) *SessionManager {
|
||||
return NewWithConfig(&DefaultConfig, storer)
|
||||
}
|
||||
|
||||
func NewWithConfig(config *Config, storer Store) *SessionManager {
|
||||
return &SessionManager{
|
||||
config: config,
|
||||
storer: storer,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionManager) Middleware() mux.Middleware {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
sessionID string
|
||||
values Value
|
||||
values Values
|
||||
)
|
||||
cookie, err := r.Cookie(config.SessionCookie)
|
||||
cookie, err := r.Cookie(s.config.SessionCookie)
|
||||
switch {
|
||||
case err == nil:
|
||||
sessionID = cookie.Value
|
||||
values = storer.Load(r.Context(), sessionID)
|
||||
values = s.storer.Load(sessionID)
|
||||
case errors.Is(err, http.ErrNoCookie):
|
||||
sessionID = objectid.New().String()
|
||||
values = Value{}
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: config.SessionCookie,
|
||||
Value: sessionID,
|
||||
Path: config.Path,
|
||||
Domain: config.Domain,
|
||||
Secure: config.Secure,
|
||||
HttpOnly: config.HttpOnly,
|
||||
MaxAge: config.MaxAge,
|
||||
})
|
||||
|
||||
ctx := context.WithValue(r.Context(), middleware.SessionValueKey, &values)
|
||||
ctx = context.WithValue(ctx, middleware.SessionIDKey, sessionID)
|
||||
ctx = context.WithValue(ctx, middleware.SessionConfigKey, config)
|
||||
ctx = context.WithValue(ctx, middleware.SessionStorerKey, storer)
|
||||
ctx := context.WithValue(r.Context(), sessionManagerKey, s)
|
||||
ctx = context.WithValue(ctx, sessionIDKey, sessionID)
|
||||
ctx = context.WithValue(ctx, sessionValueKey, values)
|
||||
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
|
||||
storer.Save(r.Context(), sessionID, values)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FromRequest(r *http.Request) *Value {
|
||||
return r.Context().Value(middleware.SessionValueKey).(*Value)
|
||||
func (s *SessionManager) Values(ctx context.Context) Values {
|
||||
aValue := ctx.Value(sessionValueKey)
|
||||
values, ok := aValue.(Values)
|
||||
if !ok || values == nil {
|
||||
values = Values{}
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func (s *SessionManager) Save(w http.ResponseWriter, r *http.Request, values Values) error {
|
||||
aSessionID := r.Context().Value(sessionIDKey)
|
||||
sessionID, ok := aSessionID.(string)
|
||||
if !ok {
|
||||
return ErrNoSessionInContext
|
||||
}
|
||||
|
||||
func Clear(w http.ResponseWriter, r *http.Request) {
|
||||
storer := r.Context().Value(middleware.SessionStorerKey).(Store)
|
||||
sessionID := r.Context().Value(middleware.SessionIDKey).(string)
|
||||
storer.Remove(r.Context(), sessionID)
|
||||
config := r.Context().Value(middleware.SessionConfigKey).(Config)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: config.SessionCookie,
|
||||
Name: s.config.SessionCookie,
|
||||
Value: sessionID,
|
||||
Path: config.Path,
|
||||
Domain: config.Domain,
|
||||
Secure: config.Secure,
|
||||
HttpOnly: config.HttpOnly,
|
||||
Path: s.config.Path,
|
||||
Domain: s.config.Domain,
|
||||
Secure: s.config.Secure,
|
||||
HttpOnly: s.config.HttpOnly,
|
||||
MaxAge: int(s.config.MaxAge.Seconds()),
|
||||
})
|
||||
|
||||
return s.storer.Save(sessionID, values)
|
||||
}
|
||||
func (s *SessionManager) Clear(w http.ResponseWriter, r *http.Request) error {
|
||||
aSessionID := r.Context().Value(sessionIDKey)
|
||||
sessionID, ok := aSessionID.(string)
|
||||
if !ok {
|
||||
return ErrNoSessionInContext
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: s.config.SessionCookie,
|
||||
Value: sessionID,
|
||||
Path: s.config.Path,
|
||||
Domain: s.config.Domain,
|
||||
Secure: s.config.Secure,
|
||||
HttpOnly: s.config.HttpOnly,
|
||||
MaxAge: -1,
|
||||
})
|
||||
|
||||
return s.storer.Remove(sessionID)
|
||||
}
|
||||
|
||||
func FromRequest(r *http.Request) *SessionManager {
|
||||
return r.Context().Value(sessionManagerKey).(*SessionManager)
|
||||
}
|
||||
|
|
|
@ -1,18 +1,9 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
Load(ctx context.Context, sessionID string) Value
|
||||
Save(ctx context.Context, sessionID string, value Value) error
|
||||
Remove(ctx context.Context, sessionID string) error
|
||||
Load(sessionID string) Values
|
||||
Save(sessionID string, value Values) error
|
||||
Remove(sessionID string) error
|
||||
}
|
||||
|
||||
type Value map[string]any
|
||||
type Values map[string]any
|
||||
|
|
Loading…
Reference in a new issue