156 lines
4.5 KiB
Go
156 lines
4.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
var sessionTTL = 24 * time.Hour
|
|
|
|
type Session struct {
|
|
Username string
|
|
}
|
|
|
|
// SessionStore persists sessions in an embedded SQLite database so they survive
|
|
// process restarts. Expired rows are dropped lazily on read.
|
|
type SessionStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewSessionStore opens (creating if needed) the SQLite database at path and
|
|
// ensures the sessions table exists. The parent directory is created too.
|
|
func NewSessionStore(path string) (*SessionStore, error) {
|
|
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
|
return nil, fmt.Errorf("session db dir: %w", err)
|
|
}
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open session db: %w", err)
|
|
}
|
|
// ponytail: single connection serializes access, avoiding SQLite's
|
|
// "database is locked" under concurrent writes. A sysadmin panel's session
|
|
// rate is trivial; switch to WAL + a real pool only if that ever bites.
|
|
db.SetMaxOpenConns(1)
|
|
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS sessions (
|
|
token TEXT PRIMARY KEY,
|
|
username TEXT NOT NULL,
|
|
expires_at INTEGER NOT NULL
|
|
)`); err != nil {
|
|
return nil, fmt.Errorf("create sessions table: %w", err)
|
|
}
|
|
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS throttle (
|
|
key TEXT PRIMARY KEY,
|
|
count INTEGER NOT NULL,
|
|
until INTEGER NOT NULL
|
|
)`); err != nil {
|
|
return nil, fmt.Errorf("create throttle table: %w", err)
|
|
}
|
|
// Restrict the DB file so only the owning user (root) can read it.
|
|
if err := os.Chmod(path, 0600); err != nil {
|
|
return nil, fmt.Errorf("chmod session db: %w", err)
|
|
}
|
|
return &SessionStore{db: db}, nil
|
|
}
|
|
|
|
// Ping reports whether the session database is reachable. Used by the health
|
|
// check.
|
|
func (s *SessionStore) Ping() error { return s.db.Ping() }
|
|
|
|
func (s *SessionStore) Create(username string) (string, error) {
|
|
token := randomToken()
|
|
expires := time.Now().Add(sessionTTL)
|
|
if _, err := s.db.Exec(
|
|
`INSERT INTO sessions (token, username, expires_at) VALUES (?, ?, ?)`,
|
|
hashSessionToken(token), username, expires.Unix(),
|
|
); err != nil {
|
|
return "", err
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
// Delete removes a session, invalidating it immediately (logout). Deleting an
|
|
// unknown token is a no-op.
|
|
func (s *SessionStore) Delete(token string) error {
|
|
_, err := s.db.Exec(`DELETE FROM sessions WHERE token = ?`, hashSessionToken(token))
|
|
return err
|
|
}
|
|
|
|
// DeleteByUsername removes every session for the given user, used when a
|
|
// password change should invalidate all existing sessions.
|
|
func (s *SessionStore) DeleteByUsername(username string) error {
|
|
_, err := s.db.Exec(`DELETE FROM sessions WHERE username = ?`, username)
|
|
return err
|
|
}
|
|
|
|
func (s *SessionStore) GetByToken(token string) (Session, bool) {
|
|
var username string
|
|
var expires int64
|
|
h := hashSessionToken(token)
|
|
err := s.db.QueryRow(
|
|
`SELECT username, expires_at FROM sessions WHERE token = ?`, h,
|
|
).Scan(&username, &expires)
|
|
if err != nil {
|
|
return Session{}, false
|
|
}
|
|
if time.Now().Unix() > expires {
|
|
s.db.Exec(`DELETE FROM sessions WHERE token = ?`, h)
|
|
return Session{}, false
|
|
}
|
|
return Session{Username: username}, true
|
|
}
|
|
|
|
// NewPersistentFailLimiter returns a failLimiter whose state survives process
|
|
// restarts via the session SQLite database.
|
|
func (s *SessionStore) NewPersistentFailLimiter(max int, window time.Duration) *failLimiter {
|
|
l := newFailLimiter(max, window)
|
|
// Load existing entries, skipping expired ones.
|
|
rows, err := s.db.Query(`SELECT key, count, until FROM throttle`)
|
|
if err != nil {
|
|
return l
|
|
}
|
|
defer rows.Close()
|
|
now := time.Now()
|
|
for rows.Next() {
|
|
var k string
|
|
var c int
|
|
var u int64
|
|
if err := rows.Scan(&k, &c, &u); err != nil {
|
|
continue
|
|
}
|
|
until := time.Unix(u, 0)
|
|
if now.After(until) {
|
|
s.db.Exec(`DELETE FROM throttle WHERE key = ?`, k)
|
|
continue
|
|
}
|
|
l.attempts[k] = &attemptState{count: c, until: until}
|
|
}
|
|
// Wire persistence: writes to DB on every mutation.
|
|
l.sync = func(key string, st *attemptState) {
|
|
if st == nil {
|
|
s.db.Exec(`DELETE FROM throttle WHERE key = ?`, key)
|
|
} else {
|
|
s.db.Exec(`INSERT OR REPLACE INTO throttle (key, count, until) VALUES (?, ?, ?)`, key, st.count, st.until.Unix())
|
|
}
|
|
}
|
|
return l
|
|
}
|
|
|
|
func hashSessionToken(token string) string {
|
|
h := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
func randomToken() string {
|
|
b := make([]byte, 32)
|
|
rand.Read(b) // never fails; rand.Read panics internally on misconfigured platforms.
|
|
return hex.EncodeToString(b)
|
|
}
|