174 lines
4.4 KiB
Go
174 lines
4.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"nadir/internal/auditlog"
|
|
|
|
"github.com/danielgtaylor/huma/v2"
|
|
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
|
"github.com/danielgtaylor/huma/v2/humatest"
|
|
)
|
|
|
|
func TestEnsurePAMService(t *testing.T) {
|
|
tempFile := filepath.Join(t.TempDir(), "nadir-pam-test")
|
|
oldPath := pamServicePath
|
|
pamServicePath = tempFile
|
|
defer func() { pamServicePath = oldPath }()
|
|
|
|
if err := EnsurePAMService(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
data, err := os.ReadFile(tempFile)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if string(data) != pamServiceContent {
|
|
t.Errorf("got content %q, want %q", string(data), pamServiceContent)
|
|
}
|
|
|
|
customContent := "custom pam content"
|
|
if err := os.WriteFile(tempFile, []byte(customContent), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := EnsurePAMService(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
data, err = os.ReadFile(tempFile)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if string(data) != customContent {
|
|
t.Errorf("EnsurePAMService clobbered file: got %q, want %q", string(data), customContent)
|
|
}
|
|
}
|
|
|
|
func TestLoginLogoutThrottling(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer auditStore.Close()
|
|
|
|
sessions, err := NewSessionStore(filepath.Join(tempDir, "sessions.db"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
|
|
|
authMock := func(username, password string) error {
|
|
if password == "correct" {
|
|
return nil
|
|
}
|
|
return errors.New("pam error")
|
|
}
|
|
throttle := newFailLimiter(3, 500*time.Millisecond)
|
|
registerLogin(api, sessions, auditStore, false, authMock, throttle)
|
|
RegisterLogout(api, sessions, false)
|
|
|
|
t.Run("failed login returns 401", func(t *testing.T) {
|
|
resp := api.Post("/api/login", struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}{
|
|
Username: "admin",
|
|
Password: "wrong",
|
|
})
|
|
if resp.Code != http.StatusUnauthorized {
|
|
t.Errorf("got code %d, want %d", resp.Code, http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
var sessionID string
|
|
t.Run("successful login returns session cookie", func(t *testing.T) {
|
|
resp := api.Post("/api/login", struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}{
|
|
Username: "admin",
|
|
Password: "correct",
|
|
})
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("got code %d, want %d", resp.Code, http.StatusOK)
|
|
}
|
|
|
|
cookieHeader := resp.Header().Get("Set-Cookie")
|
|
if !strings.Contains(cookieHeader, "nadir_session_id=") {
|
|
t.Fatalf("Set-Cookie header missing nadir_session_id: %q", cookieHeader)
|
|
}
|
|
|
|
parts := strings.SplitSeq(cookieHeader, ";")
|
|
for part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if after, ok := strings.CutPrefix(part, "nadir_session_id="); ok {
|
|
sessionID = after
|
|
break
|
|
}
|
|
}
|
|
if sessionID == "" {
|
|
t.Fatal("nadir_session_id cookie not found")
|
|
}
|
|
if _, ok := sessions.GetByToken(sessionID); !ok {
|
|
t.Fatal("session not found in session store")
|
|
}
|
|
})
|
|
|
|
t.Run("logout invalidates session", func(t *testing.T) {
|
|
resp := api.Post("/api/logout", "Cookie: nadir_session_id="+sessionID, struct{}{})
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("logout failed: got code %d, want %d", resp.Code, http.StatusOK)
|
|
}
|
|
if _, ok := sessions.GetByToken(sessionID); ok {
|
|
t.Fatal("session still valid after logout")
|
|
}
|
|
})
|
|
|
|
t.Run("throttling blocks after 3 failures", func(t *testing.T) {
|
|
for range 3 {
|
|
api.Post("/api/login", struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}{
|
|
Username: "throttled-user",
|
|
Password: "wrong",
|
|
})
|
|
}
|
|
resp := api.Post("/api/login", struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}{
|
|
Username: "throttled-user",
|
|
Password: "correct",
|
|
})
|
|
if resp.Code != http.StatusTooManyRequests {
|
|
t.Errorf("got code %d, want %d", resp.Code, http.StatusTooManyRequests)
|
|
}
|
|
})
|
|
|
|
t.Run("throttle reset allows login", func(t *testing.T) {
|
|
throttle.reset("throttled-user|")
|
|
resp := api.Post("/api/login", struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}{
|
|
Username: "throttled-user",
|
|
Password: "correct",
|
|
})
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("got code %d, want %d", resp.Code, http.StatusOK)
|
|
}
|
|
})
|
|
}
|