Пересмотрел механизм сессий
This commit is contained in:
parent
c261597b9a
commit
623eaf165a
6 changed files with 98 additions and 70 deletions
|
@ -4,8 +4,4 @@ type ctxKey int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
requestIDKey ctxKey = iota
|
requestIDKey ctxKey = iota
|
||||||
SessionIDKey
|
|
||||||
SessionValueKey
|
|
||||||
SessionConfigKey
|
|
||||||
SessionStorerKey
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,7 +2,6 @@ package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
|
@ -21,8 +20,8 @@ type BoltStore struct {
|
||||||
bucketName []byte
|
bucketName []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
|
func (s *BoltStore) Load(sessionID string) Values {
|
||||||
v := Value{}
|
v := Values{}
|
||||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||||
bucket := tx.Bucket(s.bucketName)
|
bucket := tx.Bucket(s.bucketName)
|
||||||
if bucket == nil {
|
if bucket == nil {
|
||||||
|
@ -39,12 +38,12 @@ func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
|
||||||
return gob.NewDecoder(rdr).Decode(&v)
|
return gob.NewDecoder(rdr).Decode(&v)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.WarnContext(ctx, "failed load session", slog.Any("error", err))
|
slog.Warn("failed load session", slog.Any("error", err))
|
||||||
}
|
}
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) error {
|
func (s *BoltStore) Save(sessionID string, value Values) error {
|
||||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||||
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
|
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -59,7 +58,7 @@ func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BoltStore) Remove(ctx context.Context, sessionID string) error {
|
func (s *BoltStore) Remove(sessionID string) error {
|
||||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||||
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
|
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
9
middleware/session/context.go
Normal file
9
middleware/session/context.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
type ctxKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
sessionManagerKey ctxKey = iota
|
||||||
|
sessionIDKey
|
||||||
|
sessionValueKey
|
||||||
|
)
|
|
@ -1,7 +1,6 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,22 +8,22 @@ type MemoryStore struct {
|
||||||
store sync.Map
|
store sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value {
|
func (s *MemoryStore) Load(sessionID string) Values {
|
||||||
val, ok := s.store.Load(sessionID)
|
val, ok := s.store.Load(sessionID)
|
||||||
if ok {
|
if ok {
|
||||||
return val.(Value)
|
return val.(Values)
|
||||||
}
|
}
|
||||||
|
|
||||||
return Value{}
|
return Values{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error {
|
func (s *MemoryStore) Save(sessionID string, value Values) error {
|
||||||
s.store.Store(sessionID, value)
|
s.store.Store(sessionID, value)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MemoryStore) Remove(ctx context.Context, sessionID string) error {
|
func (s *MemoryStore) Remove(sessionID string) error {
|
||||||
s.store.Delete(sessionID)
|
s.store.Delete(sessionID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -4,10 +4,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"time"
|
||||||
|
|
||||||
"go.neonxp.ru/mux"
|
"go.neonxp.ru/mux"
|
||||||
"go.neonxp.ru/mux/middleware"
|
|
||||||
"go.neonxp.ru/objectid"
|
"go.neonxp.ru/objectid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,7 +16,7 @@ type Config struct {
|
||||||
Domain string
|
Domain string
|
||||||
Secure bool
|
Secure bool
|
||||||
HttpOnly bool
|
HttpOnly bool
|
||||||
MaxAge int
|
MaxAge time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultConfig Config = Config{
|
var DefaultConfig Config = Config{
|
||||||
|
@ -26,69 +25,104 @@ var DefaultConfig Config = Config{
|
||||||
Domain: "",
|
Domain: "",
|
||||||
Secure: false,
|
Secure: false,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
MaxAge: 30 * 3600,
|
MaxAge: 365 * 24 * time.Hour,
|
||||||
}
|
}
|
||||||
|
|
||||||
func Middleware(config Config, storer Store) mux.Middleware {
|
var (
|
||||||
if storer == nil {
|
ErrSessionNotFound = errors.New("session not found")
|
||||||
storer = &MemoryStore{store: sync.Map{}}
|
ErrNoSessionInContext = errors.New("no session in context")
|
||||||
}
|
)
|
||||||
|
|
||||||
|
type SessionManager struct {
|
||||||
|
config *Config
|
||||||
|
storer Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(storer Store) *SessionManager {
|
||||||
|
return NewWithConfig(&DefaultConfig, storer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWithConfig(config *Config, storer Store) *SessionManager {
|
||||||
|
return &SessionManager{
|
||||||
|
config: config,
|
||||||
|
storer: storer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionManager) Middleware() mux.Middleware {
|
||||||
return func(h http.Handler) http.Handler {
|
return func(h http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var (
|
var (
|
||||||
sessionID string
|
sessionID string
|
||||||
values Value
|
values Values
|
||||||
)
|
)
|
||||||
cookie, err := r.Cookie(config.SessionCookie)
|
cookie, err := r.Cookie(s.config.SessionCookie)
|
||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
sessionID = cookie.Value
|
sessionID = cookie.Value
|
||||||
values = storer.Load(r.Context(), sessionID)
|
values = s.storer.Load(sessionID)
|
||||||
case errors.Is(err, http.ErrNoCookie):
|
case errors.Is(err, http.ErrNoCookie):
|
||||||
sessionID = objectid.New().String()
|
sessionID = objectid.New().String()
|
||||||
values = Value{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
ctx := context.WithValue(r.Context(), sessionManagerKey, s)
|
||||||
Name: config.SessionCookie,
|
ctx = context.WithValue(ctx, sessionIDKey, sessionID)
|
||||||
Value: sessionID,
|
ctx = context.WithValue(ctx, sessionValueKey, values)
|
||||||
Path: config.Path,
|
|
||||||
Domain: config.Domain,
|
|
||||||
Secure: config.Secure,
|
|
||||||
HttpOnly: config.HttpOnly,
|
|
||||||
MaxAge: config.MaxAge,
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), middleware.SessionValueKey, &values)
|
|
||||||
ctx = context.WithValue(ctx, middleware.SessionIDKey, sessionID)
|
|
||||||
ctx = context.WithValue(ctx, middleware.SessionConfigKey, config)
|
|
||||||
ctx = context.WithValue(ctx, middleware.SessionStorerKey, storer)
|
|
||||||
|
|
||||||
h.ServeHTTP(w, r.WithContext(ctx))
|
h.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
|
||||||
storer.Save(r.Context(), sessionID, values)
|
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FromRequest(r *http.Request) *Value {
|
func (s *SessionManager) Values(ctx context.Context) Values {
|
||||||
return r.Context().Value(middleware.SessionValueKey).(*Value)
|
aValue := ctx.Value(sessionValueKey)
|
||||||
|
values, ok := aValue.(Values)
|
||||||
|
if !ok || values == nil {
|
||||||
|
values = Values{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func Clear(w http.ResponseWriter, r *http.Request) {
|
func (s *SessionManager) Save(w http.ResponseWriter, r *http.Request, values Values) error {
|
||||||
storer := r.Context().Value(middleware.SessionStorerKey).(Store)
|
aSessionID := r.Context().Value(sessionIDKey)
|
||||||
sessionID := r.Context().Value(middleware.SessionIDKey).(string)
|
sessionID, ok := aSessionID.(string)
|
||||||
storer.Remove(r.Context(), sessionID)
|
if !ok {
|
||||||
config := r.Context().Value(middleware.SessionConfigKey).(Config)
|
return ErrNoSessionInContext
|
||||||
|
}
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: config.SessionCookie,
|
Name: s.config.SessionCookie,
|
||||||
Value: sessionID,
|
Value: sessionID,
|
||||||
Path: config.Path,
|
Path: s.config.Path,
|
||||||
Domain: config.Domain,
|
Domain: s.config.Domain,
|
||||||
Secure: config.Secure,
|
Secure: s.config.Secure,
|
||||||
HttpOnly: config.HttpOnly,
|
HttpOnly: s.config.HttpOnly,
|
||||||
|
MaxAge: int(s.config.MaxAge.Seconds()),
|
||||||
|
})
|
||||||
|
|
||||||
|
return s.storer.Save(sessionID, values)
|
||||||
|
}
|
||||||
|
func (s *SessionManager) Clear(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
aSessionID := r.Context().Value(sessionIDKey)
|
||||||
|
sessionID, ok := aSessionID.(string)
|
||||||
|
if !ok {
|
||||||
|
return ErrNoSessionInContext
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: s.config.SessionCookie,
|
||||||
|
Value: sessionID,
|
||||||
|
Path: s.config.Path,
|
||||||
|
Domain: s.config.Domain,
|
||||||
|
Secure: s.config.Secure,
|
||||||
|
HttpOnly: s.config.HttpOnly,
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return s.storer.Remove(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FromRequest(r *http.Request) *SessionManager {
|
||||||
|
return r.Context().Value(sessionManagerKey).(*SessionManager)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1,9 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrSessionNotFound = errors.New("session not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
Load(ctx context.Context, sessionID string) Value
|
Load(sessionID string) Values
|
||||||
Save(ctx context.Context, sessionID string, value Value) error
|
Save(sessionID string, value Values) error
|
||||||
Remove(ctx context.Context, sessionID string) error
|
Remove(sessionID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Value map[string]any
|
type Values map[string]any
|
||||||
|
|
Loading…
Reference in a new issue