package auth import ( "errors" "net/http" "os" "path/filepath" "strings" "testing" "time" "nadir/internal/auditlog" "github.com/danielgtaylor/huma/v2" "github.com/danielgtaylor/huma/v2/adapters/humago" "github.com/danielgtaylor/huma/v2/humatest" ) func TestEnsurePAMService(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "nadir-pam-test") oldPath := pamServicePath pamServicePath = tempFile defer func() { pamServicePath = oldPath }() if err := EnsurePAMService(); err != nil { t.Fatal(err) } data, err := os.ReadFile(tempFile) if err != nil { t.Fatal(err) } if string(data) != pamServiceContent { t.Errorf("got content %q, want %q", string(data), pamServiceContent) } customContent := "custom pam content" if err := os.WriteFile(tempFile, []byte(customContent), 0644); err != nil { t.Fatal(err) } if err := EnsurePAMService(); err != nil { t.Fatal(err) } data, err = os.ReadFile(tempFile) if err != nil { t.Fatal(err) } if string(data) != customContent { t.Errorf("EnsurePAMService clobbered file: got %q, want %q", string(data), customContent) } } func TestLoginLogoutThrottling(t *testing.T) { tempDir := t.TempDir() auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db")) if err != nil { t.Fatal(err) } defer auditStore.Close() sessions, err := NewSessionStore(filepath.Join(tempDir, "sessions.db")) if err != nil { t.Fatal(err) } mux := http.NewServeMux() api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) authMock := func(username, password string) error { if password == "correct" { return nil } return errors.New("pam error") } throttle := newFailLimiter(3, 500*time.Millisecond) registerLogin(api, sessions, auditStore, false, authMock, throttle) RegisterLogout(api, sessions, false) t.Run("failed login returns 401", func(t *testing.T) { resp := api.Post("/api/login", struct { Username string `json:"username"` Password string `json:"password"` }{ Username: "admin", Password: "wrong", }) if resp.Code != http.StatusUnauthorized { t.Errorf("got code %d, want %d", resp.Code, http.StatusUnauthorized) } }) var sessionID string t.Run("successful login returns session cookie", func(t *testing.T) { resp := api.Post("/api/login", struct { Username string `json:"username"` Password string `json:"password"` }{ Username: "admin", Password: "correct", }) if resp.Code != http.StatusOK { t.Fatalf("got code %d, want %d", resp.Code, http.StatusOK) } cookieHeader := resp.Header().Get("Set-Cookie") if !strings.Contains(cookieHeader, "nadir_session_id=") { t.Fatalf("Set-Cookie header missing nadir_session_id: %q", cookieHeader) } parts := strings.SplitSeq(cookieHeader, ";") for part := range parts { part = strings.TrimSpace(part) if after, ok := strings.CutPrefix(part, "nadir_session_id="); ok { sessionID = after break } } if sessionID == "" { t.Fatal("nadir_session_id cookie not found") } if _, ok := sessions.GetByToken(sessionID); !ok { t.Fatal("session not found in session store") } }) t.Run("logout invalidates session", func(t *testing.T) { resp := api.Post("/api/logout", "Cookie: nadir_session_id="+sessionID, struct{}{}) if resp.Code != http.StatusOK { t.Fatalf("logout failed: got code %d, want %d", resp.Code, http.StatusOK) } if _, ok := sessions.GetByToken(sessionID); ok { t.Fatal("session still valid after logout") } }) t.Run("throttling blocks after 3 failures", func(t *testing.T) { for range 3 { api.Post("/api/login", struct { Username string `json:"username"` Password string `json:"password"` }{ Username: "throttled-user", Password: "wrong", }) } resp := api.Post("/api/login", struct { Username string `json:"username"` Password string `json:"password"` }{ Username: "throttled-user", Password: "correct", }) if resp.Code != http.StatusTooManyRequests { t.Errorf("got code %d, want %d", resp.Code, http.StatusTooManyRequests) } }) t.Run("throttle reset allows login", func(t *testing.T) { throttle.reset("throttled-user|") resp := api.Post("/api/login", struct { Username string `json:"username"` Password string `json:"password"` }{ Username: "throttled-user", Password: "correct", }) if resp.Code != http.StatusOK { t.Errorf("got code %d, want %d", resp.Code, http.StatusOK) } }) }