231 lines
6 KiB
Go
231 lines
6 KiB
Go
|
package session
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/base32"
|
||
|
"log/slog"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/gorilla/securecookie"
|
||
|
"github.com/gorilla/sessions"
|
||
|
"github.com/uptrace/bun"
|
||
|
)
|
||
|
|
||
|
const sessionIDLen = 32
|
||
|
const defaultTableName = "sessions"
|
||
|
const defaultMaxAge = 60 * 60 * 24 * 30 // 30 days
|
||
|
const defaultPath = "/"
|
||
|
|
||
|
// Options for bunstore
|
||
|
type Options struct {
|
||
|
TableName string
|
||
|
SkipCreateTable bool
|
||
|
}
|
||
|
|
||
|
// Store represent a bunstore
|
||
|
type Store struct {
|
||
|
db *bun.DB
|
||
|
opts Options
|
||
|
Codecs []securecookie.Codec
|
||
|
SessionOpts *sessions.Options
|
||
|
}
|
||
|
|
||
|
type bunSession struct {
|
||
|
bun.BaseModel `bun:"table:sessions,alias:s"`
|
||
|
|
||
|
ID string `bun:",pk,unique"`
|
||
|
Data string
|
||
|
CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
|
||
|
UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
|
||
|
ExpiresAt time.Time
|
||
|
}
|
||
|
|
||
|
type KeyPairs []string
|
||
|
|
||
|
func (k KeyPairs) ToKeys() [][]byte {
|
||
|
b := make([][]byte, 0, len(k))
|
||
|
for _, kk := range k {
|
||
|
b = append(b, []byte(kk))
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// New creates a new bunstore session
|
||
|
func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) {
|
||
|
return NewOptions(db, Options{}, keyPairs)
|
||
|
}
|
||
|
|
||
|
// NewOptions creates a new bunstore session with options
|
||
|
func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) {
|
||
|
st := &Store{
|
||
|
db: db,
|
||
|
opts: opts,
|
||
|
Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...),
|
||
|
SessionOpts: &sessions.Options{
|
||
|
Path: defaultPath,
|
||
|
MaxAge: defaultMaxAge,
|
||
|
},
|
||
|
}
|
||
|
if st.opts.TableName == "" {
|
||
|
st.opts.TableName = defaultTableName
|
||
|
}
|
||
|
|
||
|
if !st.opts.SkipCreateTable {
|
||
|
model := &bunSession{}
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
|
defer cancel()
|
||
|
if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if _, err := db.NewCreateIndex().Model(model).Column("expires_at").Exec(ctx); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return st, nil
|
||
|
}
|
||
|
|
||
|
// Get returns a session for the given name after adding it to the registry.
|
||
|
func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
|
||
|
return sessions.GetRegistry(r).Get(st, name)
|
||
|
}
|
||
|
|
||
|
// New creates a session with name without adding it to the registry.
|
||
|
func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) {
|
||
|
session := sessions.NewSession(st, name)
|
||
|
opts := *st.SessionOpts
|
||
|
session.Options = &opts
|
||
|
session.IsNew = true
|
||
|
|
||
|
st.MaxAge(st.SessionOpts.MaxAge)
|
||
|
|
||
|
// try fetch from db if there is a cookie
|
||
|
s := st.getSessionFromCookie(r, session.Name())
|
||
|
if s != nil {
|
||
|
if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil {
|
||
|
return session, nil
|
||
|
}
|
||
|
session.ID = s.ID
|
||
|
session.IsNew = false
|
||
|
}
|
||
|
|
||
|
return session, nil
|
||
|
}
|
||
|
|
||
|
// Save session and set cookie header
|
||
|
func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
||
|
s := st.getSessionFromCookie(r, session.Name())
|
||
|
|
||
|
// delete if max age is < 0
|
||
|
if session.Options.MaxAge < 0 {
|
||
|
if s != nil {
|
||
|
if _, err := st.db.NewDelete().Model(&bunSession{ID: session.ID}).Exec(r.Context()); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
now := time.Now()
|
||
|
expire := now.Add(time.Second * time.Duration(session.Options.MaxAge))
|
||
|
|
||
|
if s == nil {
|
||
|
// generate random session ID key suitable for storage in the db
|
||
|
session.ID = strings.TrimRight(
|
||
|
base32.StdEncoding.EncodeToString(
|
||
|
securecookie.GenerateRandomKey(sessionIDLen)), "=")
|
||
|
s = &bunSession{
|
||
|
ID: session.ID,
|
||
|
Data: data,
|
||
|
ExpiresAt: expire,
|
||
|
}
|
||
|
if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
s.Data = data
|
||
|
s.ExpiresAt = expire
|
||
|
if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// set session id cookie
|
||
|
id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options))
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie
|
||
|
func (st *Store) getSessionFromCookie(r *http.Request, name string) *bunSession {
|
||
|
if cookie, err := r.Cookie(name); err == nil {
|
||
|
sessionID := ""
|
||
|
if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
s := &bunSession{}
|
||
|
err := st.db.NewSelect().Model(s).Where("id = ? AND expires_at > ?", sessionID, time.Now()).Scan(r.Context())
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
return s
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// MaxAge sets the maximum age for the store and the underlying cookie
|
||
|
// implementation. Individual sessions can be deleted by setting
|
||
|
// Options.MaxAge = -1 for that session.
|
||
|
func (st *Store) MaxAge(age int) {
|
||
|
st.SessionOpts.MaxAge = age
|
||
|
for _, codec := range st.Codecs {
|
||
|
if sc, ok := codec.(*securecookie.SecureCookie); ok {
|
||
|
sc.MaxAge(age)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// MaxLength restricts the maximum length of new sessions to l.
|
||
|
// If l is 0 there is no limit to the size of a session, use with caution.
|
||
|
// The default is 4096 (default for securecookie)
|
||
|
func (st *Store) MaxLength(l int) {
|
||
|
for _, c := range st.Codecs {
|
||
|
if codec, ok := c.(*securecookie.SecureCookie); ok {
|
||
|
codec.MaxLength(l)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Cleanup deletes expired sessions
|
||
|
func (st *Store) Cleanup() {
|
||
|
_, err := st.db.NewDelete().Model(&bunSession{}).Where("expires_at <= ?", time.Now()).Exec(context.Background())
|
||
|
if err != nil {
|
||
|
slog.Default().With("error", err).Error("cleanup")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// PeriodicCleanup runs Cleanup every interval. Close quit channel to stop.
|
||
|
func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) {
|
||
|
t := time.NewTicker(interval)
|
||
|
defer t.Stop()
|
||
|
for {
|
||
|
select {
|
||
|
case <-t.C:
|
||
|
st.Cleanup()
|
||
|
case <-quit:
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|