Files
nadir-agent/internal/rbac/middleware_test.go
T
2026-06-24 17:29:45 +02:00

165 lines
4.8 KiB
Go

package rbac
import (
"context"
"net/http"
"path/filepath"
"testing"
"nadir/internal/auditlog"
"nadir/internal/auth"
"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/adapters/humago"
"github.com/danielgtaylor/huma/v2/humatest"
)
func TestRbacMiddleware(t *testing.T) {
tempDir := t.TempDir()
auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db"))
if err != nil {
t.Fatal(err)
}
defer auditStore.Close()
sessions, err := auth.NewSessionStore(filepath.Join(tempDir, "sessions.db"))
if err != nil {
t.Fatal(err)
}
tokenStore, err := auth.NewTokenStore(filepath.Join(tempDir, "tokens.db"))
if err != nil {
t.Fatal(err)
}
tokenAuth := auth.NewTokenAuth(tokenStore)
r := New()
r.DefineRole(Role{
Name: "test-role",
ModuleGrants: map[string][]Permission{
"test-mod": {Read, Write},
},
})
r.AssignRole("alice", "test-role")
r.AssignRole("dash", "test-role")
mux := http.NewServeMux()
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
api.UseMiddleware(RbacMiddleware(api, sessions, tokenAuth, r, auditStore))
huma.Register(api, huma.Operation{
OperationID: "public-get",
Method: "GET",
Path: "/public",
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
return &struct{ Body string }{Body: "public"}, nil
})
huma.Register(api, huma.Operation{
OperationID: "gated-get",
Method: "GET",
Path: "/gated-read",
Metadata: map[string]any{"module": "test-mod", "permission": "read"},
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
return &struct{ Body string }{Body: "gated-read"}, nil
})
huma.Register(api, huma.Operation{
OperationID: "gated-post",
Method: "POST",
Path: "/gated-write",
Metadata: map[string]any{"module": "test-mod", "permission": "write"},
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
return &struct{ Body string }{Body: "gated-write"}, nil
})
t.Run("public route", func(t *testing.T) {
resp := api.Get("/public")
if resp.Code != http.StatusOK {
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
}
})
t.Run("no auth returns 401", func(t *testing.T) {
resp := api.Get("/gated-read")
if resp.Code != http.StatusUnauthorized {
t.Errorf("got status %d, want %d", resp.Code, http.StatusUnauthorized)
}
})
t.Run("invalid cookie returns 401", func(t *testing.T) {
resp := api.Get("/gated-read", "Cookie: nadir_session_id=invalid")
if resp.Code != http.StatusUnauthorized {
t.Errorf("got status %d, want %d", resp.Code, http.StatusUnauthorized)
}
})
var aliceToken string
t.Run("valid session returns 200", func(t *testing.T) {
var err error
aliceToken, err = sessions.Create("alice")
if err != nil {
t.Fatal(err)
}
resp := api.Get("/gated-read", "Cookie: nadir_session_id="+aliceToken)
if resp.Code != http.StatusOK {
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
}
})
t.Run("csrf mismatched origin returns 403", func(t *testing.T) {
resp := api.Post("/gated-write", "Cookie: nadir_session_id="+aliceToken, "Origin: http://evil.com", "Host: example.com", struct{}{})
if resp.Code != http.StatusForbidden {
t.Errorf("got status %d, want %d", resp.Code, http.StatusForbidden)
}
})
t.Run("csrf matching origin returns 200", func(t *testing.T) {
resp := api.Post("/gated-write", "Cookie: nadir_session_id="+aliceToken, "Origin: http://example.com", "Host: example.com", struct{}{})
if resp.Code != http.StatusOK {
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
}
})
t.Run("unauthorized user returns 403", func(t *testing.T) {
bobToken, err := sessions.Create("bob")
if err != nil {
t.Fatal(err)
}
resp := api.Get("/gated-read", "Cookie: nadir_session_id="+bobToken)
if resp.Code != http.StatusForbidden {
t.Errorf("got status %d, want %d", resp.Code, http.StatusForbidden)
}
})
t.Run("valid bearer token returns 200", func(t *testing.T) {
rawToken, err := tokenStore.Create("dash")
if err != nil {
t.Fatal(err)
}
resp := api.Get("/gated-read", "Authorization: Bearer "+rawToken)
if resp.Code != http.StatusOK {
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
}
})
t.Run("bogus bearer token returns 401", func(t *testing.T) {
resp := api.Get("/gated-read", "Authorization: Bearer nad_deadbeef")
if resp.Code != http.StatusUnauthorized {
t.Errorf("got status %d, want %d", resp.Code, http.StatusUnauthorized)
}
})
t.Run("unassigned bearer token returns 403", func(t *testing.T) {
rawUnassigned, err := tokenStore.Create("orphan")
if err != nil {
t.Fatal(err)
}
resp := api.Get("/gated-read", "Authorization: Bearer "+rawUnassigned)
if resp.Code != http.StatusForbidden {
t.Errorf("got status %d, want %d", resp.Code, http.StatusForbidden)
}
})
}