package session import ( "context" "encoding/base32" "log/slog" "net/http" "strings" "time" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "github.com/uptrace/bun" ) const ( sessionIDLen = 32 defaultTableName = "sessions" defaultMaxAge = 60 * 60 * 24 * 30 // 30 days defaultPath = "/" ) // Options for bunstore. type Options struct { TableName string SkipCreateTable bool } // Store represent a bunstore. type Store struct { db *bun.DB opts Options Codecs []securecookie.Codec SessionOpts *sessions.Options } type Model struct { bun.BaseModel `bun:"table:sessions,alias:s"` ID string `bun:",pk,unique"` Data string CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` ExpiresAt time.Time } type KeyPairs []string func (k KeyPairs) ToKeys() [][]byte { b := make([][]byte, 0, len(k)) for _, kk := range k { b = append(b, []byte(kk)) } return b } // New creates a new bunstore session. func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) { return NewOptions(db, Options{}, keyPairs) } // NewOptions creates a new bunstore session with options. func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) { st := &Store{ db: db, opts: opts, Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...), SessionOpts: &sessions.Options{ Path: defaultPath, MaxAge: defaultMaxAge, }, } return st, nil } // Get returns a session for the given name after adding it to the registry. func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(st, name) } // New creates a session with name without adding it to the registry. func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(st, name) opts := *st.SessionOpts session.Options = &opts session.IsNew = true st.MaxAge(st.SessionOpts.MaxAge) // try fetch from db if there is a cookie s := st.getSessionFromCookie(r, session.Name()) if s != nil { if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil { //nolint:nilerr return session, nil } session.ID = s.ID session.IsNew = false } return session, nil } // Save session and set cookie header. func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { s := st.getSessionFromCookie(r, session.Name()) // delete if max age is < 0 if session.Options.MaxAge < 0 { if s != nil { if _, err := st.db.NewDelete().Model(&Model{ID: session.ID}).WherePK("id").Exec(r.Context()); err != nil { return err } } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...) if err != nil { return err } now := time.Now() expire := now.Add(time.Second * time.Duration(session.Options.MaxAge)) if s == nil { // generate random session ID key suitable for storage in the db session.ID = strings.TrimRight( base32.StdEncoding.EncodeToString( securecookie.GenerateRandomKey(sessionIDLen)), "=") s = &Model{ ID: session.ID, Data: data, ExpiresAt: expire, } if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil { return err } } else { s.Data = data s.ExpiresAt = expire if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil { return err } } // set session id cookie id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...) if err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options)) return nil } // getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie. func (st *Store) getSessionFromCookie(r *http.Request, name string) *Model { if cookie, err := r.Cookie(name); err == nil { sessionID := "" if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil { return nil } s := &Model{} if err := st.db.NewSelect(). Model(s). Where("id = ? AND expires_at > ?", sessionID, time.Now()). Scan(r.Context()); err != nil { return nil } return s } return nil } // MaxAge sets the maximum age for the store and the underlying cookie // implementation. Individual sessions can be deleted by setting // Options.MaxAge = -1 for that session. func (st *Store) MaxAge(age int) { st.SessionOpts.MaxAge = age for _, codec := range st.Codecs { if sc, ok := codec.(*securecookie.SecureCookie); ok { sc.MaxAge(age) } } } // MaxLength restricts the maximum length of new sessions to l. // If l is 0 there is no limit to the size of a session, use with caution. // The default is 4096 (default for securecookie). func (st *Store) MaxLength(l int) { for _, c := range st.Codecs { if codec, ok := c.(*securecookie.SecureCookie); ok { codec.MaxLength(l) } } } // Cleanup deletes expired sessions. func (st *Store) Cleanup() { _, err := st.db.NewDelete().Model(&Model{}).Where("expires_at <= ?", time.Now()).Exec(context.Background()) if err != nil { slog.Default().With("error", err).Error("cleanup") } } // PeriodicCleanup runs Cleanup every interval. Close quit channel to stop. func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) { t := time.NewTicker(interval) defer t.Stop() for { select { case <-t.C: st.Cleanup() case <-quit: return } } }