package session import ( "context" "encoding/base32" "log/slog" "net/http" "strings" "time" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "github.com/uptrace/bun" ) const sessionIDLen = 32 const defaultTableName = "sessions" const defaultMaxAge = 60 * 60 * 24 * 30 // 30 days const defaultPath = "/" // Options for bunstore type Options struct { TableName string SkipCreateTable bool } // Store represent a bunstore type Store struct { db *bun.DB opts Options Codecs []securecookie.Codec SessionOpts *sessions.Options } type bunSession struct { bun.BaseModel `bun:"table:sessions,alias:s"` ID string `bun:",pk,unique"` Data string CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` ExpiresAt time.Time } type KeyPairs []string func (k KeyPairs) ToKeys() [][]byte { b := make([][]byte, 0, len(k)) for _, kk := range k { b = append(b, []byte(kk)) } return b } // New creates a new bunstore session func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) { return NewOptions(db, Options{}, keyPairs) } // NewOptions creates a new bunstore session with options func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) { st := &Store{ db: db, opts: opts, Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...), SessionOpts: &sessions.Options{ Path: defaultPath, MaxAge: defaultMaxAge, }, } if st.opts.TableName == "" { st.opts.TableName = defaultTableName } if !st.opts.SkipCreateTable { model := &bunSession{} ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil { return nil, err } if _, err := db.NewCreateIndex().Model(model).Column("expires_at").Exec(ctx); err != nil { return nil, err } } return st, nil } // Get returns a session for the given name after adding it to the registry. func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(st, name) } // New creates a session with name without adding it to the registry. func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(st, name) opts := *st.SessionOpts session.Options = &opts session.IsNew = true st.MaxAge(st.SessionOpts.MaxAge) // try fetch from db if there is a cookie s := st.getSessionFromCookie(r, session.Name()) if s != nil { if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil { return session, nil } session.ID = s.ID session.IsNew = false } return session, nil } // Save session and set cookie header func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { s := st.getSessionFromCookie(r, session.Name()) // delete if max age is < 0 if session.Options.MaxAge < 0 { if s != nil { if _, err := st.db.NewDelete().Model(&bunSession{ID: session.ID}).Exec(r.Context()); err != nil { return err } } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...) if err != nil { return err } now := time.Now() expire := now.Add(time.Second * time.Duration(session.Options.MaxAge)) if s == nil { // generate random session ID key suitable for storage in the db session.ID = strings.TrimRight( base32.StdEncoding.EncodeToString( securecookie.GenerateRandomKey(sessionIDLen)), "=") s = &bunSession{ ID: session.ID, Data: data, ExpiresAt: expire, } if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil { return err } } else { s.Data = data s.ExpiresAt = expire if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil { return err } } // set session id cookie id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...) if err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options)) return nil } // getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie func (st *Store) getSessionFromCookie(r *http.Request, name string) *bunSession { if cookie, err := r.Cookie(name); err == nil { sessionID := "" if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil { return nil } s := &bunSession{} err := st.db.NewSelect().Model(s).Where("id = ? AND expires_at > ?", sessionID, time.Now()).Scan(r.Context()) if err != nil { return nil } return s } return nil } // MaxAge sets the maximum age for the store and the underlying cookie // implementation. Individual sessions can be deleted by setting // Options.MaxAge = -1 for that session. func (st *Store) MaxAge(age int) { st.SessionOpts.MaxAge = age for _, codec := range st.Codecs { if sc, ok := codec.(*securecookie.SecureCookie); ok { sc.MaxAge(age) } } } // MaxLength restricts the maximum length of new sessions to l. // If l is 0 there is no limit to the size of a session, use with caution. // The default is 4096 (default for securecookie) func (st *Store) MaxLength(l int) { for _, c := range st.Codecs { if codec, ok := c.(*securecookie.SecureCookie); ok { codec.MaxLength(l) } } } // Cleanup deletes expired sessions func (st *Store) Cleanup() { _, err := st.db.NewDelete().Model(&bunSession{}).Where("expires_at <= ?", time.Now()).Exec(context.Background()) if err != nil { slog.Default().With("error", err).Error("cleanup") } } // PeriodicCleanup runs Cleanup every interval. Close quit channel to stop. func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) { t := time.NewTicker(interval) defer t.Stop() for { select { case <-t.C: st.Cleanup() case <-quit: return } } }