165 lines
4.8 KiB
Go
165 lines
4.8 KiB
Go
package rbac
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"nadir/internal/auditlog"
|
|
"nadir/internal/auth"
|
|
|
|
"github.com/danielgtaylor/huma/v2"
|
|
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
|
"github.com/danielgtaylor/huma/v2/humatest"
|
|
)
|
|
|
|
func TestRbacMiddleware(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer auditStore.Close()
|
|
|
|
sessions, err := auth.NewSessionStore(filepath.Join(tempDir, "sessions.db"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tokenStore, err := auth.NewTokenStore(filepath.Join(tempDir, "tokens.db"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tokenAuth := auth.NewTokenAuth(tokenStore)
|
|
|
|
r := New()
|
|
r.DefineRole(Role{
|
|
Name: "test-role",
|
|
ModuleGrants: map[string][]Permission{
|
|
"test-mod": {Read, Write},
|
|
},
|
|
})
|
|
r.AssignRole("alice", "test-role")
|
|
r.AssignRole("dash", "test-role")
|
|
|
|
mux := http.NewServeMux()
|
|
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
|
|
|
api.UseMiddleware(RbacMiddleware(api, sessions, tokenAuth, r, auditStore))
|
|
|
|
huma.Register(api, huma.Operation{
|
|
OperationID: "public-get",
|
|
Method: "GET",
|
|
Path: "/public",
|
|
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
|
|
return &struct{ Body string }{Body: "public"}, nil
|
|
})
|
|
|
|
huma.Register(api, huma.Operation{
|
|
OperationID: "gated-get",
|
|
Method: "GET",
|
|
Path: "/gated-read",
|
|
Metadata: map[string]any{"module": "test-mod", "permission": "read"},
|
|
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
|
|
return &struct{ Body string }{Body: "gated-read"}, nil
|
|
})
|
|
|
|
huma.Register(api, huma.Operation{
|
|
OperationID: "gated-post",
|
|
Method: "POST",
|
|
Path: "/gated-write",
|
|
Metadata: map[string]any{"module": "test-mod", "permission": "write"},
|
|
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
|
|
return &struct{ Body string }{Body: "gated-write"}, nil
|
|
})
|
|
|
|
t.Run("public route", func(t *testing.T) {
|
|
resp := api.Get("/public")
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("no auth returns 401", func(t *testing.T) {
|
|
resp := api.Get("/gated-read")
|
|
if resp.Code != http.StatusUnauthorized {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
t.Run("invalid cookie returns 401", func(t *testing.T) {
|
|
resp := api.Get("/gated-read", "Cookie: nadir_session_id=invalid")
|
|
if resp.Code != http.StatusUnauthorized {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
var aliceToken string
|
|
t.Run("valid session returns 200", func(t *testing.T) {
|
|
var err error
|
|
aliceToken, err = sessions.Create("alice")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp := api.Get("/gated-read", "Cookie: nadir_session_id="+aliceToken)
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("csrf mismatched origin returns 403", func(t *testing.T) {
|
|
resp := api.Post("/gated-write", "Cookie: nadir_session_id="+aliceToken, "Origin: http://evil.com", "Host: example.com", struct{}{})
|
|
if resp.Code != http.StatusForbidden {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
|
|
t.Run("csrf matching origin returns 200", func(t *testing.T) {
|
|
resp := api.Post("/gated-write", "Cookie: nadir_session_id="+aliceToken, "Origin: http://example.com", "Host: example.com", struct{}{})
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("unauthorized user returns 403", func(t *testing.T) {
|
|
bobToken, err := sessions.Create("bob")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp := api.Get("/gated-read", "Cookie: nadir_session_id="+bobToken)
|
|
if resp.Code != http.StatusForbidden {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
|
|
t.Run("valid bearer token returns 200", func(t *testing.T) {
|
|
rawToken, err := tokenStore.Create("dash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp := api.Get("/gated-read", "Authorization: Bearer "+rawToken)
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("bogus bearer token returns 401", func(t *testing.T) {
|
|
resp := api.Get("/gated-read", "Authorization: Bearer nad_deadbeef")
|
|
if resp.Code != http.StatusUnauthorized {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
t.Run("unassigned bearer token returns 403", func(t *testing.T) {
|
|
rawUnassigned, err := tokenStore.Create("orphan")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp := api.Get("/gated-read", "Authorization: Bearer "+rawUnassigned)
|
|
if resp.Code != http.StatusForbidden {
|
|
t.Errorf("got status %d, want %d", resp.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
}
|